资源简介
使用RNN进行mnist的分类,使用的是一个3层的GRU作为模型
代码片段和文件信息
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def initialize_weight_bias(in_size out_size):
weight = tf.truncated_normal(shape=(in_size out_size) stddev=0.01 mean=0.0)
bias = tf.constant(0.1 shape=[out_size])
return tf.Variable(weight) tf.Variable(bias)
def model(data target dropout num_hidden=200 num_layers=3):
“““
RNN model for mnist classification.
Args:
data: input data with shape (batch_size max_time_steps cell_size).
target : label of input data with shape (batch_size num_classes).
dropout: dropout rate.
num_hidden: the number of hidden units.
num_layers: the number of RNN layers.
Returns:
“““
# establish RNN model
cells = list()
for _ in range(num_layers):
cell = tf.nn.rnn_cell.GRUCell(num_units=num_hidden)
cell = tf.nn.rnn_cell.DropoutWrapper(cell=cell output_keep_prob=1.0-dropout)
cells.append(cell)
network = tf.nn.rnn_cell.MultiRNNCell(cells=cells)
outputs last_state = tf.nn.dynamic_rnn(cell=network inputs=data dtype=tf.float32)
# get last output
outputs = tf.transpose(outputs (1 0 2))
last_output = tf.gather(outputs int(outputs.get_shape()[0])-1)
# add softmax layer
out_size = int(target.get_shape()[1])
weight bias = initialize_weight_bias(in_size=num_hidden out_size=out_size)
logits = tf.add(tf.matmul(last_output weight) bias)
return logits
def main():
# define some parameters
default_epochs = 10
default_batch_size = 64
default_dropout = 0.5
test_freq = 150 # every 150 batches
logs_path = ‘data/log‘
# get train and test data
mnist_data = input_data.read_data_sets(‘data/mnist‘ one_hot=True)
total_steps = int(mnist_data.train.num_examples/default_batch_size)
total_test_steps = int(mnist_data.test.num_examples/default_batch_size)
print(‘number of training examples: %d‘ % mnist_data.train.num_examples) # 55000
print(‘number of test examples: %d‘ % mnist_data.test.num_examples) # 10000
# fit RNN model
input_x = tf.placeholder(tf.float32 shape=(None 28 28))
- 上一篇:word2vec.py
- 下一篇:广州精细地铁线路图轨迹图可视化.py
相关资源
- MNIST手写体数字训练/测试数据集(图
- Python-本项目基于yolo3与crnn实现中文自
- Long Short-Term Memory Networks With Python
- bayes分类python
- knn算法识别mnist图片-python3
- Ubuntu18.04LTS下安装 Caffe-GPU版本及 Ana
- MLP/RNN/LSTM模型进行IMDb情感分析
- mnist手写数字识别数据集npz文件.zip
- 基于Python的手写字体识别系统
- Python学习实践-sklearn分类算法实践-M
- pytorch版本手写体识别MNIST.zip
- 基于python3 tensorflow DBN_and_RNN的实现
- 自己编写的MNIST手写字Python实现
- 文本分类代码集合含数据_TextCNN_Text
- mnist手写字体识别之BP.zip
- tensorflow操作mnist数据集源代码
- 使用knn对MNIST分类
- tensorflow手写数字识别完整版.zip
- 基于tensorflow的手写体识别python源码附
- 逻辑回归python代码
- mnist手写字体识别之随机森林.zip
- RNN python
- poetryRNN诗人
- win10+anaconda3+python3 mnist训练代码
- 代码:Python+TensorFlow+PyQt实现手写体数
- SVM分类手动鼠标手写数字-python版本
- 使用python搭建mnist全连接神经网络
- knn_search.py
- Tensorflow实现GAN生成mnist手写数字图片
- 纯python实现mnist手写体识别.zip
评论
共有 条评论