资源简介
该代码是在学习深度学习的过程中,自行编写的代码,利用cnn网络来对mnist手写字符进行高精度的识别,并加入了详细的注释,非常适合作为初次接触深度学习的新手入门。欢迎下载。
代码片段和文件信息
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 一些包括网络结构在内的超参数,全局变量
INPUT_NODE=784
OUTPUT_NODE=10
IMAGE_SIZE=28
NUM_CHANNELS=1
NUM_LABELS=10
CONV1_DEEP=32
CONV1_SIZE=5
CONV2_DEEP=64
CONV2_SIZE=5
FC_SIZE=512
BATCH_SIZE=20
LEARNING_RATE_base=0.06
LEARNING_RATE_DECAY=0.999
REGULARIZATION_RATE=0.0001
TRAINING_STEPS=20
MOVING_AVERAGE_DECAY=0.99
# 前向传播的计算过程
def inference(input_tensortrainavg_class=Noneregularizer=Nonereuse=False):
with tf.variable_scope(‘layer1_conv1‘reuse=reuse):
conv1_weights=tf.get_variable(
‘weight‘[CONV1_SIZECONV1_SIZENUM_CHANNELSCONV1_DEEP]
initializer=tf.truncated_normal_initializer(stddev=0.1))
conv1_biases=tf.get_variable(
‘bias‘[CONV1_DEEP]initializer=tf.constant_initializer(0.0))
if avg_class==None:
conv1=tf.nn.conv2d(
input_tensorconv1_weightsstrides=[1221]padding=‘SAME‘)
relu1=tf.nn.relu(tf.nn.bias_add(conv1conv1_biases))
else:
conv1=tf.nn.conv2d(
input_tensoravg_class.average(conv1_weights)strides=[1221]padding=‘SAME‘)
relu1=tf.nn.relu(tf.nn.bias_add(conv1avg_class.average(conv1_biases)))
with tf.name_scope(‘layer2_pool1‘):
pool1=tf.nn.max_pool(
relu1ksize=[1221]strides=[1221]padding=‘SAME‘)
with tf.variable_scope(‘layer3_conv2‘reuse=reuse):
conv2_weights=tf.get_variable(
‘weight‘[CONV2_SIZECONV2_SIZECONV1_DEEPCONV2_DEEP]
initializer=tf.truncated_normal_initializer(stddev=0.1))
conv2_baises=tf.get_variable(
‘bias‘[CONV2_DEEP]initializer=tf.constant_initializer(0.0))
if avg_class==None:
conv2=tf.nn.conv2d(
relu1conv2_weightsstrides=[1221]padding=‘SAME‘)
relu2=tf.nn.relu(tf.nn.bias_add(conv2conv2_baises))
else:
conv2=tf.nn.conv2d(
relu1avg_class.average(conv2_weights)strides=[1221]padding=‘SAME‘)
relu2=tf.nn.relu(tf.nn.bias_add(conv2avg_class.average(conv2_baises)))
with tf.name_scope(‘layer4-pool2‘):
pool2=tf.nn.max_pool(
relu2ksize=[1221]strides=[1221]padding=‘SAME‘)
pool_shape=pool2.get_shape().as_list()
nodes=pool_shape[1]*pool_shape[2]*pool_shape[3]
reshaped=tf.reshape(pool2[-1nodes])
# print(type(reshaped))
with tf.variable_scope(‘layer5_fc1‘reuse=reuse):
fc1_weights=tf.get_variable(
‘weight‘[nodesFC_SIZE]
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer!=None:
tf.add_to_collection(‘losses‘regularizer(fc1_weights))
fc1_biases=tf.get_variable(
‘bias‘[FC_SIZE]initializer=tf.constant_initializer(0.1))
if avg_class==None:
fc1=tf.nn.relu(tf.matmu
相关资源
- CNN_源代码
- 基于Mnist数据集的贝叶斯分类器
- MNIST数据集获取 input_data.py
- faster rcnn end-to-end loss曲线的绘制
- 卷积神经网络回归模型
- Tensorflow-BiLSTM分类
- TensorFlow实现人脸识别(3)--------对人
- pytorch-基于RNN的MNIST手写数据集识别
- 读取自己的mnist数据集代码mnist.py
- Python-手势识别使用在TensorFlow中卷积神
- Python+Tensorflow+CNN实现车牌识别的
- 基于TensorFlow实现的闲聊机器人
- CBAM_MNIST.py
- 利用keras实现的cnn卷积神经网络对手写
- AI智能五子棋Python代码
- TensorFlow 实现 Yolo
- 保存图片为 mnist格式
- 基于tensorflow的遥感影像分类
- 安装步骤。提取码也在里面
- 神经网络模型python模板
- autoencoder自编码器tensorflow代码
- CNN卷积神经网络TensorFlow代码
- Tensorflow笔记-中国大学全部讲义源代码
- DeepLab-ResNet-101
- 基于tensorflow的二分类的python实现注释
- tensorflow的ckpt文件转pb模型文件
- tensorflow-C3D-ucf101网络
- lstm_tensorflow
- 《TensorFlow2深度学习》
- TensorFlow Python API documentation.pdf
评论
共有 条评论