Source code for kashgari.embeddings.transformer_embedding
# encoding: utf-8
# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz
# file: transformer_embedding.py
# time: 11:41 上午
import codecs
import json
from typing import Dict, List, Any, Optional
from bert4keras.models import build_transformer_model
from kashgari.embeddings.abc_embedding import ABCEmbedding
from kashgari.logger import logger
[docs]class TransformerEmbedding(ABCEmbedding):
"""
TransformerEmbedding is based on bert4keras.
The embeddings itself are wrapped into our simple embedding interface so that they can be used like any other embedding.
"""
[docs] def to_dict(self) -> Dict[str, Any]:
info_dic = super(TransformerEmbedding, self).to_dict()
info_dic['config']['vocab_path'] = self.vocab_path
info_dic['config']['config_path'] = self.config_path
info_dic['config']['checkpoint_path'] = self.checkpoint_path
info_dic['config']['model_type'] = self.model_type
return info_dic
[docs] def __init__(self,
vocab_path: str,
config_path: str,
checkpoint_path: str,
model_type: str = 'bert',
**kwargs: Any):
"""
Args:
vocab_path: vocab file path, example `vocab.txt`
config_path: model config path, example `config.json`
checkpoint_path: model weight path, example `model.ckpt-100000`
model_type: transfer model type, {bert, albert, nezha, gpt2_ml, t5}
kwargs: additional params
"""
self.vocab_path = vocab_path
self.config_path = config_path
self.checkpoint_path = checkpoint_path
self.model_type = model_type
self.vocab_list: List[str] = []
kwargs['segment'] = True
super(TransformerEmbedding, self).__init__(**kwargs)
[docs] def load_embed_vocab(self) -> Optional[Dict[str, int]]:
token2idx: Dict[str, int] = {}
with codecs.open(self.vocab_path, 'r', 'utf8') as reader:
for line in reader:
token = line.strip()
self.vocab_list.append(token)
token2idx[token] = len(token2idx)
top_words = [k for k, v in list(token2idx.items())[:50]]
logger.debug('------------------------------------------------')
logger.debug("Loaded transformer model's vocab")
logger.debug(f'config_path : {self.config_path}')
logger.debug(f'vocab_path : {self.vocab_path}')
logger.debug(f'checkpoint_path : {self.checkpoint_path}')
logger.debug(f'Top 50 words : {top_words}')
logger.debug('------------------------------------------------')
return token2idx
[docs] def build_embedding_model(self,
*,
vocab_size: int = None,
force: bool = False,
**kwargs: Dict) -> None:
if self.embed_model is None:
config_path = self.config_path
with open(config_path, 'r') as f:
config = json.loads(f.read())
if 'max_position' in config:
self.max_position = config['max_position']
else:
self.max_position = config.get('max_position_embeddings')
bert_model = build_transformer_model(config_path=self.config_path,
checkpoint_path=self.checkpoint_path,
model=self.model_type,
application='encoder',
return_keras_model=True)
for layer in bert_model.layers:
layer.trainable = False
self.embed_model = bert_model
self.embedding_size = bert_model.output.shape[-1]
if __name__ == "__main__":
pass