Source code for kashgari.embeddings.bert_embedding

# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: bert_embedding.py
# time: 2:49 下午

import os
from typing import Dict, Any

from kashgari.embeddings.transformer_embedding import TransformerEmbedding


[docs]class BertEmbedding(TransformerEmbedding): """ BertEmbedding is a simple wrapped class of TransformerEmbedding. If you need load other kind of transformer based language model, please use the TransformerEmbedding. """
[docs] def to_dict(self) -> Dict[str, Any]: info_dic = super(BertEmbedding, self).to_dict() info_dic['config']['model_folder'] = self.model_folder return info_dic
[docs] def __init__(self, model_folder: str, **kwargs: Any): """ Args: model_folder: path of checkpoint folder. kwargs: additional params """ self.model_folder = model_folder vocab_path = os.path.join(self.model_folder, 'vocab.txt') config_path = os.path.join(self.model_folder, 'bert_config.json') checkpoint_path = os.path.join(self.model_folder, 'bert_model.ckpt') kwargs['vocab_path'] = vocab_path kwargs['config_path'] = config_path kwargs['checkpoint_path'] = checkpoint_path kwargs['model_type'] = 'bert' super(BertEmbedding, self).__init__(**kwargs)
if __name__ == "__main__": pass