资源简介
Tensorflow implementation of [Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434) which is a stabilize Generative Adversarial Networks.
代码片段和文件信息
import os
import numpy as np
import tensorflow as tf
from model import DCGAN
from utils import pp visualize show_all_variables
flags = tf.app.flags
flags.DEFINE_integer(“epoch“ 100 “训练轮次,默认100次“)
flags.DEFINE_float(“learning_rate“ 0.0002 “学习速率,默认0.0002“)
flags.DEFINE_float(“beta1“ 0.5 “Adam动量,默认0.5“)
flags.DEFINE_integer(“train_size“ np.inf “每个轮次训练的次数,默认为np.inf“)
flags.DEFINE_integer(“batch_size“ 64 “每次训练引入的数据量,默认为64“)
flags.DEFINE_integer(“input_height“ 256 “图片的输入高度,默认为256“)
flags.DEFINE_integer(“input_width“ None “图片的输入宽度,默认为空;如果为空,则跟高度一致“)
flags.DEFINE_integer(“output_height“ 128 “图片的输出高度,默认为128“)
flags.DEFINE_integer(“output_width“ None “图片的输出宽度,默认为空;如果为空,则跟高度一致“)
flags.DEFINE_string(“dataset“ “celebA“ “The name of dataset [celebA mnist lsun]“)
flags.DEFINE_string(“input_fname_pattern“ “*.jpg“ “输入的图片格式[*]“)
flags.DEFINE_string(“checkpoint_dir“ “checkpoint“ “模型的保存路径,默认为checkpoint“)
flags.DEFINE_string(“sample_dir“ “samples“ “保存例子的文件夹名,默认为samples“)
flags.DEFINE_boolean(“train“ False “如果为真则训练,否则进行测试,默认为假“)
flags.DEFINE_boolean(“crop“ True “如果为真则裁剪,否则不裁剪,默认为真“)
flags.DEFINE_boolean(“visualize“ True “是否为可视化,默认为假“)
FLAGS = flags.FLAGS
def main(_):
pp.pprint(flags.FLAGS.__flags)
if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
run_config = tf.ConfigProto(allow_soft_placement=True)
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
if FLAGS.dataset == ‘mnist‘:
dcgan = DCGAN(
sess
input_width=FLAGS.input_width
input_height=FLAGS.input_height
output_width=FLAGS.output_width
output_height=FLAGS.output_height
batch_size=FLAGS.batch_size
sample_num=FLAGS.batch_size
y_dim=10
dataset_name=FLAGS.dataset
input_fname_pattern=FLAGS.input_fname_pattern
crop=FLAGS.crop
checkpoint_dir=FLAGS.checkpoint_dir
sample_dir=FLAGS.sample_dir)
else:
dcgan = DCGAN(
sess
input_width=FLAGS.input_width
input_height=FLAGS.input_height
output_width=FLAGS.output_width
output_height=FLAGS.output_height
batch_size=FLAGS.batch_size
sample_num=FLAGS.batch_size
dataset_name=FLAGS.dataset
input_fname_pattern=FLAGS.input_fname_pattern
crop=FLAGS.crop
checkpoint_dir=FLAGS.checkpoint_dir
sample_dir=FLAGS.sample_dir)
show_all_variables()
if FLAGS.train:
dcgan.train(FLAGS)
else:
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
raise Exception(“[!] Train a model first then run test mode“)
OPTION = 1
visualize(sess dcgan F
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 7840016 1998-01-26 23:07 DC-GAN\data\mnist\t10k-images.idx3-ubyte
文件 10008 1998-01-26 23:07 DC-GAN\data\mnist\t10k-labels.idx1-ubyte
文件 47040016 1996-11-18 23:36 DC-GAN\data\mnist\train-images.idx3-ubyte
文件 60008 1996-11-18 23:36 DC-GAN\data\mnist\train-labels.idx1-ubyte
文件 150616 2017-07-20 09:21 DC-GAN\DCGAN.png
文件 3449 2017-08-31 18:40 DC-GAN\main.py
文件 20211 2017-08-31 17:21 DC-GAN\model.py
文件 3493 2017-08-31 16:00 DC-GAN\ops.py
文件 3338 2017-07-20 09:21 DC-GAN\README.md
文件 8821 2017-08-31 17:31 DC-GAN\utils.py
目录 0 2017-08-31 17:11 DC-GAN\data\mnist
目录 0 2017-08-31 17:14 DC-GAN\data
目录 0 2017-08-31 18:41 DC-GAN
----------- --------- ---------- ----- ----
55139976 13
- 上一篇:OpenCV自带视频测试文件vtest.avi
- 下一篇:仿小米运动的运动记录界面
评论
共有 条评论