-
大小: 5KB文件类型: .zip金币: 1下载: 0 次发布日期: 2021-06-04
- 语言: 其他
- 标签: 3DCNN tensorflow
资源简介
Tensorflow 3D CNN
代码片段和文件信息
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
“““
Created on Thu Feb 23 10:51:28 2017
@author: cdn
“““
import numpy as np
np.random.seed(1234)
import timeit
import os
import matplotlib.pyplot as plt
import scipy.io as sio
from sklearn.cross_validation import StratifiedKFold
import tensorflow as tf
from tensorflow.contrib.layers import fully_connected convolution2d flatten dropout
from tensorflow.python.layers.pooling import max_pooling3d
from tensorflow.python.ops.nn import relusoftmax
from tensorflow.python.framework.ops import reset_default_graph
import six.moves.cPickle as pickle
from confusionmatrix import ConfusionMatrix
def onehot(t num_classes):
out = np.zeros((t.shape[0] num_classes))
for row col in enumerate(t):
out[row col] = 1
return out
def load_cv(cv_name = ‘index_10fold.pkl‘fold_idx = 0):
input_doc = open(cv_name‘rb‘)
in_data = pickle.load(input_doc)
input_doc.close()
train_idxtest_idx = in_data
train_index0 = train_idx[fold_idx]
test_index = test_idx[fold_idx]
val_index = test_idx[fold_idx-1]
train_index = list(set(train_index0)-set(val_index))
train_index = np.array(train_index)
return train_indextest_indexval_index
def load_data(fold_index):
data = sio.loadmat(‘ADNI/PET/Data_PET_d3.mat‘)
AD_data = data[‘AD_data_PET_d3‘]
NC_data = data[‘NORMAL_data_PET_d3‘]
# AD_data = data[‘AffineAD_128‘]
# NC_data = data[‘AffineNC_128‘]
ad_numsizeXsizeYsizeZ = AD_data.shape
nc_num___ = NC_data.shape
size_input = [1sizeXsizeYsizeZ]
np.random.seed(1234)
random_idx = np.random.permutation(ad_num+nc_num)
adnc_data = np.concatenate((AD_dataNC_data)axis=0)[random_idx]
labels = np.hstack((np.ones((ad_num))np.zeros((nc_num))))[random_idx]
trainidtestidvalidid = load_cv(fold_idx = fold_index)
x_train = adnc_data[trainid]
y_train = labels[trainid]
x_test = adnc_data[testid]
y_test = labels[testid]
x_valid = adnc_data[validid]
y_valid = labels[validid]
return x_trainy_trainx_testy_testx_validy_validsize_input
n_fold = 10
train_accuracy = np.zeros((n_fold))
test_accuracy = np.zeros((n_fold))
valid_accuracy = np.zeros((n_fold))
t1_time = timeit.default_timer()
#for fi in range(n_fold):
num_classes = 2
num_filters_conv1 = 10
num_filters_conv2 = 25
num_filters_conv3 = 40
num_filters_conv4 = 40
dense_num = 100
size_conv = 3 # [height width]
pool_size = 2
batch_size = 5
nb_epoch = 50
fi = 0
X_trainy_trainX_testy_testX_valy_valsize_input = load_data(fi)
X_train = X_train.reshape(X_train.shape[0] 1 X_train.shape[1] X_train.shape[2]X_train.shape[3])
X_val = X_val.reshape(X_val.shape[0] 1 X_val.shape[1]X_val.shape[2]X_val.shape[3])
X_test = X_test.reshape(X_test.shape[0] 1 X_test.shape[1] X_test.shape[2]X_test.shape[3])
print(‘X_train shape:‘ X_train.shape)
print(X_train.shape[0] ‘train samples‘)
print(X_val.shape[0] ‘
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2017-04-25 15:32 tensorflow3dCNN\
文件 11695 2017-04-25 15:31 tensorflow3dCNN\3D_CNN_tensorflow.py
文件 4208 2016-11-04 23:00 tensorflow3dCNN\confusionmatrix.py
相关资源
- TensorFlow安装错误解决:ImportError: DL
- tensorflow-2.1.0-cp36-cp36m-win_amd64.whl
- tensorflow_gpu-2.1.0-cp37-cp37m-win_amd64.whl
- tensorflow的安装、图像识别应用、训练
- Opencv+Tensorflow入门人工智能处理无密完
- tensorflow下手写汉字识别及其可视化代
- tensorflow+inceptionv3网络
- tensorflow1.12.0+gpucuda 9.0
- tensorflow-2.1.0-cp37-cp37m-win_amd64.whl
- 深度学习之TensorFlow 入门、原理与进阶
- 卸载tensorflow-cpu重装tensorflow-gpu操作
- tensorflow-2.3.0-cp37-cp37m-win_amd64.whl
- tensorflow-1.15.0-cp36-cp36m-win_amd64.whl
- tensorflow麻将智能出牌源码
- 基于TensorFlow的DenseNet学习源码
- Tensorflow快速实现图像的风格迁移
- TensorFlow平台上基于LSTM神经网络的人体
- 深度学习框架Tensorflow学习与应用
- 2018斯坦福深度学习Tensorflow实战课程课
- TensorFlow实现人脸识别(5)-------利用
- TensorFlow实现人脸识别(1)------Linux下
-
batch normalization 和 la
yer normalization - 基于tensorflow的猫狗图片的识别分类
- DQN-Atari-Tensorflow 在Tensorflow中,使用深
- 使用tensorflow实现VGG网络训练mnist数据
- Tensorflow tf.nn.atrous_conv2d如何实现空洞
- 基于TensorFlow的CNN实现Mnist手写数字识
-
深度学习项目实战视频课程-st
yleT - 电机振动故障检测tensorflow神经网络
- tensorflow语音识别完整代码
评论
共有 条评论