• 大小: 0.01M
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-05-13
  • 语言: Python
  • 标签: 其他  

资源简介

vgg_easy.py

资源截图

代码片段和文件信息

import os
import numpy as np
import tensorflow as tf
from PIL import Image
from skimage import io transform
import glob
os.environ[“CUDA_VISIBLE_DEVICES“] = “0“
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
path = ‘./easy/‘          # for one dataset cross validation
train_path = ‘./example/train/‘ # for train and test set
test_path = ‘./example/test/‘
w = 224
h = 224
c = 3
n_class = 4

def read_img(path):
    cate   = [path + x for x in os.listdir(path) if os.path.isdir(path + x)]
    imgs   = []
    labels = []
    label_list = np.eye(n_class)
    for idx folder in enumerate(cate):            #search folder
        for im in glob.glob(folder + ‘/*.jpg‘):    #change doc type if necessary
            img = io.imread(im)
            img = transform.resize(img (w h c))
            imgs.append(img)                        # (sum2242243)
            labels.append(label_list[idx])          # (sum4)                         
    return np.asarray(imgs np.float32) np.asarray(labels np.float32)

#------------------train and test set------------
data label = read_img(train_path)
   
num_example = data.shape[0]                        
arr = np.arange(num_example)                    
np.random.shuffle(arr)                           # random sequence
x_train = data[arr]
y_train = label[arr]

s = num_example

data_t label_t = read_img(test_path)
   
s_test = data_t.shape[0]                        
arr = np.arange(s_test)                    
np.random.shuffle(arr)                           # random sequence
x_val = data_t[arr]
y_val = label_t[arr]

# ------------------one dataset cross validation ----------
#data label = read_img(path)
   
#num_example = data.shape[0]                        
#arr = np.arange(num_example)                    
#np.random.shuffle(arr)                           # random sequence
#data = data[arr]
#label = label[arr]

#ratio = 0.8
#s = np.int(num_example * ratio)
#x_train = data[:s]                         # (sum_train2242243)
#y_train = label[:s]                        # (sum_train4)
#x_val   = data[s:]
#y_val   = label[s:]    

#------------------vgg16 structure----------------
 
x = tf.placeholder(tf.float32 shape=[None h w c])
y = tf.placeholder(tf.float32 shape=[None n_class])     

#----------------- conv1 ------------------------

w_conv1_1 = tf.Variable(tf.truncated_normal([3 3 3 64] stddev=0.1))
b_conv1_1 = tf.Variable(tf.constant(0.1 shape=[64]))
L_conv1_1 = tf.nn.relu(tf.nn.conv2d(x w_conv1_1strides=[1 1 1 1] padding=‘SAME‘) + b_conv1_1)

w_conv1_2 = tf.Variable(tf.truncated_normal([3 3 64 64] stddev=0.1))
b_conv1_2 = tf.Variable(tf.constant(0.1 shape=[64]))
L_conv1_2 = tf.nn.relu(tf.nn.conv2d(L_conv1_1 w_conv1_2strides=[1 1 1 1] padding=‘SAME‘) + b_conv1_2)

L_pool1 = tf.nn.max_pool(L_conv1_2 ksize=[1 2 2 1] strides=[1 2 2 1] padding=‘SAME‘)
#----------------- conv2 ------------------------

w_conv2_1 = tf.Variable(tf.truncated_normal([3 3 64 128

评论

共有 条评论