资源简介
针对已训练好的tensorflow模型,模型是根据自身需要训练的,将模型其应用的遥感影像分类中,并显示分类结果。
代码片段和文件信息
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.io as scio
import cv2
import datetimetime
os.environ[“CUDA_VISIBLE_DEVICES“] = ‘01‘
def discrete_matshow(data labels_names=[] title=““):
# get discrete colormap
cmap = plt.get_cmap(‘Paired‘ np.max(data) - np.min(data) + 1)
# set limits .5 outside true range
mat = plt.matshow(data
cmap=cmap
vmin=np.min(data) - .5
vmax=np.max(data) + .5)
# tell the colorbar to tick at integers
cax = plt.colorbar(mat
ticks=np.arange(np.min(data) np.max(data) + 1))
# The names to be printed aside the colorbar
if labels_names:
cax.ax.set_yticklabels(labels_names)
if title:
plt.suptitle(title fontsize=14 fontweight=‘bold‘)
def next_batch(image ii h):
j = 14
temp = []
while j < h - 14:
rgb = image[ii - 14:ii + 14 j - 14:j + 14 :]
temp.append(rgb)
j += 1
temp = np.array(temp)
# print(temp.shape)
# assert temp.shape[0] == 3972
# print(temp.shape)
return temp
img = cv2.imread(‘jimo_resize_2000.tif‘)
img = cv2.cvtColor(img cv2.COLOR_BGR2RGB)
img = np.multiply(img 1.0/255.0)
print(img.shape)
m = img.shape[0]
n = img.shape[1]
print(‘load the model....‘)
vgg_saver = tf.train.import_meta_graph(‘2017.09.11-03.31.ckpt.meta‘)
vgg_graph = tf.get_default_graph()
# for n in tf.get_default_graph().as_graph_def().node:
# print(n.name)
x = tf.get_default_graph().get_tensor_by_name(‘Placeholder:0‘)
z = tf.get_default_graph().get_tensor_by_name(‘Placeholder_1:0‘)
feature = vgg_graph.get_tensor_by_name(“D_conv_mnist/fully_connected_2/BiasAdd:0“)
print(feature)
pred = tf.nn.softmax(feature)
print(‘extract jimo image feature...‘)
result = []
start_time = datetime.datetime.now()
with tf.Session() as sess:
vgg_saver.restore(sess ‘./2017.09.11-03.31.ckpt‘)
i = 14
segmentation_ = []
z_sample
评论
共有 条评论