资源简介
梯度下降纯手工实现 MLP CNN RNN SEQ2SEQ识别手写体MNIST数据集十分类问题代码详解.
代码片段和文件信息
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data # 导入下载数据集手写体
mnist = input_data.read_data_sets(‘../MNIST_data/‘ one_hot=True) # 下载数据集
class CNNNet:
def __init__(self):
self.x = tf.placeholder(dtype=tf.float32 shape=[None 28 28 1] name=‘input_x‘)
self.y = tf.placeholder(dtype=tf.float32 shape=[None 10] name=‘input_y‘)
self.w1 = tf.Variable(
tf.truncated_normal(shape=[3 3 1 16] dtype=tf.float32 stddev=tf.sqrt(1 / 16) name=‘w1‘))
self.b1 = tf.Variable(tf.zeros(shape=[16] dtype=tf.float32 name=‘b1‘))
self.w2 = tf.Variable(
tf.truncated_normal(shape=[3 3 16 32] dtype=tf.float32 stddev=tf.sqrt(1 / 32) name=‘w2‘))
self.b2 = tf.Variable(tf.zeros(shape=[32] dtype=tf.float32 name=‘b2‘))
self.fc_w1 = tf.Variable(
tf.truncated_normal(shape=[28 * 28 * 32 128] dtype=tf.float32 stddev=tf.sqrt(1 / 128) name=‘fc_w1‘))
self.fc_b1 = tf.Variable(tf.zeros(shape=[128] dtype=tf.float32 name=‘fc_b1‘))
self.fc_w2 = tf.Variable(
tf.truncated_normal(shape=[128 10] dtype=tf.float32 stddev=tf.sqrt(1 / 10) name=‘fc_w2‘))
self.fc_b2 = tf.Variable(tf.zeros(shape=[10] dtype=tf.float32 name=‘fc_b2‘))
def forward(self):
self.conv1 = tf.nn.relu(
tf.nn.conv2d(self.x self.w1 strides=[1 1 1 1] padding=‘SAME‘ name=‘conv1‘) + self.b1)
self.conv2 = tf.nn.relu(
tf.nn.conv2d(self.conv1 self.w2 strides=[1 1 1 1] padding=‘SAME‘ name=‘conv2‘) + self.b2)
self.flat = tf.reshape(self.conv2 [-1 28 * 28 * 32])
self.fc1 = tf.nn.relu(tf.matmul(self.flat self.fc_w1) + self.fc_b1)
self.fc2 = tf.matmul(self.fc1 self.fc_w2) + self.fc_b2
self.output = tf.nn.softmax(self.fc2)
def backward(self):
self.cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.fc2 labels=self.y))
self.opt = tf.train.AdamOptimizer().minimize(self.cost)
def acc(self):
self.acc2 = tf.equal(tf.argmax(self.output 1) tf.argmax(self.y 1))
self.accaracy = tf.reduce_mean(tf.cast(self.acc2 dtype=tf.float32))
if __name__ == ‘__main__‘:
net = CNNNet()
net.forward()
net.backward()
net.acc()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(10000):
ax ay = mnist.train.next_batch(100)
ax_batch = ax.reshape([-1 28 28 1])
loss output accaracy _ = sess.run(fetches=[net.cost net.output net.accaracy net.opt]
feed_dict={net.x: ax_batch net.y: ay})
# print(loss)
# print(accaracy)
if i % 10 == 0:
test_ax test_ay = mnist.test.next_batch(100)
test_ax_batch = test_
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 21566 2018-11-02 22:06 gradient_descent.png
文件 1849 2018-11-01 11:33 gradient_descent.py
文件 1648877 2018-10-30 09:53 MNIST_data\t10k-images-idx3-ubyte.gz
文件 4542 2018-10-30 09:53 MNIST_data\t10k-labels-idx1-ubyte.gz
文件 9912422 2018-10-30 09:53 MNIST_data\train-images-idx3-ubyte.gz
文件 28881 2018-10-30 09:53 MNIST_data\train-labels-idx1-ubyte.gz
文件 3213 2018-11-02 11:21 SEQ2SEQ.py
文件 2640 2018-11-01 19:56 RNNNet.py
文件 3357 2018-11-01 19:56 CNNNet.py
文件 2205 2018-11-01 12:44 MLPNet.py
目录 0 2018-11-02 22:07 MNIST_data
----------- --------- ---------- ----- ----
11629552 11
相关资源
- 吴恩达UFLDL教程
- 深度学习:智能时代的核心驱动力量
- Deep Learning with R
- 深度学习基础(Fundamentals of Deep Lear
- 使用tensorflow实现CNN-RNN-GAN代码
- 金融股票深度学习论文整理
- 对《Secureml A system for scalable privacy-p
- Hands-On Machine Learning with Scikit-Learn Ke
- 深度学习方法及应用Deep Learning Metho
- 深度学习源代码162566
- 概率统计超入门
- 黄海广博士整理的吴恩达深度学习笔
- Neural Networks and Deep Learning-神经网络与
- 深度学习资料+官方文档
- 动手学深度学习源代码
- 深度学习框架PyTorch:入门与实践 PD
- Reinforcement Learning an Introduction,2018正
- 深度学习卷积神经网络代码
- 深度学习/图像识别/TensorFlow
- fashion-mnist数据集
- 深度学习 [deep learning] AI圣经 Deep Lea
- 一天弄懂深度学习-李宏毅PPT+PDF超级高
- 解压后的MNIST数据集
- 深度学习车牌识别模型.zip
- 深度学习 智能时代的核心驱动力量
- imdb.npz数据集
- tensorflow实战+实战Google深度学习框架
- 基于深度学习的目标检测程序
- 吴恩达老师深度学习第一课神经网络
- 深度学习 AI圣经 中文高清版 带完整目
评论
共有 条评论