资源简介
使用VGG16网络实现对传统MNIST手写数据集的识别任务。
代码片段和文件信息
#Create Wed May 2019-5-29 19:37:16
#End 2019-5-29 21:30:35
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
mnist = input_data.read_data_sets(‘MNIST/‘ one_hot = True)
x = tf.placeholder(tf.float32 [None 784])
y = tf.placeholder(tf.float32 [None 10])
keep_prob = tf.placeholder(tf.float32)
def conv2d(name x w b):
return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x w strides = [1 1 1 1] padding = ‘SAME‘) b) name = name)
def max_pool(name x):
return tf.nn.max_pool(x ksize = [1 2 2 1] strides = [1 2 2 1] padding = ‘SAME‘ name = name)
def norm(name x):
return tf.nn.lrn(x depth_radius = None bias = 0.01 alpha = 0.001 beta = 1.0 name = name)
weights = {
‘wc1‘: tf.Variable(tf.random_normal([3 3 1 64]))
‘wc2‘: tf.Variable(tf.random_normal([3 3 64 64]))
‘wc3‘: tf.Variable(tf.random_normal([3 3 64 128]))
‘wc4‘: tf.Variable(tf.random_normal([3 3 128 128]))
‘wc5‘: tf.Variable(tf.random_normal([3 3 128 256]))
‘wc6‘: tf.Variable(tf.random_normal([3 3 256 256]))
‘wc7‘: tf.Variable(tf.random_normal([3 3 256 256]))
‘wc8‘: tf.Variable(tf.random_normal([3 3 256 256]))
‘wc9‘: tf.Variable(tf.random_normal([3 3 256 512]))
‘wc10‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc11‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc12‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc13‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc14‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc15‘: tf.Variable(tf.random_normal([3 3 512 512]))
‘wc16‘: tf.Variable(tf.random_normal([3 3 512 256]))
‘wd1‘: tf.Variable(tf.random_normal([4*4*256 4096]))
‘wd2‘: tf.Variable(tf.random_normal([4096 4096]))
‘out‘: tf.Variable(tf.random_normal([4096 10]))
}
biases = {
‘bc1‘: tf.Variable(tf.zeros([64]))
‘bc2‘: tf.Variable(tf.zeros([64]))
‘bc3‘: tf.Variable(tf.zeros([128]))
‘bc4‘: tf.Variable(tf.zeros([128]))
‘bc5‘: tf.Variable(tf.zeros([256]))
‘bc6‘: tf.Variable(tf.zeros([256]))
‘bc7‘: tf.Variable(tf.zeros([256]))
‘bc8‘: tf.Variable(tf.zeros([256]))
‘bc9‘: tf.Variable(tf.zeros([512]))
‘bc10‘: tf.Variable(tf.zeros([512]))
‘bc11‘: tf.Variable(tf.zeros([512]))
‘bc12‘: tf.Variable(tf.zeros([512]))
‘bc13‘: tf.Variable(tf.zeros([512]))
‘bc14‘: tf.Variable(tf.zeros([512]))
‘bc15‘: tf.Variable(tf.zeros([512]))
‘bc16‘: tf.Variable(tf.zeros([256]))
‘bd1‘: tf.Variable(tf.zeros([4096]))
‘bd2‘: tf.Variable(tf.zeros([4096]))
‘out‘: tf.Variable(tf.zeros([10]))
}
#2 4 12进行池化
def VGG16(x weights biases dropout):
x = tf.reshape(x shape = [-1 28 28 1])
conv1 = conv2d(‘conv1‘ x weights[‘wc1‘] biases[‘bc1‘])
#28*28*64
norm1 = norm(‘norm1‘ conv1)
conv2 = conv2d(‘conv2‘ norm1 weights[‘wc2‘] biases[‘b
相关资源
- mnist_acgan.py
- CNN网络识别Mnist的源码,有详细注释,
- 基于Mnist数据集的贝叶斯分类器
- MNIST数据集获取 input_data.py
- pytorch-基于RNN的MNIST手写数据集识别
- 读取自己的mnist数据集代码mnist.py
- CBAM_MNIST.py
- 利用keras实现的cnn卷积神经网络对手写
- 保存图片为 mnist格式
- 预训练数据VGG_imagenet.npy
- Vgg16 model
- mnist_CNN 深度学习小
- python MNIST分类 tensorflow
- mnist_mlp.py
- 使用逻辑回归进行MNIST字符分类识别代
- 基于kNN方法的MNIST手写数字识别Tenso
- python kNN算法实现MNIST数据集分类 k值
- BP算法实现26个字母识别
- MLPCNN;识别手写数字集Mnist
- KNN-mnist识别
- tutorials.7z
评论
共有 条评论