资源简介

python3, tensorflow >= 1.3 简单的英文聊天机器人基于深度学习seq2seq, 可以直接跑,结果不是很准确

资源截图

代码片段和文件信息

#! /usr/bin/python
# -*- coding: utf8 -*-
“““Sequence to Sequence Learning for Twitter/Cornell Chatbot.

References
----------
http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/
“““
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *

import tensorflow as tf
import numpy as np
import time

###============= prepare data
from data.twitter import data
metadata idx_q idx_a = data.load_data(PATH=‘data/twitter/‘)                   # Twitter
# from data.cornell_corpus import data
# metadata idx_q idx_a = data.load_data(PATH=‘data/cornell_corpus/‘)          # Cornell Moive
(trainX trainY) (testX testY) (validX validY) = data.split_dataset(idx_q idx_a)

trainX = trainX.tolist()
trainY = trainY.tolist()
testX = testX.tolist()
testY = testY.tolist()
validX = validX.tolist()
validY = validY.tolist()

trainX = tl.prepro.remove_pad_sequences(trainX)
trainY = tl.prepro.remove_pad_sequences(trainY)
testX = tl.prepro.remove_pad_sequences(testX)
testY = tl.prepro.remove_pad_sequences(testY)
validX = tl.prepro.remove_pad_sequences(validX)
validY = tl.prepro.remove_pad_sequences(validY)

###============= parameters
xseq_len = len(trainX)#.shape[-1]
yseq_len = len(trainY)#.shape[-1]
assert xseq_len == yseq_len
batch_size = 32
n_step = int(xseq_len/batch_size)
xvocab_size = len(metadata[‘idx2w‘]) # 8002 (0~8001)
emb_dim = 1024

w2idx = metadata[‘w2idx‘]   # dict  word 2 index
idx2w = metadata[‘idx2w‘]   # list index 2 word

unk_id = w2idx[‘unk‘]   # 1
pad_id = w2idx[‘_‘]     # 0

start_id = xvocab_size  # 8002
end_id = xvocab_size+1  # 8003

w2idx.update({‘start_id‘: start_id})
w2idx.update({‘end_id‘: end_id})
idx2w = idx2w + [‘start_id‘ ‘end_id‘]

xvocab_size = yvocab_size = xvocab_size + 2

“““ A data for Seq2Seq should look like this:
input_seqs : [‘how‘ ‘are‘ ‘you‘ ‘]
decode_seqs : [‘‘ ‘I‘ ‘am‘ ‘fine‘ ‘]
target_seqs : [‘I‘ ‘am‘ ‘fine‘ ‘‘ ‘]
target_mask : [1 1 1 1 0]
“““

print(“encode_seqs“ [idx2w[id] for id in trainX[10]])
target_seqs = tl.prepro.sequences_add_end_id([trainY[10]] end_id=end_id)[0]
    # target_seqs = tl.prepro.remove_pad_sequences([target_seqs] pad_id=pad_id)[0]
print(“target_seqs“ [idx2w[id] for id in target_seqs])
decode_seqs = tl.prepro.sequences_add_start_id([trainY[10]] start_id=start_id remove_last=False)[0]
    # decode_seqs = tl.prepro.remove_pad_sequences([decode_seqs] pad_id=pad_id)[0]
print(“decode_seqs“ [idx2w[id] for id in decode_seqs])
target_mask = tl.prepro.sequences_get_mask([target_seqs])[0]
print(“target_mask“ target_mask)
print(len(target_seqs) len(decode_seqs) len(target_mask))

###============= model
def model(encode_seqs decode_seqs is_train=True reuse=False):
    with tf.variable_scope(“model“ reuse=reuse):
        # for chatbot you can use the same embedding layer
        # for translation you may want to use 2 seperated embedding layers
        with tf.variable_scope(“embedding“) as vs:
        

 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2017-11-05 23:01  seq2seq-chatbot-master\
     文件         156  2017-11-05 23:01  seq2seq-chatbot-master\.gitignore
     文件        1247  2017-11-05 23:01  seq2seq-chatbot-master\README.md
     目录           0  2017-11-05 23:01  seq2seq-chatbot-master\data\
     文件          91  2017-11-05 23:01  seq2seq-chatbot-master\data\__init__.py
     目录           0  2017-11-05 23:01  seq2seq-chatbot-master\data\cornell_corpus\
     文件       11453  2017-11-05 23:01  seq2seq-chatbot-master\data\cornell_corpus\data.py
     目录           0  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\
     文件        7459  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\data.py
     文件    10433840  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\idx_a.npy
     文件    10433840  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\idx_q.npy
     文件     2877112  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\metadata.pkl
     文件         119  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\pull
     文件         101  2017-11-05 23:01  seq2seq-chatbot-master\data\twitter\pull_raw_data
     文件        9656  2017-11-05 23:01  seq2seq-chatbot-master\main_simple_seq2seq.py

评论

共有 条评论