-
大小: 8KB文件类型: .py金币: 2下载: 1 次发布日期: 2021-06-17
- 语言: Python
- 标签: tensorflow
资源简介
VGG,V3,RESNET迁移学习,tensorflow和keras写的程序
代码片段和文件信息
# -*- coding: utf-8 -*-
import os
from keras.utils import plot_model
from keras.applications.resnet50 import ResNet50
from keras.applications.vgg19 import VGG19
from keras.applications.inception_v3 import InceptionV3
from keras.layers import DenseFlattenGlobalAveragePooling2D
from keras.models import Modelload_model
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
class PowerTransferMode:
#数据准备
def DataGen(self dir_path img_row img_col batch_size is_train):
if is_train:
datagen = ImageDataGenerator(rescale=1./255 #值将在执行其他处理前乘到整个图像上,
# 我们的图像在RGB通道都是0~255的整数,
# 这样的操作可能使图像的值过高或过低,所以我们将这个值定为0~1之间的数。
zoom_range=0.25 #随机缩放的幅度
rotation_range=15. #数据提升时图片随机转动的角度
channel_shift_range=25. #随机通道偏移的幅度
width_shift_range=0.02 #数据提升时图片随机水平偏移的幅度
height_shift_range=0.02 #数据提升时图片随机竖直偏移的幅度
horizontal_flip=True #水平旋转
fill_mode=‘constant‘) #当进行变换时超出边界的点将根据本参数给定的方法进行处理
else:
datagen = ImageDataGenerator(rescale=1./255)
generator = datagen.flow_from_directory(
dir_path target_size=(img_row img_col)
batch_size=batch_size
#class_mode=‘binary‘
shuffle=is_train)
return generator
#ResNet模型
def ResNet50_model(self lr=0.005 decay=1e-6 momentum=0.9 nb_classes=2 img_rows=197 img_cols=197 RGB=True is_plot_model=False):
color = 3 if RGB else 1
base_model = ResNet50(weights=‘imagenet‘ include_top=False pooling=None input_shape=(img_rows img_cols color)
classes=nb_classes)
#冻结base_model所有层,这样就可以正确获得bottleneck特征
for layer in base_model.layers:
layer.trainable = False
x = base_model.output
#添加自己的全链接分类层
x = Flatten()(x)
#x = GlobalAveragePooling2D()(x)
#x = Dense(1024 activation=‘relu‘)(x)
predictions = Dense(nb_classes activation=‘softmax‘)(x)
#训练模型
model = Model(inputs=base_model.input outputs=predictions)
sgd = SGD(lr=lr decay=decay momentum=momentum nesterov=True)
model.compile(loss=‘categorical_crossentropy‘ optimizer=sgd metrics=[‘accuracy‘])
#绘制模型
if is_plot_model:
plot_model(model to_file=‘resnet50_model.png‘show_shapes=True)
return model
#VGG模型
def VGG19_model(self lr=0.005 decay=1e-6 momentum=0.9 nb_classes=2 img_rows=197 img_cols=197 RGB=True is_plot_model=False):
color = 3 if RGB else 1
base_model = VGG19(weights=‘imagenet‘ include_top=False pooling=
相关资源
- tensorflow制作自己的灰度图像数据集并
- anaconda下安装tensorflow(注:不同版本
- 北京大学曹健老师-人工智能实践:
- Deep Learning With Python - Jason Brownlee
- Python-自然场景文本检测PSENet的一个
- Python-高效准确的EAST文本检测器的一个
- Python-TensorFlow弱监督图像分割
- Python-基于tensorflow实现的用textcnn方法
- Python-subpixel利用Tensorflow的一个子像素
- 【官方文档】TensorFlow Python API docume
-
tensorflow画风迁移代码 st
yle transfer - 简单粗暴 TensorFlow
- [PDF] Reinforcement Learning With Open AI Tens
- tensorflow目标检测代码
- 基于Python的手写字体识别系统
- 基于Tensorflow的人脸识别源码
- python TensorFlow 官方文档中文版
- Python-在TensorFlow中实现实现图像卷积网
- tensorflow-1.9.0-cp37-cp37m-win_amd64.whl
- Faster-RCNN-TensorFlow-Python3.5-master
- 聊天机器人tensorflow
- caffe模型转化为tensorflow模型
- Python-一个非常简单的BiLSTMCRF模型用于
- Python-Tensorflow仿AlphaGo框架实现的AI围棋
- Mask R-CNN源码(TensorFlow版本)
- 基于python3 tensorflow DBN_and_RNN的实现
- tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl
- Hands-On Machine Learning with Scikit-Learn an
- python3中文识别词库模型
- tensorflow -1.4-py2.7 -cpu 版本
评论
共有 条评论