dataset_utils.py 7.72 KB
from collections import Counter, defaultdict
from itertools import chain

from datasets import ClassLabel, Sequence

from morfeusz2 import Morfeusz

from .constants import (
    FIRST,
    LAST,
    MASK_VALUE,
    UPPERCASE,
    LOWERCASE,
    SEG_BEGIN,
    SEG_INSIDE,
    TOKENS,
    SEGS,
    LEMMAS,
    LEMMA_CASES,
    LEMMA_RULES,
    TAGS,
    HEADS,
    ADJACENCY_MATRIX,
)

def make_lemma_rule(token, lemma, tag):
    case = UPPERCASE if lemma[0].isupper() else LOWERCASE
    prefix_cut = 0
    token, lemma = token.lower(), lemma.lower()
    #if lemma.startswith('naj') and  or
    if (token.startswith('nie') and 'neg' in tag) or (token.startswith('naj') and 'sup' in tag):
        prefix_cut = 3
        token = token[3:]
    cut = 0
    while token[:(cut + 1)] == lemma[:(cut + 1)] and cut < len(token):
        cut += 1
    suffix = lemma[cut:]
    cut = len(token) - cut
    return case, f'{prefix_cut}_{cut}_{suffix}'

def _add_lemma_rules(instance, tag_dict):
    tokens = instance[TOKENS]
    lemmas = instance[LEMMAS]
    tags = [tag_dict[v] for v in instance[TAGS]]
    cases, rules = zip(*(make_lemma_rule(*x) for x in zip(tokens, lemmas, tags)))
    return {
        LEMMA_CASES : cases,
        LEMMA_RULES : rules,
    }

def cast_labels(dataset, columns):
    vals = defaultdict(Counter)
    for d in dataset.values():
        for column in columns:
            vals[column].update(chain.from_iterable(s[column] for s in d))
    new_features = dataset['train'].features.copy()
    for column in columns:
        new_features[column] = Sequence(ClassLabel(names=sorted(vals[column].keys())))
    return dataset.cast(new_features)

def add_lemma_rules(dataset):
    tag_dict = dataset['train'].features[TAGS].feature.names
    new_dataset = dataset.map(lambda instance: _add_lemma_rules(instance, tag_dict))
    return cast_labels(new_dataset, [LEMMA_CASES, LEMMA_RULES])

EDGE, NO_EDGE = 1, 0

def _add_adjacency_matrix(instance):
    heads = instance[HEADS]
    # ROOT is ‘it’s own’ head
    heads = [x if x != -1 else i for i, x in enumerate(heads)]
    am = [[NO_EDGE for j in range(len(heads))] for i in range(len(heads))]
    for i, (token, head) in enumerate(zip(instance[TOKENS], heads)):
        am[i][head] = EDGE
    return {ADJACENCY_MATRIX : am}

def add_adjacency_matrix(dataset):
    return dataset.map(_add_adjacency_matrix)

# https://huggingface.co/docs/transformers/v4.23.1/en/tasks/token_classification

def masked_word_ids(word_ids, masking_strategy=FIRST):
    masked = []
    for i, word_idx in enumerate(word_ids):
        # Set the label for the first/last token of each word.
        # Mask the label for:
        #   * special tokens (word id = None)
        #   * other tokens in a word
        if word_idx is None:
            masked.append(None)
        else:
            if masking_strategy == FIRST:
                masked.append(word_idx if word_idx != word_ids[i - 1] else None)
            elif masking_strategy == LAST:
                masked.append(word_idx if word_idx != word_ids[i + 1] else None)
    return masked

def _align_row(values, masked_word_ids):
    return [MASK_VALUE if idx is None else values[idx] for idx in masked_word_ids]

def _align_example(example, masked_ids):
    
    column_names = list(example.keys())
    labels = defaultdict(list)
    masked_row = [MASK_VALUE for x in masked_ids]
    
    for column_name in column_names:
        if column_name in (TOKENS, LEMMAS):
            continue
        values = example[column_name]
        if type(values) == str:
            continue
        matrix = hasattr(values[0], '__iter__')
        if matrix:
            aligned_labels = [_align_row(values[idx], masked_ids) if idx is not None else masked_row for idx in masked_ids]
        else:
            aligned_labels = _align_row(example[column_name], masked_ids)
        labels[column_name] = aligned_labels
    
    return labels

def morf_tokenize(text, m):
    segs = dict()
    max_j = 0
    for i, j, interp in m.analyse(text):
        orth = interp[0]
        if (i, j) in segs:
            assert (orth == segs[(i, j)])
        else:
            segs[(i, j)] = orth
        max_j = max(max_j, j)
    return [segs[(i, i + 1)] for i in range(max_j)]

def _morf_tokenize_and_align(example, morfeusz, masking_strategy=FIRST):
    
    if masking_strategy not in (FIRST, LAST):
        raise RuntimeError(f'Uknown masking strategy: {masking_strategy}')
    if masking_strategy == LAST:
        raise RuntimeError(f'Can’t use {masking_strategy} masking strategy with retokenize')
    
    labels = defaultdict(list)
    
    mask = []
    for i, token in enumerate(example[TOKENS]):
        for j, morf_token in enumerate(morf_tokenize(token, morfeusz)):
            labels[TOKENS].append(morf_token)
            labels[SEGS].append(SEG_BEGIN if j == 0 else SEG_INSIDE)
            mask.append(i if j == 0 else None)
    
    labels.update(_align_example(example, mask))
    return labels

def morfeusz_retokenize(dataset, masking_strategy=FIRST):
    morfeusz = Morfeusz(generate=False)
    print(f'retokenizing using {morfeusz.dict_id()}')
    new_dataset = dataset.map(lambda x: _morf_tokenize_and_align(x, morfeusz, masking_strategy=masking_strategy))
    return cast_labels(new_dataset, [SEGS])

def bert_tokenize_and_align(example, tokenizer, masking_strategy=FIRST):
    
    if masking_strategy not in (FIRST, LAST):
        raise RuntimeError(f'Uknown masking strategy: {masking_strategy}')
    
    tokenized_inputs = tokenizer(example[TOKENS], truncation=True, is_split_into_words=True)
    word_ids = tokenized_inputs.word_ids()
    mask = masked_word_ids(word_ids, masking_strategy)
    labels = _align_example(example, mask)
    tokenized_inputs.update(labels)
    return tokenized_inputs

'''
    def _remove_columns(self, dataset):
        to_keep = ['id', TOKENS] + self.categories + self.categories2d
        columns_to_remove = [col for col in dataset.column_names if col not in to_keep]
        return dataset.remove_columns(columns_to_remove)

# TODO for 2d categories!!!
    def _unify_signatures(self, datasets):
        if not datasets:
            return None
        datasets = [self._remove_columns(dataset) for dataset in datasets]
        if len(datasets) == 1:
            return datasets
        print('unifying datasets:')
        for dataset in datasets:
            print(len(dataset), 'examples')
        for category in self.categories: #TODO!!! + self.categories2d:
            values = set()
            for dataset in datasets:
                if category in dataset.features:
                    feature = dataset.features[category].feature
                    if type(feature) == ClassLabel:
                        values.add(tuple(dataset.features[category].feature.names))
                    else:
                        print(type(feature))
                        1/0
                        values.add('VALUE')
            if len(values) > 1:
                print(f'{category}: aligning labels')
                mapping = {value : i for i, value in enumerate(sorted(set(chain.from_iterable(values))))}
                datasets = [dataset.align_labels_with_mapping(mapping, category) for dataset in datasets]
        return datasets
    
    def _join_datasets(self, datasets):
        if not datasets:
            return None
        if self.segmentation:
            datasets = [self._retokenize_dataset(d) for d in datasets]
        datasets = self._unify_signatures(datasets)
        if len(datasets) == 1:
            return datasets[0]
        print('joining datasets:')
        for dataset in datasets:
            print(len(dataset), 'examples')
        joined = concatenate_datasets(datasets)
        print('result:', len(joined), 'examples')
        if self.segmentation:
            joined = self._retokenize_dataset(joined)
        return joined
''';