constituency_parser.py 11.1 KB
import json
import time

import morfeusz2

import tensorflow as tf
from transformers import AutoTokenizer

from datasets.features import ClassLabel, Sequence

from .data_utils import dict_to_tensors
from .dataset_utils import  masked_word_ids, morf_tokenize
from .hybrid_tree_utils import make_lemma, correct_lemma, get_heads, make_tree, tree2dict
from .MultiTarget import TFBertForMultiTargetTokenClassification

from .constants import (
    SEG_BEGIN,
    SEGS,
    LEMMAS,
    LEMMA_CASES,
    LEMMA_RULES,
    TAGS,
    HEADS,
    ADJACENCY_MATRIX,
    DEPRELS,
    SPINES,
    ANCHORS,
    ANCHOR_HS,
)

def maybe_int(s):
    if s and (s.isdigit() or s[0] == '-' and s[1:].isdigit()):
        return int(s)
    return s

def keys_hook(d):
    return { maybe_int(k) : v for k, v in d.items() }

def category_names(segmentation=True, lemmatisation=True, tagging=True, dependency=True, spines=True, nonterminal_features=[]):
    categories = []
        
    if segmentation:
        categories.append(SEGS)
    
    if lemmatisation:
        categories.append(LEMMA_CASES)
        categories.append(LEMMA_RULES)
    
    if tagging:
        categories.append(TAGS)
    
    if dependency:
        categories.append(ADJACENCY_MATRIX)
        categories.append(DEPRELS)
        
    if spines:
        categories.append(SPINES)
        categories.append(ANCHORS)
        categories.append(ANCHOR_HS)
    
    if nonterminal_features:
        categories += nonterminal_features
    
    return categories

def get_labels(features, categories):
    labels = {}
    for cat in categories:
        feature = features[cat].feature
        if type(feature) == ClassLabel:
            labels[cat] = feature.names
    return labels

class ConstituencyParser(object):
    
    def __init__(
            self,
            bert_path,
            model,
            labels,
            segmentation=True,
            lemmatisation=True,
            tagging=True,
            dependency=True,
            spines=True,
            nonterminal_features=[],
            bert_tokenizer=None,
            ):
        self.bert_path = bert_path
        self.model = model
        self.segmentation = segmentation
        self.lemmatisation = lemmatisation
        self.tagging = tagging
        self.dependency = dependency
        self.spines = spines
        self.nonterminal_features = nonterminal_features
        self.categories = category_names(segmentation, lemmatisation, tagging, dependency, spines, nonterminal_features)
        self.labels = labels
        if bert_tokenizer is not None:
            self.bert_tokenizer = bert_tokenizer
        else:
            self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.morfeusz = morfeusz2.Morfeusz(generate=False, expand_tags=True)
    
    def save(self, path):
        self.model.save_pretrained(f'{path}/model')
        config = {
            'segmentation' : self.segmentation,
            'lemmatisation' : self.lemmatisation,
            'tagging' : self.tagging,
            'dependency' : self.dependency,
            'spines' : self.spines,
            'nonterminal_features' : self.nonterminal_features,
            'labels' : self.labels,
            'bert_path' : self.bert_path,
        }
        with open(f'{path}/config.json', 'w') as f:
            json.dump(config, f, ensure_ascii=False)
    
    def create(
            bert_path,
            features,
            segmentation=True,
            lemmatisation=True,
            tagging=True,
            dependency=True,
            spines=True,
            nonterminal_features=[],
            bert_tokenizer=None,
            ):
        categories = category_names(segmentation, lemmatisation, tagging, dependency, spines, nonterminal_features)
        labels = get_labels(features, categories)
        model = TFBertForMultiTargetTokenClassification.from_pretrained(
            bert_path,
            from_pt=True,
            categories=categories,
            labels=labels,
        )
        return ConstituencyParser(
            bert_path,
            model,
            labels,
            segmentation=segmentation,
            lemmatisation=lemmatisation,
            tagging=tagging,
            dependency=dependency,
            spines=spines,
            nonterminal_features=nonterminal_features,
            bert_tokenizer=bert_tokenizer
        )
    
    def load(path):
        with open(f'{path}/config.json') as f:
            config = json.load(f, object_hook=keys_hook)
        labels = config['labels']
        segmentation = config['segmentation']
        lemmatisation = config['lemmatisation']
        tagging = config['tagging']
        dependency = config['dependency']
        spines = config['spines']
        nonterminal_features = config['nonterminal_features']
        bert_path = config['bert_path']
        categories = category_names(segmentation, lemmatisation, tagging, dependency, spines, nonterminal_features)
        model = TFBertForMultiTargetTokenClassification.from_pretrained(
            f'{path}/model',
            categories=categories,
            labels=labels,
        )
        return ConstituencyParser(
            bert_path,
            model,
            labels,
            segmentation=segmentation,
            lemmatisation=lemmatisation,
            tagging=tagging,
            dependency=dependency,
            spines=spines,
            nonterminal_features=nonterminal_features,
        )
    
    def retokenize_mask(self, tokens, seg, min_tokens):
        tok2 = ''
        index = 0
        assert(len(tokens) == len(seg))
        tokens2, mask = [], []
        for token, seglabel in zip(tokens, seg):
            if tok2 == min_tokens[index]:
                tok2 = ''
                index += 1
            if seglabel == SEG_BEGIN or tok2 == min_tokens[index]:
                tokens2.append(token)
                mask.append(1)
            else:
                tokens2[-1] += token
                mask.append(None)
            tok2 += token
        return tokens2, mask
    
    def align_with_mask(self, labels, mask):
        return [
            lbl if not hasattr(lbl, '__iter__') or type(lbl) == str else self.align_with_mask(lbl, mask)
            for lbl, m in zip(labels, mask) if m is not None
        ]
    
    def process_labels(self, labels, tokens, correct_lemmata=False):
        
        if self.lemmatisation:
            rules = labels.pop(LEMMA_RULES)
            cases = labels.pop(LEMMA_CASES)
            labels[LEMMAS] = [make_lemma(*x) for x in zip(tokens, cases, rules)]
        else:
            labels[LEMMAS] = ['_' for _ in tokens]
        
        if correct_lemmata:
            tags = labels[TAGS]
            lemmas = labels[LEMMAS]
            labels[LEMMAS] = [correct_lemma(*x, self.morfeusz) for x in zip(tokens, lemmas, tags)]
        
        if self.dependency:
            matrix = labels.pop(ADJACENCY_MATRIX)
            labels[HEADS] = get_heads(matrix)
        
        return labels
    
    def parse(
        self,
        sentences,
        correct_lemmata=False,
        return_jsons=False,
        return_labels=False,
        return_logits=False,
        root_label=None,
        force_root_label=False,
        force_long=False,
        is_tokenized=False,
        return_times=False
    ):
        
        t1 = time.process_time_ns()
        
        if sum((return_jsons, return_labels, return_logits)) > 1:
            raise RuntimeError('At most one can be set to True: return_jsons, return_labels, return_logits.')
        if not is_tokenized and not self.segmentation:
            raise RuntimeError('This model can’t tokenize, please use is_tokenized=True and pass a space-separated tokenized sentence, e.g ‘Miał em kota .’')
        if correct_lemmata and not (self.lemmatisation and self.tagging):
            print('This model can’t lemmatise and/or tag, setting correct_lemmata to False.')
            correct_lemmata = False
        return_trees = not (return_jsons or return_labels)
        if (return_trees or return_jsons) and not (self.dependency and self.spines):
            raise RuntimeError('This model can’t parse and won’t return trees/jsons, use return_labels=True.')
        
        if isinstance(sentences, str):
            sentences = [sentences]
        tokens = [s.split() for s in sentences]
        if self.segmentation and not is_tokenized:
            tokens = [morf_tokenize(' '.join(toks), self.morfeusz) for toks in tokens]
        tokenized = self.bert_tokenizer(
            tokens,
            is_split_into_words=True,
            return_offsets_mapping=True,
            padding=True,
        )
        
        M = len(tokenized['input_ids'][0])
        if M > self.bert_tokenizer.model_max_length and not force_long:
            raise RuntimeError(f'Bert tokenizer produced a sequence of {M} tokens which exceeds the model’s limit ({self.bert_tokenizer.model_max_length}). Parse shorter sentences or call parse with force_long=True at your own risk.')
        x = dict_to_tensors(dict(tokenized))
        
        t2 = time.process_time_ns()
        predicted = self.model.predict(x)
        t3 = time.process_time_ns()
        
        labels = dict()
        for cat, pred in predicted.items():
            if return_logits and cat != SEGS:
                lbls = pred
            else:
                if cat == ADJACENCY_MATRIX:
                    lbls = tf.nn.softmax(pred, axis=-1).numpy()
                else:
                    label_ids = tf.argmax(pred, axis=-1).numpy()
                    lbls = [[self.labels[cat][i] for i in l_ids] for l_ids in label_ids]
            labels[cat] = lbls
        
        trees = []
        
        for i, (tkns, sentence) in enumerate(zip(tokens, sentences)):
            
            mask = masked_word_ids(tokenized.word_ids(i))
            lbls = {cat : self.align_with_mask(lbls[i], mask) for cat, lbls in labels.items()}
            
            if self.segmentation:
                # remove the seg labels
                seg = lbls.pop(SEGS)
                if not is_tokenized:
                    tkns, mask = self.retokenize_mask(tkns, seg, sentence.split())
                    lbls = {cat : self.align_with_mask(lbl, mask) for cat, lbl in lbls.items()}
            
            if return_logits:
                trees.append((tkns, lbls))
                continue
                
            lbls = self.process_labels(lbls, tkns, correct_lemmata=correct_lemmata)
            
            if return_trees:
                trees.append(make_tree(tkns, lbls, self.nonterminal_features,
                                       root_label=root_label,
                                       force_root_label=force_root_label))
            elif return_jsons:
                trees.append(tree2dict(make_tree(tkns, lbls, self.nonterminal_features,
                                                 root_label=root_label,
                                                 force_root_label=force_root_label)))
            else:
                trees.append((tkns, lbls))
        
        t4 = time.process_time_ns()
        
        if return_times:
            return trees, {'total' : t4 - t1, 'prediction' : t3 - t2}
        else:
            return trees