constituency_parser.py 4.28 KB
import json

import morfeusz2

import tensorflow as tf
from transformers import AutoTokenizer

from datasets.features import ClassLabel, Sequence

from .data_utils import dict_to_tensors
from .dataset_utils import masked_word_ids
from .MultiTarget import TFBertForMultiTargetTokenClassification

from .constants import (
    SPINES,
    ANCHORS,
    ANCHOR_HS,
)

def maybe_int(s):
    if s and (s.isdigit() or s[0] == '-' and s[1:].isdigit()):
        return int(s)
    return s

def keys_hook(d):
    return { maybe_int(k) : v for k, v in d.items() }

def get_labels(features, categories):
    labels = {}
    for cat in categories:
        feature = features[cat].feature
        if type(feature) == ClassLabel:
            labels[cat] = feature.names
    return labels

class ConstituencyParser(object):
    
    def __init__(
            self,
            bert_path,
            model,
            labels,
            bert_tokenizer=None,
            ):
        self.bert_path = bert_path
        self.model = model
        self.categories = [SPINES, ANCHORS, ANCHOR_HS]
        self.labels = labels
        if bert_tokenizer is not None:
            self.bert_tokenizer = bert_tokenizer
        else:
            self.bert_tokenizer = AutoTokenizer.from_pretrained(bert_path)
        self.morfeusz = morfeusz2.Morfeusz(generate=False, expand_tags=True)
    
    def save(self, path):
        self.model.save_pretrained(f'{path}/model')
        config = {
            'labels' : self.labels,
            'bert_path' : self.bert_path,
        }
        with open(f'{path}/config.json', 'w') as f:
            json.dump(config, f, ensure_ascii=False)
    
    def create(
            bert_path,
            features,
            bert_tokenizer=None,
            ):
        categories = [SPINES, ANCHORS, ANCHOR_HS]
        labels = get_labels(features, categories)
        model = TFBertForMultiTargetTokenClassification.from_pretrained(
            bert_path,
            from_pt=True,
            categories=categories,
            labels=labels,
        )
        return ConstituencyParser(
            bert_path,
            model,
            labels,
            bert_tokenizer=bert_tokenizer
        )
    
    def load(path):
        with open(f'{path}/config.json') as f:
            config = json.load(f, object_hook=keys_hook)
        labels = config['labels']
        bert_path = config['bert_path']
        categories = [SPINES, ANCHORS, ANCHOR_HS]
        model = TFBertForMultiTargetTokenClassification.from_pretrained(
            f'{path}/model',
            categories=categories,
            labels=labels,
        )
        return ConstituencyParser(
            bert_path,
            model,
            labels,
        )
    
    def align_with_mask(self, labels, mask):
        return [
            lbl if not hasattr(lbl, '__iter__') or type(lbl) == str else self.align_with_mask(lbl, mask)
            for lbl, m in zip(labels, mask) if m is not None
        ]
    
    def parse(self, sentences, force_long=False):
        
        if isinstance(sentences, str):
            sentences = [sentences]
        tokens = [s.split() for s in sentences]
        tokenized = self.bert_tokenizer(
            tokens,
            is_split_into_words=True,
            return_offsets_mapping=True,
            padding=True,
        )
        
        M = len(tokenized['input_ids'][0])
        if M > self.bert_tokenizer.model_max_length and not force_long:
            raise RuntimeError(f'Bert tokenizer produced a sequence of {M} tokens which exceeds the model’s limit ({self.bert_tokenizer.model_max_length}). Parse shorter sentences or call parse with force_long=True at your own risk.')
        x = dict_to_tensors(dict(tokenized))
        
        predicted = self.model.predict(x)
        labels = dict()
        for cat, pred in predicted.items():
            label_ids = tf.argmax(pred, axis=-1).numpy()
            lbls = [[self.labels[cat][i] for i in l_ids] for l_ids in label_ids]
            labels[cat] = lbls
        
        trees = []
        
        for i, (tkns, sentence) in enumerate(zip(tokens, sentences)):
            
            mask = masked_word_ids(tokenized.word_ids(i))
            lbls = {cat : self.align_with_mask(lbls[i], mask) for cat, lbls in labels.items()}
            trees.append((tkns, lbls))
        
        return trees