• 大小: 12KB
    文件类型: .py
    金币: 1
    下载: 0 次
    发布日期: 2021-05-13
  • 语言: Python
  • 标签: tensorflow  

资源简介

构建一个四层神经网络识别手写体数据集MNIST,然后将注意力模块CBAM插入到网络的第一层之后,查看注意力模块的性能。可以改变CBAM模块插入的位置,做到任意插入。

资源截图

代码片段和文件信息

import warnings
warnings.filterwarnings(‘ignore‘ category=FutureWarning)
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
slim = tf.contrib.slim

#create weights for each layer
def get_weights(shape):
    data = tf.truncated_normal(shapestddev=0.1)
    return tf.Variable(data)

def get_biases(shape):
    data = tf.constant(0.1shape=shape)
    return tf.Variable(data)

#2d convolutional function
def convolution_2d(xw):
    return tf.nn.conv2d(xwstrides=[1111]padding=‘SAME‘)
#2*2 max pooling
def max_pooling(x):
    return tf.nn.max_pool(xksize=[1221]strides=[1221]padding=‘SAME‘)


def combined_static_and_dynamic_shape(tensor):
    “““Returns a list containing static and dynamic values for the dimensions.  Returns a list of static 
    and dynamic values for shape dimensions. This is  useful to preserve static shapes when available in reshape operation.  
    Args:    tensor: A tensor of any type.  
    Returns:    A list of size tensor.shape.ndims containing integers or a scalar tensor.  “““
    static_tensor_shape = tensor.shape.as_list()
    dynamic_tensor_shape = tf.shape(tensor)
    combined_shape = []
    for index dim in enumerate(static_tensor_shape):
        if dim is not None:
            combined_shape.append(dim)
        else:
            combined_shape.append(dynamic_tensor_shape[index])
    return combined_shape


def convolutional_block_attention_module(feature_map index reduction_ratio = 0.5):
    “““CBAM:convolutional block attention module
    Args:
        feature_map:input feature map
        index:the index of the module
        reduction_ratio:output units number of first MLP layer:reduction_ratio * feature map
    Return:
        feature map with channel and spatial attention“““

    with tf.variable_scope(“cbam_%s“ % (index)):
        feature_map_shape = combined_static_and_dynamic_shape(feature_map)
        # channel attention module
        channel_avg_weights = tf.nn.avg_pool(value=feature_map
                                             ksize=[1 feature_map_shape[1] feature_map_shape[2] 1]
                                             strides=[1 1 1 1]
                                             padding=‘VALID‘)  # global average pool
        channel_max_weights = tf.nn.max_pool(value=feature_map
                                             ksize=[1 feature_map_shape[1] feature_map_shape[2] 1]
                                             strides=[1 1 1 1]
                                             padding=‘VALID‘)
        channel_avg_reshape = tf.reshape(channel_avg_weights
                                         [feature_map_shape[0] 1 feature_map_shape[3]])
        channel_max_reshape = tf.reshape(channel_max_weights
                                         [feature_map_shape[0] 1 feature_map_shape[3]])
        channel_w_reshape = tf.concat([channel_avg_reshape channel_max_re

评论

共有 条评论