-
大小: 24KB文件类型: .zip金币: 2下载: 1 次发布日期: 2021-06-17
- 语言: Python
- 标签:
资源简介
这是Google BERT模型的一个Pytorch重新实现
代码片段和文件信息
# Copyright 2018 Dong-Hyun Lee Kakao Brain.
“““ Load a checkpoint file of pretrained transformer to a model in pytorch “““
import numpy as np
import tensorflow as tf
import torch
#import ipdb
#from models import *
def load_param(checkpoint_file conversion_table):
“““
Load parameters in pytorch model from checkpoint file according to conversion_table
checkpoint_file : pretrained checkpoint model file in tensorflow
conversion_table : { pytorch tensor in a model : checkpoint variable name }
“““
for pyt_param tf_param_name in conversion_table.items():
tf_param = tf.train.load_variable(checkpoint_file tf_param_name)
# for weight(kernel) we should do transpose
if tf_param_name.endswith(‘kernel‘):
tf_param = np.transpose(tf_param)
assert pyt_param.size() == tf_param.shape \
‘Dim Mismatch: %s vs %s ; %s‘ % \
(tuple(pyt_param.size()) tf_param.shape tf_param_name)
# assign pytorch tensor from tensorflow param
pyt_param.data = torch.from_numpy(tf_param)
def load_model(model checkpoint_file):
“““ Load the pytorch model from checkpoint file “““
# embedding layer
e p = model.embed ‘bert/embeddings/‘
load_param(checkpoint_file {
e.tok_embed.weight: p+“word_embeddings“
e.pos_embed.weight: p+“position_embeddings“
e.seg_embed.weight: p+“token_type_embeddings“
e.norm.gamma: p+“layerNorm/gamma“
e.norm.beta: p+“layerNorm/beta“
})
# Transformer blocks
for i in range(len(model.blocks)):
b p = model.blocks[i] “bert/encoder/layer_%d/“%i
load_param(checkpoint_file {
b.attn.proj_q.weight: p+“attention/self/query/kernel“
b.attn.proj_q.bias: p+“attention/self/query/bias“
b.attn.proj_k.weight: p+“attention/self/key/kernel“
b.attn.proj_k.bias: p+“attention/self/key/bias“
b.attn.proj_v.weight: p+“attention/self/value/kernel“
b.attn.proj_v.bias: p+“attention/self/value/bias“
b.proj.weight: p+“attention/output/dense/kernel“
b.proj.bias: p+“attention/output/dense/bias“
b.pwff.fc1.weight: p+“intermediate/dense/kernel“
b.pwff.fc1.bias: p+“intermediate/dense/bias“
b.pwff.fc2.weight: p+“output/dense/kernel“
b.pwff.fc2.bias: p+“output/dense/bias“
b.norm1.gamma: p+“attention/output/layerNorm/gamma“
b.norm1.beta: p+“attention/output/layerNorm/beta“
b.norm2.gamma: p+“output/layerNorm/gamma“
b.norm2.beta: p+“output/layerNorm/beta“
})
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2019-06-26 13:14 pytorchic-bert-master\
文件 129 2019-06-26 13:14 pytorchic-bert-master\.gitignore
文件 11355 2019-06-26 13:14 pytorchic-bert-master\LICENSE
文件 5375 2019-06-26 13:14 pytorchic-bert-master\README.md
文件 2793 2019-06-26 13:14 pytorchic-bert-master\checkpoint.py
文件 7945 2019-06-26 13:14 pytorchic-bert-master\classify.py
目录 0 2019-06-26 13:14 pytorchic-bert-master\config\
文件 168 2019-06-26 13:14 pytorchic-bert-master\config\bert_ba
文件 150 2019-06-26 13:14 pytorchic-bert-master\config\pretrain.json
文件 141 2019-06-26 13:14 pytorchic-bert-master\config\train_mrpc.json
文件 5608 2019-06-26 13:14 pytorchic-bert-master\models.py
文件 6595 2019-06-26 13:14 pytorchic-bert-master\optim.py
文件 10251 2019-06-26 13:14 pytorchic-bert-master\pretrain.py
文件 8940 2019-06-26 13:14 pytorchic-bert-master\tokenization.py
文件 4841 2019-06-26 13:14 pytorchic-bert-master\train.py
文件 2554 2019-06-26 13:14 pytorchic-bert-master\utils.py
评论
共有 条评论