资源简介
代码用于猫-非猫图片的二分类问题(附件内给出h5格式的数据集),基于Pytorch神经网络工具包,采用比较经典的逻辑回归(Logistic Regression)算法。
代码片段和文件信息
import numpy as np
import torch as t
import h5py
import os
os.environ[‘KMP_DUPLICATE_LIB_OK‘] = ‘True‘
class LogisticRegression(t.nn.Module):
def __init__(self):
super(LogisticRegression self).__init__()
self.lg = t.nn.Sequential(
t.nn.Linear(12288 1) t.nn.Sigmoid()
)
def forward(self x):
output = self.lg(x)
return output
lg_model = LogisticRegression()
cost_func = t.nn.BCELoss()
optimizer = t.optim.SGD(lg_model.parameters() lr=0.001 momentum=0.9)
epochs = 500
train_dataset = h5py.File(‘train_catvnoncat.h5‘ “r“)
train_set_x_orig = t.tensor(train_dataset[“train_set_x“][:])
train_set_y = t.tensor(np.array(train_dataset[“train_set_y“][:]))/1.0
test_dataset = h5py.File(‘test_catvnoncat.h5‘ “r“)
test_set_x_orig = t.tensor(np.array(test_dataset[“test_set_x“][:]))
test_set_y = t.tensor(np.array(test_dataset[“test_set_y“][:]))/1.0
num_train = train_set_x_orig.shape[0]
num_test = test_set_x_orig.shape[0]
train_set_x = train_set_x_orig.reshape(num_train -1)/255.0
test_set_x = test_set_x_orig.reshape(num_test -1)/255.0
train_set_y = train_set_y.reshape(num_train 1)
test_set_y = test_set_y.reshape(num_test 1)
train_loss = 0
for epoch in range(epochs):
lg_model.train()
y_out = lg_model(train_set_x)
train_loss = cost_func(y_out train_set_y)
optimizer.zero_grad()
train_loss.backward()
optimizer.step()
with t.no_grad():
y_pred = y_out.ge(0.5).float()
num_correct = (y_pred == train_set_y).sum().item()
acc_rate = num_correct * 100.0 / num_train
print(“世代数: %d 训练集正确率: %.1f%%“ % (epoch acc_rate))
lg_model.eval()
y_out = lg_model(test_set_x)
y_pred = y_out.ge(0.5).float()
num_correct = (y_pred == test_set_y).sum().item()
acc_rate = num_correct * 100.0 / num_test
print(“测试集正确率: %.1f%%“ % acc_rate)
pass
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 1991 2020-08-04 17:05 猫-非猫图二分类\Logistic_Regression.py
文件 616958 2020-07-31 15:31 猫-非猫图二分类\test_catvnoncat.h5
文件 2572022 2020-07-31 15:31 猫-非猫图二分类\train_catvnoncat.h5
目录 0 2020-08-04 17:00 猫-非猫图二分类\
相关资源
- 012345手势识别神经网络代码
- python 记录键盘按键(基于keyboard)
- python 画星星(满天星)
- python 实现 屏幕水印
- python发送gmail邮件demo
- python自动抠图(基于cv2)
- Python Libraries(python编程常用库教程)
- python猜数字游戏
- 华为-python基础教程(108页)
- python幸运抽奖
- Python新手入门详细教程(网盘)
- 《大数据数学基础(Python语言描述)
- Python 3.8.5中文指南
- python 画熊猫(基于turtle)
- 用python画第一型空间曲线
- python无限生成点(基于matplotlib.pyplo
- python 微信机器人
- python基础教程.pptx
- python opencv 银行卡识别.ipynb
- python opencv 图片更换背景. ipynb
- python学生管理系统源码(控制台)
- python 采集京东商品数据
- 多项式拟合(LSM.py)
- python实现 99乘法表
- python连接dubbo
- 百度图片爬虫(python版)
- 《Python 编程:从入门到实践》所有代
- 机器学习k means算法实现图像分割
- 帝国竞争算法python实现
- 拼音转汉字(python输入法)
评论
共有 条评论