dataset_utils.py 2.2 KB
from collections import Counter, defaultdict
from itertools import chain

import datasets

from .constants import (
    UPPERCASE,
    LOWERCASE,
    TOKENS,
    LEMMAS,
    LEMMA_CASES,
    LEMMA_RULES,
    TAGS,
    HEADS,
    ADJACENCY_MATRIX,
)

def make_lemma_rule(token, lemma, tag):
    case = UPPERCASE if lemma[0].isupper() else LOWERCASE
    prefix_cut = 0
    token, lemma = token.lower(), lemma.lower()
    #if lemma.startswith('naj') and  or
    if (token.startswith('nie') and 'neg' in tag) or (token.startswith('naj') and 'sup' in tag):
        prefix_cut = 3
        token = token[3:]
    cut = 0
    while token[:(cut + 1)] == lemma[:(cut + 1)] and cut < len(token):
        cut += 1
    suffix = lemma[cut:]
    cut = len(token) - cut
    return case, f'{prefix_cut}_{cut}_{suffix}'

def _add_lemma_rules(instance, tag_dict):
    tokens = instance[TOKENS]
    lemmas = instance[LEMMAS]
    tags = [tag_dict[v] for v in instance[TAGS]]
    cases, rules = zip(*(make_lemma_rule(*x) for x in zip(tokens, lemmas, tags)))
    return {
        LEMMA_CASES : cases,
        LEMMA_RULES : rules,
    }

def cast_labels(dataset, columns):
    vals = defaultdict(Counter)
    for d in dataset.values():
        for column in columns:
            vals[column].update(chain.from_iterable(s[column] for s in d))
    new_features = dataset['train'].features.copy()
    for column in columns:
        new_features[column] = datasets.Sequence(datasets.ClassLabel(names=sorted(vals[column].keys())))
    return dataset.cast(new_features)

def add_lemma_rules(dataset):
    tag_dict = dataset['train'].features[TAGS].feature.names
    new_dataset = dataset.map(lambda instance: _add_lemma_rules(instance, tag_dict))
    return cast_labels(new_dataset, [LEMMA_CASES, LEMMA_RULES])

EDGE, NO_EDGE = 1, 0

def _add_adjacency_matrix(instance):
    heads = instance[HEADS]
    # ROOT is ‘it’s own’ head
    heads = [x if x != -1 else i for i, x in enumerate(heads)]
    am = [[NO_EDGE for j in range(len(heads))] for i in range(len(heads))]
    for i, (token, head) in enumerate(zip(instance[TOKENS], heads)):
        am[i][head] = EDGE
    return {ADJACENCY_MATRIX : am}

def add_adjacency_matrix(dataset):
    return dataset.map(_add_adjacency_matrix)