资源简介
基于Tensorflow多层神经网络的MNIST手写数字识别(数据集源码).rar

代码片段和文件信息
# @author ZwwIot
#!/usr/bin/env python
# coding: utf-8
# In[17]:
import tensorflow as tf
# 导入 MNIST 数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(“/data/“ one_hot = True)
# In[18]:
# 参数设置
learning_rate = 0.001
training_epochs = 25
batch_size = 100
display_step = 1
# 网络参数
n_hidden_1 = 256# 1层网络神经元数
n_hidden_2 = 256# 2层网络神经元数
n_input = 784# MNIST data 输入 (img shape: 28*28)
n_classes = 10# MNIST 类别 (0-9 一共10类)
saver = tf.train.Saver()# 保存
model_path = “log/520model.ckpt“
# tf Graph input
x = tf.placeholder(“float“ [None n_input])
y = tf.placeholder(“float“ [None n_classes])
# In[19]:
# Create model
def multilayer_perceptron(x weights biases):
# Hidden layer with RELU activation
layer_1 = tf.add(tf.matmul(x weights[‘h1‘])biases[‘b1‘])
layer_1 = tf.nn.relu(layer_1)
# Hidden layer with RELU activation
layer_2 = tf.add(tf.matmul(layer_1 weights[‘h2‘])biases[‘b2‘])
layer_2 = tf.nn.relu(layer_2)
# Output layer with linear activation
out_layer = tf.matmul(layer_2 weights[‘out‘]) + biases[‘out‘]
return out_layer
# In[20]:
# Store layers weight & bias
weights = {
‘h1‘: tf.Variable(tf.random_normal([n_input n_hidden_1]))
‘h2‘: tf.Variable(tf.random_normal([n_hidden_1 n_hidden_2]))
‘out‘: tf.Variable(tf.random_normal([n_hidden_2 n_classes]))
}
biases = {
‘b1‘: tf.Variable(tf.random_normal([n_hidden_1]))
‘b2‘: tf.Variable(tf.random_normal([n_hidden_2]))
‘out‘: tf.Variable(tf.random_normal([n_classes]))
}
# In[21]:
# 构建模型
pred = multilayer_perceptron(x weights biases)
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pred labels = y))
optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(cost)
# 初始化变量
init = tf.global_variables_initializer()
# In[25]:
# 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 启动循环开始训练
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)# 每一轮训练多少批次
# 遍历全部数据集
for i in range(total_batch):
batch_xs batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_ c = sess.run([optimizer cost] feed_dict={x: batch_xs y: batch_ys})
# 计算平均值以使误差值更平均
avg_cost += c / total_batch
# print(“I:“ ‘%04d‘ % (epoch + 1) “cost=“ “{:.9f}“.format(avg_cost))
# 显示训练中的详细信息
if (epoch+1) % display_step == 0:
print(“Epoch:“ ‘%04d‘ % (epoch+1) “cost=“ “{:.9f}“.format(avg_cost))
print(“Finished!“)
# 测试 model
correct_prediction = tf.equal(tf.argmax(pred 1) tf.argmax(y 1))
# 计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction “float“))
print(“Accuracy:“ accuracy.eval({x: mnist.test.images y: mnist.test.labels}))
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 4248 2018-12-23 23:55 MNIST多层分类.py
文件 7840016 2016-11-02 19:39 MNIST DATA\MNIST DATA\test-images
文件 10008 2016-11-02 19:39 MNIST DATA\MNIST DATA\test-labels
文件 47040016 2016-11-02 19:39 MNIST DATA\MNIST DATA\train-images
文件 60008 2016-11-02 19:39 MNIST DATA\MNIST DATA\train-labels
目录 0 2018-12-23 23:56 MNIST DATA\MNIST DATA
目录 0 2018-12-23 23:56 MNIST DATA
----------- --------- ---------- ----- ----
54954296 7
相关资源
- Listary Pro 破解最新版
- Allway Sync Pro 10.5.8注册码 序列号 激活
- 百度文库破解软件
- winfrom自定义设计器源码
- 操作系统哲学家就餐问题(界面+源码
- 康耐视电子表格实战
- getdata破解补丁
- 一键反修复远程桌面.rar
- STM32F407ZGT6实现HAL库硬件I2C读写EEPROM功
- arcGis10.2
- PS技术 在学校里 学三年 也学不到这么
- Hi3520D300 硬件设计用户指南
- Hi3536 Linux开发环境用户指南
- activiti-explorer流程图设计器汉化文件
- 数据库课程设计停车场管理系统
- m×n的长方阵迷宫问题完美求解
- 针对MPLAB® X IDE使用PICKit™ 3在线调
- 东软股份教育事业部解决方案
- 东软数字化校园网解决方案成功应用
- 奶瓶(beini)无限免费破解增强版 使
- 交通灯multisim仿真(附图)
- powerdesigner 15.1 license key
- powerdesigner15.0的注册码license key
- visio软件64位破解版本
- Internet Explorer 11 Windows 系统 各版本
-
开机速度优化工具Startup Dela
yer3.0中 - tomcat 8.0 32位 绿色版
- 四路抢答器
- SolidWorks-100多个
- delphi源码-检测是否运行了多个程序
评论
共有 条评论