• 大小: 22.17MB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2023-07-06
  • 语言: 其他
  • 标签: GAN  手写数字  

资源简介

利用GAN原始模型,生成手写数字,包含数据集和代码,直接可以用。

资源截图

代码片段和文件信息

import tensorflow as tf
import numpy as np
from GAN.TWO import util
import os
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
#读入数据
mnist=input_data.read_data_sets(‘./data‘one_hot=True)
# print(mnist)

Z=tf.placeholder(tf.float32shape=[None100])

X=tf.placeholder(tf.float32shape=[None784])
#喂入数据
G_sample=util.generator(Z)
D_realD_logit_real=util.discriminator(X)
D_fakeD_logit_fake=util.discriminator(G_sample)
#计算loss
D_real_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real
                                        labels=tf.ones_like(D_logit_real)))
D_fake_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake
                                        labels=tf.zeros_like(D_logit_fake)))
D_loss=D_fake_loss+D_real_loss

G_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake
                                        labels=tf.ones_like(D_logit_fake)))

D_optimizer=tf.train.AdamOptimizer().minimize(D_lossvar_list=util.theta_D)
G_optimizer=tf.train.AdamOptimizer().minimize(G_lossvar_list=util.theta_G)

if not os.path.exists(‘out/‘):
    os.makedirs(‘out/‘)
“““
画图
“““
def plot(samples):
    gs=gridspec.GridSpec(44)
    gs.update(wspace=0.05hspace=.05)
    for isample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis(‘off‘)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect(‘equal‘)

        plt.imshow(sample.reshape(2828)cmap=‘Greys_r‘)
print(“=====================开始训练============================“)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for it in range(10000):
        X_mb_=mnist.train.next_batch(batch_size=128)
        # print(X_mb)
        _D_loss_curr=sess.run([D_optimizerD_loss]
                               feed_dict={X:X_mbZ:util.sample_z(128100)})
        _ G_loss_curr = sess.run([G_optimizer G_loss]
                                  feed_dict={Z: util.sample_z(128 100)})
        if it%1000==0:
            print(‘====================打印出生成的数据============================‘)
            samples=sess.run(G_samplefeed_dict={Z: util.sample_z(16 100)})
            plot(samples)
            plt.show()
        if it%1000==0:
            print(‘iter={}‘.format(it))
            print(‘D_loss={}‘.format(D_loss_curr))
            print(‘G_loss={}‘.format(G_loss_curr))

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2018-07-29 05:06  TWO\
     文件        1218  2018-07-28 16:03  TWO\util.py
     文件        3591  2018-07-28 15:38  TWO\test.py
     文件        2526  2018-07-29 05:06  TWO\main.py
     目录           0  2018-07-28 15:44  TWO\out\
     文件       18932  2018-07-28 15:44  TWO\out\001.png
     文件       31842  2018-07-28 15:47  TWO\out\000.png
     目录           0  2018-07-28 16:04  TWO\__pycache__\
     文件        1534  2018-07-28 16:04  TWO\__pycache__\util.cpython-35.pyc
     目录           0  2018-07-28 15:15  TWO\data\
     文件        4542  2018-07-28 15:15  TWO\data\t10k-labels-idx1-ubyte.gz
     文件     1648877  2018-07-28 15:15  TWO\data\t10k-images-idx3-ubyte.gz
     文件       28881  2018-07-28 15:15  TWO\data\train-labels-idx1-ubyte.gz
     文件     9912422  2018-07-28 15:15  TWO\data\train-images-idx3-ubyte.gz
     文件       60008  2018-07-28 15:06  TWO\data\train-labels-idx1-ubyte
     文件    47040016  2018-07-28 15:09  TWO\data\train-images-idx3-ubyte
     文件       10008  2018-07-28 15:09  TWO\data\t10k-labels-idx1-ubyte
     文件     7840016  2018-07-28 15:12  TWO\data\t10k-images-idx3-ubyte

评论

共有 条评论