资源简介
代码是利用pytorch框架实现的,识别过程是利用循环神经网络RNN进行训练。
代码片段和文件信息
import torch
from torch import nnoptim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasetstransforms
# 超参数
batch_size = 100 # 批大小
learning_rate = 0.01 # 学习率
num_epoches = 20 # 训练次数
data_tf = transforms.Compose([transforms.ToTensor()transforms.Normalize([0.5][0.5])])
train_dataset = datasets.MNIST(root=‘./data‘train=Truetransform=data_tfdownload=True)
test_dataset = datasets.MNIST(root=‘./data‘train=Falsetransform=data_tf)
train_loader = DataLoader(train_datasetbatch_size=batch_sizeshuffle=True)
test_loader = DataLoader(test_datasetbatch_size=batch_sizeshuffle=False)
class Rnn(nn.Module):
def __init__(self in_dim hidden_dim n_layer n_class):
super(Rnn self).__init__()
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(in_dim hidden_dim n_layer batch_first=True)
self.classifier = nn.Linear(hidden_dim n_class)
def forward(self x):
out _ = self.lstm(x)
out = out[: -1 :]
out = self.classifier(out)
return out
model = Rnn(28 128 2 10) # 图片大小是28x28
# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters() lr=learning_rate)
# 开始训练
for epoch in range(num_epoches):
running_loss = 0.0
running_acc = 0.0
for i data in enumerate(train_loader 1):
img label = data
img = img.squeeze(1)
if torch.cuda.is_available():
img = img.cuda()
label = label.cuda()
else:
img = Variable(img)
相关资源
- RNN实现的matlab代码
- 读取自己的mnist数据集代码mnist.py
- Python-RNNoiseRNN音频噪声抑制学习
- CBAM_MNIST.py
- 利用keras实现的cnn卷积神经网络对手写
- 保存图片为 mnist格式
- keras上LSTM长短期记忆网络金融时序预
- Python实现循环神经网络RNN
- 《PyTorch生成对抗网络编程》思维导图
- cpu_nms.py
- Python-使用RNN股市预测
- pytorch1.5官方英文文档PythonAPI(包含书
- Deep Learning: Recurrent Neural Networks in Py
- pytorch实现logistic回归
- 基于PyTorch的深度学习技术进步
- resnet-pytorch
- mnist_CNN 深度学习小
- python MNIST分类 tensorflow
- 股票预测 LSTM 时间序列rnn 代码程序数
- mnist_mlp.py
- 使用逻辑回归进行MNIST字符分类识别代
- 基于kNN方法的MNIST手写数字识别Tenso
- Python-PyTorch实现基于Transformer的神经机
- Python-PyTorch实现的NEAT神经进化算法
- python kNN算法实现MNIST数据集分类 k值
- RNN网络代码
- BP算法实现26个字母识别
- MLPCNN;识别手写数字集Mnist
- cnn +rnn +attention 以及CTC-loss融合的文字
- pytorch的cifar10数据集分类程序
评论
共有 条评论