-
大小: 2.19MB文件类型: .rar金币: 1下载: 0 次发布日期: 2023-08-13
- 语言: 其他
- 标签: RtFrecords
资源简介
tensorflow下 自制rfrecords数据集采用one-hot编码做图像分类源码
代码片段和文件信息
# -*- coding: utf-8 -*-
“““
Created on Sat Feb 23 23:21:44 2019
@author: Administrator
“““
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import OneHotEncoder
from RTFrcord_read_data import read_and_decode
############################################################################################
height=100
weight=100
#############################################################################################
batch_size=432
#定义初始化权重和偏置函数
def weight_variable(shape):
return(tf.Variable(tf.random_normal(shapestddev=0.01)))
def bias_variable(shape):
return(tf.Variable(tf.constant(0.1shape=shape)))
#定义输入数据和dropout占位符
X=tf.placeholder(tf.float32[batch_sizeheight weight3])
y_=tf.placeholder(tf.float32[batch_size8])
keep_pro=tf.placeholder(tf.float32)
#搭建网络
def model(Xkeep_pro):
w1=weight_variable([55332])
b1=bias_variable([32])
conv1=tf.nn.relu(tf.nn.conv2d(Xw1strides=[1111]padding=‘SAME‘)+b1)
pool1=tf.nn.max_pool(conv1ksize=[1441]strides=[1441]padding=‘SAME‘)
w2=weight_variable([553264])
b2=bias_variable([64])
conv2=tf.nn.relu(tf.nn.conv2d(pool1w2strides=[1111]padding=‘SAME‘)+b2)
pool2=tf.nn.max_pool(conv2ksize=[1441]strides=[1441]padding=‘SAME‘)
tensor=tf.reshape(pool2[batch_size-1])
dim=tensor.get_shape()[1].value
w3=weight_variable([dim1024])
b3=bias_variable([1024])
fc1=tf.nn.relu(tf.matmul(tensorw3)+b3)
h_fc1=tf.nn.dropout(fc1keep_pro)
w4=weight_variable([10248])
b4=bias_variable([8])
y_conv=tf.nn.softmax(tf.matmul(h_fc1w4)+b4)
return(y_conv)
#定义网络,并设置损失函数和训练器
y_conv=model(Xkeep_pro)
cost=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y_conv)reduction_indices=[1]))
train_step=tf.train.AdamOptimizer(0.001).minimize(cost)
#计算准确率
correct_prediction=tf.equal(tf.argmax(y_conv1)tf.argmax(y_1))
accuracy=tf.reduce_mean(tf.cast(correct_predictiontf.float32))
#读取tfrecords数据
imagelabel=read_and_decode(“train1.tfrecords“)
#定义会话,并开始训练
with tf.Session() as sess:
tf.global_variables_initializer().run()
#定义多线程
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(coord=coord)
#定义训练图像和标签
example=np.zeros((batch_sizeheight weight3))
l=np.zeros((batch_size1))
try:
#将数据存入example和l并将转化成one_hot形式
for epoch in range(batch_size):
example[epoch]l[epoch]=sess.run([imagelabel])
print(l)
enc=OneHotEncoder()
l=enc.fit_transform(l)
l=l.toarray()
print(l)
for i in range(100):
#开始训练
sess.run(train_stepfeed_dict={X:exampley_:lkeep_pro:0.5})
if i%10==0:
print(‘train step‘‘%04d ‘ %(i+1)‘Accuracy=‘sess.run(accuracyfeed_dict={X:exampley_:lkeep_pro:0.5}))
except tf.errors.OutOfRangeError:
print(‘done!‘)
finally:
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 3234 2019-02-23 23:47 RTFrcords\data_classification.py
文件 1743 2019-02-23 23:47 RTFrcords\RTFrcord_read_data.py
文件 1819 2019-02-23 23:08 RTFrcords\RTFrcord_save_data.py
文件 13016413 2019-02-23 23:09 RTFrcords\train1.tfrecords
文件 893 2019-02-23 23:27 RTFrcords\__pycache__\RTFrcord_read_data.cpython-36.pyc
目录 0 2019-02-23 23:27 RTFrcords\__pycache__
目录 0 2019-02-23 23:27 RTFrcords
----------- --------- ---------- ----- ----
13024102 7
- 上一篇:LOIC(低轨道离子加农炮) 工具使用方法
- 下一篇:哈工大王义和近世代数答案
评论
共有 条评论