资源简介

使用tensorflow写的resnet-110训练cifar10数据,以及inceptionv3的一个网络(不带数据集),DenseNet在写(后续更新)

资源截图

代码片段和文件信息

#!/usr/bin/env python
# -*- coding: UTF-8 -*-

import numpy as np
import tensorflow as tf
import argparse
import os


from tensorpack import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *

“““
CIFAR10 DenseNet example. See: http://arxiv.org/abs/1608.06993
Code is developed based on Yuxin Wu‘s ResNet implementation: https://github.com/ppwwyyxx/tensorpack/tree/master/examples/ResNet
Results using DenseNet (L=40 K=12) on Cifar10 with data augmentation: ~5.77% test error.

Running time:
On one TITAN X GPU (CUDA 7.5 and cudnn 5.1) the code should run ~5iters/s on a batch size 64.
“““

BATCH_SIZE = 64

class Model(ModelDesc):
    def __init__(self depth):
        super(Model self).__init__()
        self.N = int((depth - 4)  / 3)
        self.growthRate =12

    def _get_inputs(self):
        return [InputDesc(tf.float32 [None 32 32 3] ‘input‘)
                InputDesc(tf.int32 [None] ‘label‘)
               ]

    def _build_graph(self input_vars):
        image label = input_vars
        image = image / 128.0 - 1

        def conv(name l channel stride):
            return Conv2D(name l channel 3 stride=stride
                          nl=tf.identity use_bias=False
                          W_init=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/channel)))
        def add_layer(name l):
            shape = l.get_shape().as_list()
            in_channel = shape[3]
            with tf.variable_scope(name) as scope:
                c = BatchNorm(‘bn1‘ l)
                c = tf.nn.relu(c)
                c = conv(‘conv1‘ c self.growthRate 1)
                l = tf.concat([c l] 3)
            return l

        def add_transition(name l):
            shape = l.get_shape().as_list()
            in_channel = shape[3]
            with tf.variable_scope(name) as scope:
                l = BatchNorm(‘bn1‘ l)
                l = tf.nn.relu(l)
                l = Conv2D(‘conv1‘ l in_channel 1 stride=1 use_bias=False nl=tf.nn.relu)
                l = AvgPooling(‘pool‘ l 2)
            return l


        def dense_net(name):
            l = conv(‘conv0‘ image 16 1)
            with tf.variable_scope(‘block1‘) as scope:

                for i in range(self.N):
                    l = add_layer(‘dense_layer.{}‘.format(i) l)
                l = add_transition(‘transition1‘ l)

            with tf.variable_scope(‘block2‘) as scope:

                for i in range(self.N):
                    l = add_layer(‘dense_layer.{}‘.format(i) l)
                l = add_transition(‘transition2‘ l)

            with tf.variable_scope(‘block3‘) as scope:

                for i in range(self.N):
                    l = add_layer(‘dense_layer.{}‘.format(i) l)
            l = BatchNorm(‘bnlast‘ l)
            l = tf.nn.relu(l)
            l = GlobalAvgPooling(‘gap‘ l)
            logits = FullyConnected(‘linear‘ l out_dim=10 nl=tf.identity)

            return logits

        logits

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----

     文件        455  2018-05-04 20:50  Architecture\.idea\Architecture.iml

     文件        153  2018-05-04 20:54  Architecture\.idea\codestyles\codestyleConfig.xml

     文件        198  2018-05-04 20:54  Architecture\.idea\codestyles\Project.xml

     文件         84  2018-05-04 20:54  Architecture\.idea\dictionaries\Yel.xml

     文件        185  2018-05-04 20:54  Architecture\.idea\misc.xml

     文件        276  2018-05-04 20:50  Architecture\.idea\modules.xml

     文件      40968  2018-05-15 19:15  Architecture\.idea\workspace.xml

     文件      14878  2018-05-11 10:52  Architecture\inception_v3.py

     文件       3798  2018-05-15 16:08  Architecture\training.py

     目录          0  2018-05-04 20:54  Architecture\.idea\codestyles

     目录          0  2018-05-04 20:54  Architecture\.idea\dictionaries

     目录          0  2018-05-15 19:15  Architecture\.idea

     目录          0  2018-05-15 18:00  Architecture

     文件         97  2018-05-15 17:40  Architecture\checkpoint

     文件       6419  2018-03-18 04:03  Architecture\cifar10-densenet.py

     文件       9206  2018-05-14 10:29  Architecture\data_utils.py

     文件        734  2018-05-15 16:53  Architecture\DenseNet.py

     文件       5588  2017-05-03 14:44  Architecture\helper.py

     文件   13963728  2018-05-15 17:40  Architecture\image_classification.data-00000-of-00001

     文件      11949  2018-05-15 17:40  Architecture\image_classification.index

     文件    8393363  2018-05-15 17:40  Architecture\image_classification.meta

     文件      11696  2018-05-15 18:00  Architecture\ResNet.py

     文件       7910  2018-05-15 16:17  Architecture\__pycache__\data_utils.cpython-36.pyc

     目录          0  2018-05-15 16:17  Architecture\__pycache__

----------- ---------  ---------- -----  ----

             22471685                    24


评论

共有 条评论