sentencer_utils.py
2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#! pip install lxml
from transformers import AutoTokenizer, TFAutoModelForTokenClassification
import numpy as np
import tensorflow as tf
def dict_to_tensors(d):
return { k : tf.convert_to_tensor(v) for k, v in d.items() }
class Sentencer(object):
def __init__(self, pretrained_path):
self.tokenizer = AutoTokenizer.from_pretrained('allegro/herbert-base-cased')
self.sentencer = TFAutoModelForTokenClassification.from_pretrained(pretrained_path, num_labels=2)
def do_segmentation(self, texts):
if type(texts) == str:
texts = [texts]
texts = [self.tokenizer._tokenizer.normalizer.normalize_str(t).strip() for t in texts]
tokenised = self.tokenizer(texts, return_offsets_mapping=True, padding=True, truncation=True, return_overflowing_tokens=True, max_length=512)
offsets_mapping = tokenised.pop('offset_mapping')
overflow_mapping = tokenised.pop('overflow_to_sample_mapping')
#if len(overflow_mapping) != len(texts):
# print('!!!!')
# display(list(enumerate(texts)))
# print(overflow_mapping)
predicted = np.argmax(self.sentencer(dict_to_tensors(tokenised), training=False).logits, axis=-1)
labeled_paragraphs = []
for idx, prediction, token_offsets, sample_no in zip(tokenised['input_ids'], predicted, offsets_mapping, overflow_mapping):
text = texts[sample_no]
if len(labeled_paragraphs) < sample_no + 1:
labeled_paragraphs.append([])
for token, is_eos, (i, j) in zip(self.tokenizer.convert_ids_to_tokens(idx), prediction, token_offsets):
if token in ('<s>', '</s>', '<pad>'):
continue
token = text[i:j]#token.replace('</w>', '')
labeled_paragraphs[-1].append([token, is_eos, i])
paragraphs = []
sentence = ''
for labeled_paragraph in labeled_paragraphs:
paragraphs.append([])
for token, is_eos, i in labeled_paragraph:
while len(sentence) < i:
sentence += ' '
sentence += token
if is_eos:
if sentence.strip():
paragraphs[-1].append(sentence.strip())
sentence = ''
if sentence.strip():
paragraphs[-1].append(sentence.strip())
sentence = ''
return paragraphs