Source code for kashgari.embeddings.abc_embedding

# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: abc_embedding.py
# time: 2:43 下午

import json
from typing import Dict, List, Any, Optional, Union

import numpy as np
import tensorflow as tf
import tqdm

import kashgari
from kashgari.generators import CorpusGenerator
from kashgari.logger import logger
from kashgari.processors import ABCProcessor

L = tf.keras.layers


class ABCEmbedding:
    def to_dict(self) -> Dict[str, Any]:
        config: Dict[str, Any] = {
            'segment': self.segment,
            'embedding_size': self.embedding_size,
            'max_position': self.max_position,
            **self.kwargs
        }
        return {
            '__class_name__': self.__class__.__name__,
            '__module__': self.__class__.__module__,
            'config': config,
            'embed_model': json.loads(self.embed_model.to_json())
        }

    def __init__(self,
                 segment: bool = False,
                 embedding_size: int = 100,
                 max_position: int = None,
                 **kwargs: Any):

        self.embed_model: tf.keras.Model = None

        self.segment: bool = segment  # type: ignore
        self.kwargs = kwargs

        self.embedding_size: int = embedding_size  # type: ignore
        self.max_position: int = max_position  # type: ignore
        self.vocab2idx = self.load_embed_vocab()
        self._text_processor: Optional[ABCProcessor] = None

    def _override_load_model(self, config: Dict) -> None:
        embed_model_json_str = json.dumps(config['embed_model'])
        self.embed_model = tf.keras.models.model_from_json(embed_model_json_str,
                                                           custom_objects=kashgari.custom_objects)

    def setup_text_processor(self, processor: ABCProcessor) -> None:
        self._text_processor = processor
        self.build_embedding_model(vocab_size=processor.vocab_size)
        self._text_processor.segment = self.segment
        if self.vocab2idx:
            self._text_processor.vocab2idx = self.vocab2idx
            self._text_processor.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()])

    def get_seq_length_from_corpus(self,
                                   generators: List[CorpusGenerator],
                                   *,
                                   use_label: bool = False,
                                   cover_rate: float = 0.95) -> int:
        """
        Calculate proper sequence length according to the corpus

        Args:
            generators:
            use_label:
            cover_rate:

        Returns:

        """
        seq_lens = []
        for gen in generators:
            for sentence, label in tqdm.tqdm(gen, desc="Calculating sequence length"):
                if use_label:
                    seq_lens.append(len(label))
                else:
                    seq_lens.append(len(sentence))
        if cover_rate == 1.0:
            target_index = -1
        else:
            target_index = int(cover_rate * len(seq_lens))
        sequence_length = sorted(seq_lens)[target_index]
        logger.debug(f'Calculated sequence length = {sequence_length}')
        return sequence_length

    def load_embed_vocab(self) -> Optional[Dict[str, int]]:
        """
        Load vocab dict from embedding layer

        Returns:
            vocab dict or None
        """
        raise NotImplementedError

    def build_embedding_model(self,
                              *,
                              vocab_size: int = None,
                              force: bool = False,
                              **kwargs: Dict) -> None:
        raise NotImplementedError

    def embed(self,
              sentences: List[List[str]],
              *,
              debug: bool = False) -> np.ndarray:
        """
        batch embed sentences

        Args:
            sentences: Sentence list to embed
            debug: show debug info
        Returns:
            vectorized sentence list
        """
        if self._text_processor is None:
            raise ValueError('Need to setup the `embedding.setup_text_processor` before calling the embed function.')

        tensor_x = self._text_processor.transform(sentences,
                                                  segment=self.segment,
                                                  seq_length=self.max_position)
        if debug:
            logger.debug(f'sentence tensor: {tensor_x}')
        embed_results = self.embed_model.predict(tensor_x)
        return embed_results


if __name__ == "__main__":
    pass