-
大小: 3KB文件类型: .rar金币: 1下载: 0 次发布日期: 2021-06-10
- 语言: 其他
- 标签: CGAN tensorflow
资源简介
条件生成对抗网络(CGAN), tensorflow实现
代码片段和文件信息
# -*- coding: utf-8 -*-
from __future__ import division print_function absolute_import
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import cv2
import random
# os.environ[“CUDA_DEVICE_ORDER“] = “PCI_BUS_ID“
# os.environ[“CUDA_VISIBLE_DEVICES“] = “1“
flags = tf.app.flags
flags.DEFINE_integer(“iter“ 1000000 “Iteration to train [1000000]“)
flags.DEFINE_integer(“batch_size“ 64 “The size of batch images [64]“)
flags.DEFINE_string(“model_path“ ‘./model/cgan.model‘ “Save model path [‘./model/cgan.model‘]“)
flags.DEFINE_boolean(“is_train“ False “Train or test [False]“)
flags.DEFINE_integer(“test_number“ None “The number that want to generate if None generate randomly [None]“)
FLAGS = flags.FLAGS
mnist = input_data.read_data_sets(‘./MNIST_data‘ one_hot=True)
Z_dim = 100
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128
def xavier_init(size):
in_dim = size[0]
xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
return tf.random_normal(shape=size stddev=xavier_stddev)
“““ Discriminator Net model “““
X = tf.placeholder(tf.float32 shape=[None 784])
y = tf.placeholder(tf.float32 shape=[None y_dim])
D_W1 = tf.Variable(xavier_init([X_dim + y_dim h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
D_W2 = tf.Variable(xavier_init([h_dim 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))
theta_D = [D_W1 D_W2 D_b1 D_b2]
def discriminator(x y):
inputs = tf.concat(axis=1 values=[x y])
D_h1 = tf.nn.relu(tf.matmul(inputs D_W1) + D_b1)
D_logit = tf.matmul(D_h1 D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob D_logit
“““ Generator Net model “““
Z = tf.placeholder(tf.float32 shape=[None Z_dim])
G_W1 = tf.Variable(xavier_init([Z_dim + y_dim h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
G_W2 = tf.Variable(xavier_init([h_dim X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
theta_G = [G_W1 G_W2 G_b1 G_b2]
def generator(z y):
inputs = tf.concat(axis=1 values=[z y])
G_h1 = tf.nn.relu(tf.matmul(inputs G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1 G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
def sample_Z(m n):
return np.random.uniform(-1. 1. size=[m n])
def plot(samples):
fig = plt.figure(figsize=(4 4))
gs = gridspec.GridSpec(4 4)
gs.update(wspace=0.05 hspace=0.05)
for i sample in enumerate(samples):
ax = plt.subplot(gs[i])
plt.axis(‘off‘)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(‘equal‘)
plt.imshow(sample.reshape(28 28) cmap=‘Greys_r‘)
return fig
G_sample = generator(Z y)
D_real D_logit_real = discriminator(X y)
D_fake D_logit_fake = discriminator(G_sample y)
D_loss_real
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
文件 5453 2017-08-16 19:20 cgan_tensorflow.py
文件 973 2017-08-27 12:57 README.md
----------- --------- ---------- ----- ----
6426 2
相关资源
- (WGAN、WGAN_gp)Wasseratein GAN
- tensorflow识别花朵
- keras tensorflow lstm 多变量序列的预测
- 基于TensorFlow的Faster_R-CNN源码
- 在TensorFlow框架下实现DBN网络源码
- Tensorflow垃圾邮件分类
- TensorFlow视频教程
- tensorflow实现猫狗识别
- 双线性池化Bilinear poolingtensorflow版
- tensorflow 1.3 lstm训练和预测铁路客运数
- Tensorflow
- yolov1的tensorflow实现
- 基于Tensorflow实现BNBatch Normalization的代
- Tensorflow下构建LSTM模型
- 利用tensorflow实现3DCNN
- TensorFlow安装错误解决:ImportError: DL
- tensorflow-2.1.0-cp36-cp36m-win_amd64.whl
- tensorflow_gpu-2.1.0-cp37-cp37m-win_amd64.whl
- tensorflow的安装、图像识别应用、训练
- Opencv+Tensorflow入门人工智能处理无密完
- tensorflow下手写汉字识别及其可视化代
- tensorflow+inceptionv3网络
- tensorflow1.12.0+gpucuda 9.0
- tensorflow-2.1.0-cp37-cp37m-win_amd64.whl
- 深度学习之TensorFlow 入门、原理与进阶
- 卸载tensorflow-cpu重装tensorflow-gpu操作
- tensorflow-2.3.0-cp37-cp37m-win_amd64.whl
- tensorflow-1.15.0-cp36-cp36m-win_amd64.whl
- tensorflow麻将智能出牌源码
- 基于TensorFlow的DenseNet学习源码
评论
共有 条评论