sentencer_utils.py 2.4 KB
#! pip install lxml

from transformers import AutoTokenizer, TFAutoModelForTokenClassification

import numpy as np
import tensorflow as tf

def dict_to_tensors(d):
    return { k : tf.convert_to_tensor(v) for k, v in d.items() }

class Sentencer(object):

    def __init__(self, pretrained_path):
        self.tokenizer = AutoTokenizer.from_pretrained('allegro/herbert-base-cased')
        self.sentencer = TFAutoModelForTokenClassification.from_pretrained(pretrained_path, num_labels=2)

    def do_segmentation(self, texts):
        if type(texts) == str:
            texts = [texts]
        texts = [self.tokenizer._tokenizer.normalizer.normalize_str(t).strip() for t in texts]
        tokenised = self.tokenizer(texts, return_offsets_mapping=True, padding=True, truncation=True, return_overflowing_tokens=True, max_length=512)
        offsets_mapping = tokenised.pop('offset_mapping')
        overflow_mapping = tokenised.pop('overflow_to_sample_mapping')
        #if len(overflow_mapping) != len(texts):
        #    print('!!!!')
        #    display(list(enumerate(texts)))
        #    print(overflow_mapping)
        predicted = np.argmax(self.sentencer(dict_to_tensors(tokenised), training=False).logits, axis=-1)
        labeled_paragraphs = []
        for idx, prediction, token_offsets, sample_no in zip(tokenised['input_ids'], predicted, offsets_mapping, overflow_mapping):
            text = texts[sample_no]
            if len(labeled_paragraphs) < sample_no + 1:
                labeled_paragraphs.append([])
            for token, is_eos, (i, j) in zip(self.tokenizer.convert_ids_to_tokens(idx), prediction, token_offsets):
                if token in ('<s>', '</s>', '<pad>'):
                    continue    
                token = text[i:j]#token.replace('</w>', '')
                labeled_paragraphs[-1].append([token, is_eos, i])
        paragraphs = []
        sentence = ''
        for labeled_paragraph in labeled_paragraphs:
            paragraphs.append([])
            for token, is_eos, i in labeled_paragraph:
                while len(sentence) < i:
                    sentence += ' '
                sentence += token
                if is_eos:
                    if sentence.strip():
                        paragraphs[-1].append(sentence.strip())
                    sentence = ''
            if sentence.strip():
                paragraphs[-1].append(sentence.strip())
            sentence = ''
        return paragraphs