# encoding: utf-8
# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz
# file: label_processor.py
# time: 2:53 下午
import collections
import operator
from typing import List, Union, Dict, Optional, Any, Tuple
import numpy as np
import tqdm
from kashgari.generators import CorpusGenerator
from kashgari.processors.abc_processor import ABCProcessor
from kashgari.types import TextSamplesVar
[docs]class ClassificationProcessor(ABCProcessor):
[docs] def to_dict(self) -> Dict[str, Any]:
data = super(ClassificationProcessor, self).to_dict()
data['config']['multi_label'] = self.multi_label
return data
[docs] def __init__(self,
multi_label: bool = False,
**kwargs: Any) -> None:
from kashgari.utils import MultiLabelBinarizer
super(ClassificationProcessor, self).__init__(**kwargs)
self.multi_label = multi_label
self.multi_label_binarizer = MultiLabelBinarizer(self.vocab2idx)
[docs] def build_vocab_generator(self,
generators: List[CorpusGenerator]) -> None:
from kashgari.utils import MultiLabelBinarizer
if self.vocab2idx:
return
vocab2idx: Dict[str, int] = {}
token2count: Dict[str, int] = {}
for generator in generators:
if self.multi_label:
for _, label in tqdm.tqdm(generator, desc="Preparing classification label vocab dict"):
for token in label:
count = token2count.get(token, 0)
token2count[token] = count + 1
else:
for _, label in tqdm.tqdm(generator, desc="Preparing classification label vocab dict"):
count = token2count.get(label, 0)
token2count[label] = count + 1
sorted_token2count = sorted(token2count.items(),
key=operator.itemgetter(1),
reverse=True)
token2count = collections.OrderedDict(sorted_token2count)
for token, token_count in token2count.items():
if token not in vocab2idx:
vocab2idx[token] = len(vocab2idx)
self.vocab2idx = vocab2idx
self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()])
self.multi_label_binarizer = MultiLabelBinarizer(self.vocab2idx)
[docs] def get_tensor_shape(self, batch_size: int, seq_length: int) -> Tuple:
if self.multi_label:
return batch_size, len(self.vocab2idx)
else:
return (batch_size,)
if __name__ == "__main__":
pass