• 大小: 1.14MB
    文件类型: .zip
    金币: 1
    下载: 0 次
    发布日期: 2023-09-02
  • 语言: Python
  • 标签: Keras  Attention  Python  

资源简介

该代码为基于Keras的attention实战,环境配置: Wn10+CPU i7-6700 、Pycharm 2018、 python 3.6 、、numpy 1.14.5 、Keras 2.0.2 Matplotlib 2.2.2 经过小编亲自调试,可以使用,适合初学者从代码的角度了解attention机制。

资源截图

代码片段和文件信息

import numpy as np

from attention_utils import get_activations get_data

np.random.seed(1337)  # for reproducibility
from keras.models import *
from keras.layers import Input Dense merge

input_dim = 32


def build_model():
    inputs = Input(shape=(input_dim))

    # ATTENTION PART STARTS HERE
    attention_probs = Dense(input_dim activation=‘softmax‘ name=‘attention_vec‘)(inputs)
    attention_mul = merge([inputs attention_probs] output_shape=32 name=‘attention_mul‘ mode=‘mul‘)
    # ATTENTION PART FINISHES HERE

    attention_mul = Dense(64)(attention_mul)
    output = Dense(1 activation=‘sigmoid‘)(attention_mul)
    model = Model(input=[inputs] output=output)
    return model


if __name__ == ‘__main__‘:
    N = 10000
    inputs_1 outputs = get_data(N input_dim)

    m = build_model()
    m.compile(optimizer=‘adam‘ loss=‘binary_crossentropy‘ metrics=[‘accuracy‘])
    print(m.summary())

    m.fit([inputs_1] outputs epochs=20 batch_size=64 validation_split=0.5)

    testing_inputs_1 testing_outputs = get_data(1 input_dim)

    # Attention vector corresponds to the second matrix.
    # The first one is the Inputs output.
    attention_vector = get_activations(m testing_inputs_1
                                       print_shape_only=True
                                       layer_name=‘attention_vec‘)[0].flatten()
    print(‘attention =‘ attention_vector)

    # plot part.
    import matplotlib.pyplot as plt
    import pandas as pd

    pd.Dataframe(attention_vector columns=[‘attention (%)‘]).plot(kind=‘bar‘
                                                                   title=‘Attention Mechanism as ‘
                                                                         ‘a function of input‘
                                                                         ‘ dimensions.‘)
    plt.show()

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2017-10-15 13:08  keras-attention-mechanism-master\
     文件        1163  2017-10-15 13:08  keras-attention-mechanism-master\.gitignore
     文件       11357  2017-10-15 13:08  keras-attention-mechanism-master\LICENSE
     文件        4880  2017-10-15 13:08  keras-attention-mechanism-master\README.md
     目录           0  2017-10-15 13:08  keras-attention-mechanism-master\assets\
     文件       45984  2017-10-15 13:08  keras-attention-mechanism-master\assets\1.png
     文件      215990  2017-10-15 13:08  keras-attention-mechanism-master\assets\attention_1.png
     文件      437259  2017-10-15 13:08  keras-attention-mechanism-master\assets\graph_multi_attention.png
     文件      443997  2017-10-15 13:08  keras-attention-mechanism-master\assets\graph_single_attention.png
     文件       47113  2017-10-15 13:08  keras-attention-mechanism-master\assets\lstm_after.png
     文件       51615  2017-10-15 13:08  keras-attention-mechanism-master\assets\lstm_before.png
     文件        1865  2017-10-15 13:08  keras-attention-mechanism-master\attention_dense.py
     文件        3430  2017-10-15 13:08  keras-attention-mechanism-master\attention_lstm.py
     文件        2615  2017-10-15 13:08  keras-attention-mechanism-master\attention_utils.py
     文件          60  2017-10-15 13:08  keras-attention-mechanism-master\requirements.txt

评论

共有 条评论