Source code for kashgari.processors.sequence_processor

# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: text_processor.py
# time: 12:27 下午

import collections
import operator
from typing import Dict, List, Any, Optional, Union

import numpy as np
import tqdm
from tensorflow.keras.preprocessing.sequence import pad_sequences

from kashgari.generators import CorpusGenerator
from kashgari.logger import logger
from kashgari.processors.abc_processor import ABCProcessor
from kashgari.types import TextSamplesVar


[docs]class SequenceProcessor(ABCProcessor): """ Generic processors for the sequence samples. """
[docs] def to_dict(self) -> Dict[str, Any]: data = super(SequenceProcessor, self).to_dict() data['config'].update({ 'build_in_vocab': self.build_in_vocab, 'min_count': self.min_count }) return data
[docs] def __init__(self, build_in_vocab: str = 'text', min_count: int = 3, build_vocab_from_labels: bool = False, **kwargs: Any) -> None: """ Args: vocab_dict_type: initial vocab dict type, one of `text` `labeling`. **kwargs: """ super(SequenceProcessor, self).__init__(**kwargs) self.build_in_vocab = build_in_vocab self.min_count = min_count self.build_vocab_from_labels = build_vocab_from_labels if build_in_vocab == 'text': self._initial_vocab_dic = { self.token_pad: 0, self.token_unk: 1, self.token_bos: 2, self.token_eos: 3 } elif build_in_vocab == 'labeling': self._initial_vocab_dic = { self.token_pad: 0 } else: self._initial_vocab_dic = {} self._showed_seq_len_warning = False
[docs] def build_vocab_generator(self, generators: List[CorpusGenerator]) -> None: if not self.vocab2idx: vocab2idx = self._initial_vocab_dic token2count: Dict[str, int] = {} for gen in generators: for sentence, label in tqdm.tqdm(gen, desc="Preparing text vocab dict"): if self.build_vocab_from_labels: target = label else: target = sentence for token in target: count = token2count.get(token, 0) token2count[token] = count + 1 sorted_token2count = sorted(token2count.items(), key=operator.itemgetter(1), reverse=True) token2count = collections.OrderedDict(sorted_token2count) for token, token_count in token2count.items(): if token not in vocab2idx and token_count >= self.min_count: vocab2idx[token] = len(vocab2idx) self.vocab2idx = vocab2idx self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()]) top_k_vocab = [k for (k, v) in list(self.vocab2idx.items())[:10]] logger.debug(f"--- Build vocab dict finished, Total: {len(self.vocab2idx)} ---") logger.debug(f"Top-10: {top_k_vocab}")
[docs] def transform(self, samples: TextSamplesVar, *, seq_length: int = None, max_position: int = None, segment: bool = False) -> np.ndarray: seq_length_from = "" # An ugly patch for tf-serving use case. if seq_length is None and self._sequence_length_from_saved_model is not None: seq_length = self._sequence_length_from_saved_model if seq_length is None: seq_length_from = "max length of the samples" seq_length = max([len(i) for i in samples]) + 2 if max_position is not None and max_position < seq_length: seq_length_from = "max embedding seq length" seq_length = max_position if seq_length_from and not self._showed_seq_len_warning: logger.warning( f'Sequence length is None, will use the {seq_length_from}, which is {seq_length}') self._showed_seq_len_warning = True numerized_samples = [] for seq in samples: if self.token_bos in self.vocab2idx: seq = [self.token_bos] + seq + [self.token_eos] else: seq = [self.token_pad] + seq + [self.token_pad] if self.token_unk in self.vocab2idx: unk_index = self.vocab2idx[self.token_unk] numerized_samples.append([self.vocab2idx.get(token, unk_index) for token in seq]) else: numerized_samples.append([self.vocab2idx[token] for token in seq]) sample_index = pad_sequences(numerized_samples, seq_length, padding='post', truncating='post') token_ids = np.array(sample_index) if segment: segment_ids = np.zeros(token_ids.shape, dtype=np.int32) return token_ids, segment_ids else: return token_ids
[docs] def inverse_transform(self, # type: ignore[override] labels: Union[List[List[int]], np.ndarray], *, lengths: List[int] = None, threshold: float = 0.5, **kwargs: Any) -> List[List[str]]: result = [] for index, seq in enumerate(labels): labels_ = [] for idx in seq: labels_.append(self.idx2vocab[idx]) if lengths is not None: labels_ = labels_[1:lengths[index] + 1] else: labels_ = labels_[1:-1] result.append(labels_) return result
if __name__ == "__main__": pass