资源简介
该代码是在学习深度学习的过程中,自行编写的代码,利用cnn网络来对mnist手写字符进行高精度的识别,并加入了详细的注释,非常适合作为初次接触深度学习的新手入门。欢迎下载。
代码片段和文件信息
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 一些包括网络结构在内的超参数,全局变量
INPUT_NODE=784
OUTPUT_NODE=10
IMAGE_SIZE=28
NUM_CHANNELS=1
NUM_LABELS=10
CONV1_DEEP=32
CONV1_SIZE=5
CONV2_DEEP=64
CONV2_SIZE=5
FC_SIZE=512
BATCH_SIZE=20
LEARNING_RATE_base=0.06
LEARNING_RATE_DECAY=0.999
REGULARIZATION_RATE=0.0001
TRAINING_STEPS=20
MOVING_AVERAGE_DECAY=0.99
# 前向传播的计算过程
def inference(input_tensortrainavg_class=Noneregularizer=Nonereuse=False):
with tf.variable_scope(‘layer1_conv1‘reuse=reuse):
conv1_weights=tf.get_variable(
‘weight‘[CONV1_SIZECONV1_SIZENUM_CHANNELSCONV1_DEEP]
initializer=tf.truncated_normal_initializer(stddev=0.1))
conv1_biases=tf.get_variable(
‘bias‘[CONV1_DEEP]initializer=tf.constant_initializer(0.0))
if avg_class==None:
conv1=tf.nn.conv2d(
input_tensorconv1_weightsstrides=[1221]padding=‘SAME‘)
relu1=tf.nn.relu(tf.nn.bias_add(conv1conv1_biases))
else:
conv1=tf.nn.conv2d(
input_tensoravg_class.average(conv1_weights)strides=[1221]padding=‘SAME‘)
relu1=tf.nn.relu(tf.nn.bias_add(conv1avg_class.average(conv1_biases)))
with tf.name_scope(‘layer2_pool1‘):
pool1=tf.nn.max_pool(
relu1ksize=[1221]strides=[1221]padding=‘SAME‘)
with tf.variable_scope(‘layer3_conv2‘reuse=reuse):
conv2_weights=tf.get_variable(
‘weight‘[CONV2_SIZECONV2_SIZECONV1_DEEPCONV2_DEEP]
initializer=tf.truncated_normal_initializer(stddev=0.1))
conv2_baises=tf.get_variable(
‘bias‘[CONV2_DEEP]initializer=tf.constant_initializer(0.0))
if avg_class==None:
conv2=tf.nn.conv2d(
relu1conv2_weightsstrides=[1221]padding=‘SAME‘)
relu2=tf.nn.relu(tf.nn.bias_add(conv2conv2_baises))
else:
conv2=tf.nn.conv2d(
relu1avg_class.average(conv2_weights)strides=[1221]padding=‘SAME‘)
relu2=tf.nn.relu(tf.nn.bias_add(conv2avg_class.average(conv2_baises)))
with tf.name_scope(‘layer4-pool2‘):
pool2=tf.nn.max_pool(
relu2ksize=[1221]strides=[1221]padding=‘SAME‘)
pool_shape=pool2.get_shape().as_list()
nodes=pool_shape[1]*pool_shape[2]*pool_shape[3]
reshaped=tf.reshape(pool2[-1nodes])
# print(type(reshaped))
with tf.variable_scope(‘layer5_fc1‘reuse=reuse):
fc1_weights=tf.get_variable(
‘weight‘[nodesFC_SIZE]
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer!=None:
tf.add_to_collection(‘losses‘regularizer(fc1_weights))
fc1_biases=tf.get_variable(
‘bias‘[FC_SIZE]initializer=tf.constant_initializer(0.1))
if avg_class==None:
fc1=tf.nn.relu(tf.matmu
相关资源
- DeepLabV3-Tensorflow-master
- 基于TensorFlow实现CNN文本分类实验指导
- 人工智能算法实现mnist手写数字识别
- 利用CNN网络实现mnist图像分类,手动实
- tensorflow2.0 yolo3目标检测算法
- MNIST手写体数字训练/测试数据集(图
- tensorflow制作自己的灰度图像数据集并
- anaconda下安装tensorflow(注:不同版本
- 北京大学曹健老师-人工智能实践:
- Deep Learning With Python - Jason Brownlee
- Python-自然场景文本检测PSENet的一个
- Python-用PyTorch10实现FasterRCNN和MaskRCNN比
- Python-高效准确的EAST文本检测器的一个
- Python-TensorFlow弱监督图像分割
- Python-基于tensorflow实现的用textcnn方法
- Python-FastSCNN的PyTorch实现快速语义分割
- Python-subpixel利用Tensorflow的一个子像素
- 【官方文档】TensorFlow Python API docume
- 基于深度学习堆栈自动编码器模型的
- 性别模型库 simple_CNN.81-0.96.hdf5
- lightened_cnn_S 5M模型
- bayes分类python
- knn算法识别mnist图片-python3
-
tensorflow画风迁移代码 st
yle transfer - TBCNN 源码
- Ubuntu18.04LTS下安装 Caffe-GPU版本及 Ana
- faster rcnn(python+caffe)源代码
- 简单粗暴 TensorFlow
- [PDF] Reinforcement Learning With Open AI Tens
- tensorflow目标检测代码
评论
共有 条评论