• 大小: 24KB
    文件类型: .zip
    金币: 2
    下载: 1 次
    发布日期: 2021-06-17
  • 语言: Python
  • 标签:

资源简介

这是Google BERT模型的一个Pytorch重新实现

资源截图

代码片段和文件信息

# Copyright 2018 Dong-Hyun Lee Kakao Brain.

“““ Load a checkpoint file of pretrained transformer to a model in pytorch “““

import numpy as np
import tensorflow as tf
import torch
#import ipdb
#from models import *

def load_param(checkpoint_file conversion_table):
    “““
    Load parameters in pytorch model from checkpoint file according to conversion_table
    checkpoint_file : pretrained checkpoint model file in tensorflow
    conversion_table : { pytorch tensor in a model : checkpoint variable name }
    “““
    for pyt_param tf_param_name in conversion_table.items():
        tf_param = tf.train.load_variable(checkpoint_file tf_param_name)

        # for weight(kernel) we should do transpose
        if tf_param_name.endswith(‘kernel‘):
            tf_param = np.transpose(tf_param)

        assert pyt_param.size() == tf_param.shape \
            ‘Dim Mismatch: %s vs %s ; %s‘ % \
                (tuple(pyt_param.size()) tf_param.shape tf_param_name)
        
        # assign pytorch tensor from tensorflow param
        pyt_param.data = torch.from_numpy(tf_param)


def load_model(model checkpoint_file):
    “““ Load the pytorch model from checkpoint file “““

    # embedding layer
    e p = model.embed ‘bert/embeddings/‘
    load_param(checkpoint_file {
        e.tok_embed.weight: p+“word_embeddings“
        e.pos_embed.weight: p+“position_embeddings“
        e.seg_embed.weight: p+“token_type_embeddings“
        e.norm.gamma:       p+“layerNorm/gamma“
        e.norm.beta:        p+“layerNorm/beta“
    })

    # Transformer blocks
    for i in range(len(model.blocks)):
        b p = model.blocks[i] “bert/encoder/layer_%d/“%i
        load_param(checkpoint_file {
            b.attn.proj_q.weight:   p+“attention/self/query/kernel“
            b.attn.proj_q.bias:     p+“attention/self/query/bias“
            b.attn.proj_k.weight:   p+“attention/self/key/kernel“
            b.attn.proj_k.bias:     p+“attention/self/key/bias“
            b.attn.proj_v.weight:   p+“attention/self/value/kernel“
            b.attn.proj_v.bias:     p+“attention/self/value/bias“
            b.proj.weight:          p+“attention/output/dense/kernel“
            b.proj.bias:            p+“attention/output/dense/bias“
            b.pwff.fc1.weight:      p+“intermediate/dense/kernel“
            b.pwff.fc1.bias:        p+“intermediate/dense/bias“
            b.pwff.fc2.weight:      p+“output/dense/kernel“
            b.pwff.fc2.bias:        p+“output/dense/bias“
            b.norm1.gamma:          p+“attention/output/layerNorm/gamma“
            b.norm1.beta:           p+“attention/output/layerNorm/beta“
            b.norm2.gamma:          p+“output/layerNorm/gamma“
            b.norm2.beta:           p+“output/layerNorm/beta“
        })


 属性            大小     日期    时间   名称
----------- ---------  ---------- -----  ----
     目录           0  2019-06-26 13:14  pytorchic-bert-master\
     文件         129  2019-06-26 13:14  pytorchic-bert-master\.gitignore
     文件       11355  2019-06-26 13:14  pytorchic-bert-master\LICENSE
     文件        5375  2019-06-26 13:14  pytorchic-bert-master\README.md
     文件        2793  2019-06-26 13:14  pytorchic-bert-master\checkpoint.py
     文件        7945  2019-06-26 13:14  pytorchic-bert-master\classify.py
     目录           0  2019-06-26 13:14  pytorchic-bert-master\config\
     文件         168  2019-06-26 13:14  pytorchic-bert-master\config\bert_base.json
     文件         150  2019-06-26 13:14  pytorchic-bert-master\config\pretrain.json
     文件         141  2019-06-26 13:14  pytorchic-bert-master\config\train_mrpc.json
     文件        5608  2019-06-26 13:14  pytorchic-bert-master\models.py
     文件        6595  2019-06-26 13:14  pytorchic-bert-master\optim.py
     文件       10251  2019-06-26 13:14  pytorchic-bert-master\pretrain.py
     文件        8940  2019-06-26 13:14  pytorchic-bert-master\tokenization.py
     文件        4841  2019-06-26 13:14  pytorchic-bert-master\train.py
     文件        2554  2019-06-26 13:14  pytorchic-bert-master\utils.py

评论

共有 条评论