资源简介
本资源为纯python实现mnist手写体识别的代码,为作者本人所写,供深度学习初学者共同交流探讨,欢迎二次创作,网络为三层,可达到97%上准确率,模型可以选择多种训练方式,学习率,激活函数,损失函数等我都写了相关函数,可以选择,模型也可以自由变换,只需要改一下前面常量参数值就行。升级版本正在打包测试过程中,完成后可以自行选择batch—size大小等,具体介绍可以看我置顶博文介绍
代码片段和文件信息
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import h5py
#plt.rcParams[“font.family“]=“SimHei“
#加载keras内部mnist数据集
mnist=tf.keras.datasets.mnist
(train_xtrain_y)(test_xtest_y)=mnist.load_data()
train_xtest_x=train_x/255test_x/255
#网络模型结构参数
width_input=784 #输入层神经网络节点数=28*28
width_net1=100 #第一层神经网络节点数
width_net2=100 #第二层神经网络节点数
width_net3=10 #输出层神经网络节点数
#模型训练参数
epoch=50
way_dec_lr=1 #input:1 or 2
“““
学习率更新方式,选1,表示每lr_dec_epoch轮固定按lr_dec_rate比例减少学习率
选择2,表示记录5次学习率大小,当当前轮次loss值大于前nub次(包括本次)loss平均值
时,学习率自动降为当前学习率0.1倍,当学习率降为last_lr时,训练终止,保存模型
“““
nub=3 #设置记录nub次loss值
last_lr=0.0001 #方式2时,最终截止学习率值
learn_rate=0.01 #默认学习率
init_learn_rate=0.01 #初始学习率
lr_dec_epoch=10 #设置每10轮更新一次学习率
lr_dec_rate=0.5 #跟新学习率倍数
savepath=‘data/weight4.h5‘ #保存模型地址
loadmodel=‘data/weight3.h5‘ #当为迁移学习时,载入模型地址(请确保本次训练模型结构与加载的模型一致)
isretrain=False #是否为迁移学习True or False
#隐含层的激活函数
def sigmoid(x):
return 1/(1+np.exp(-x))
#输出层的激活函数
def softmax(y):
c=np.max(y)
y=y-c
sum=np.sum(np.exp(y))
return np.exp(y)/sum
#定义均方误差损失函数定义
def loss(y_prey_grtru):
return np.sum(np.square(y_pre-y_grtru))
#定义交叉熵损失函数
def cross_entropy_loss(y_prey_grtru):
return -np.sum(y_grtru*np.log(y_pre)+(1-y_grtru)*np.log(1-y_pre))
#定义网络输入层
x=np.zeros((width_input))
#定义网络第一层
a1=np.zeros((width_net1))
#定义网络隐藏层
a2=np.zeros((width_net2))
#定义网络输出层
y=np.zeros((width_net3))
#模型权重导入
def get_model(weight_path):
h5f=h5py.File(weight_path‘r‘)
w1=h5f[‘w1‘][:]
b1=h5f[‘b1‘][:]
w2=h5f[‘w2‘][:]
b2=h5f[‘b2‘][:]
w3=h5f[‘w3‘][:]
b3=h5f[‘b3‘][:]
return w1w2w3b1b2b3
#初始化模型
def genarate_model():
w1=np.random.normal(02/width_input(width_inputwidth_net1))
b1=np.random.normal(02/width_net1(width_net1))
w2=np.random.normal(02/width_net1(width_net1width_net2))
b2=np.random.normal(02/width_net2(width_net2))
w3=np.random.normal(02/width_net2(width_net2width_net3))
b3=np.random.normal(02/width_net3(width_net3))
return w1w2w3b1b2b3
#初始化nub个临时保存模型的参数以便在早停前选取最优模型
w11=np.zeros((nubwidth_inputwidth_net1))
b11=np.zeros((nubwidth_net1))
w21=np.zeros((nubwidth_net1width_net2))
b21=np.zeros((nubwidth_net2))
w31=np.zeros((nubwidth_net2width_net3))
b31=np.zeros((nubwidth_net3))
#模型参数生成
if isretrain:
w1w2w3b1b2b3=get_model(loadmodel)
else:
w1w2w3b1b2b3=genarate_model()
#初始化参数z(其中a=sigmoid(z))
z1=np.dot(xw1)+b1
z2=np.dot(a1w2)+b2
z3=np.dot(a2w3)+b3
#定义前向传播
def feedforward(awb):
return sigmoid(np.dot(aw)+b)
#保存模型
def save_model(savepathw_1w_2w_3b_1b_2b_3):
filename=savepath
h5f=h5py.File(filename‘w‘)
h5f.create_dataset(‘w1‘data=w_1)
h5f.create_dataset(‘w2‘data=w_2)
h5f.create_dataset(‘w3‘data=w_3)
h5f.cre
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 7151 2020-04-02 19:53 minist-network.py
文件 1821 2020-04-02 22:14 minist_pre.py
相关资源
- python实现有序边表算法.zip
- Python爬取豆瓣图书信息并保存到本地
- python实现有向图单源最短路径迪杰斯
- 文件夹下所有图片的读取以及显示p
- python 实现图片像素大小设置
- 经典遗传算法(SGA)解01背包问题的
- 第六章Python函数习题及答案--中文
- SVM鸢尾花分类Python实现.rar
- arima预测python程序
- 必应壁纸天天换python小程序.zip
- python小项目--外星人入侵
- Flask项目实战-超市商品管理平台
- pythonreader.rar
- Python Scrapy爬虫爬取微博和微信公众号
- python写盛金法求一元三次方方程解
- 老男孩Python2018基础高级进阶(28周)
- python http服务器搭建
- Python输入年份月份显示日历
- python实现百度坐标和世界经纬度坐标
- 利用OpenCV检测人脸python程序
- JSYX2.0.zip
- Python题目汇总含答案pdf
- 模态分解emd算法Python实现
- Python读取Las与转换为TXT.zip
- backup.sh.py
- BSTestRunner.pypython3
- SI模型,影响力传播模型,传染病模型
- python自动抓取网页中的pdf文件
- python爬虫网站图片
- Anaconda3 for MacOSX x64百度云
评论
共有 条评论