dataset_utils.py 4.64 KB
from collections import Counter, defaultdict
from itertools import chain

from datasets import ClassLabel, Sequence

from .hybrid_tree_utils import tree_from_dataset_instance

from .constants import (
    FIRST,
    LAST,
    MASK_VALUE,
    EMPTY,
    TOKENS,
    SPINES,
    ANCHORS,
    ANCHOR_HS,
)

def _do_collect_spines(tree):
    if not tree.children:
        return [tree], []
    heads = [child for child in tree.children if child.is_head]
    assert(len(heads) == 1)
    head = heads[0]
    paths = []
    my_path = [tree]
    non_heads = []
    for child in tree.children:
        child_path, grandchildren_paths = _do_collect_spines(child)
        paths += grandchildren_paths
        if child == head:
            my_path += child_path
        else:
            non_heads.append(child_path)
    for child_path in non_heads:
        # h == which <tree.category> counting from the bottom is the anchor
        h = [n.category for n in my_path].count(tree.category)
        paths.append((tree.category, h, child_path))
    return my_path, paths

def _collect_spines(tree):
    try:
        path, paths = _do_collect_spines(tree)
    except:
        print(tree.to_brackets())
        raise
    return {p[-1] : (anchor, h, p[:-1]) for anchor, h, p in [('<ROOT>', '<ROOT>', path)] + paths}

def _compress_spine(spine):
    compressed = []
    for category in spine:
        if category in compressed:
            assert(category == compressed[-1])
        else:
            compressed.append(category)
    return compressed

def _add_spines_and_attachments(instance, dataset_features, compress):
    tree = tree_from_dataset_instance(instance, dataset_features)
    spines = _collect_spines(tree)
    leafs_linear = sorted(tree.get_yield(), key=lambda leaf: leaf.from_index)
    rows = []
    for leaf in leafs_linear:
        anchor, anchor_h, spine = spines[leaf]
        spine = [node.category for node in spine]
        if compress:
            spine = _compress_spine(spine)
        spine = '_'.join(spine) if spine else EMPTY
        rows.append((spine, anchor, str(anchor_h)))
    spines, anchors, anchor_hs = zip(*rows)
    return {
        SPINES : spines,
        ANCHORS : anchors,
        ANCHOR_HS : anchor_hs,
    }

def cast_labels(dataset, columns):
    vals = defaultdict(Counter)
    for d in dataset.values():
        for column in columns:
            vals[column].update(chain.from_iterable(s[column] for s in d))
    new_features = dataset['train'].features.copy()
    for column in columns:
        new_features[column] = Sequence(ClassLabel(names=sorted(vals[column].keys())))
    return dataset.cast(new_features)

def add_spines_and_attachments(dataset, compress=False):
    dataset_features = dataset['train'].features
    new_dataset = dataset.map(lambda instance: _add_spines_and_attachments(instance, dataset_features, compress=compress))
    return cast_labels(new_dataset, [SPINES, ANCHORS, ANCHOR_HS])

# https://huggingface.co/docs/transformers/v4.23.1/en/tasks/token_classification

def masked_word_ids(word_ids, masking_strategy=FIRST):
    masked = []
    for i, word_idx in enumerate(word_ids):
        # Set the label for the first/last token of each word.
        # Mask the label for:
        #   * special tokens (word id = None)
        #   * other tokens in a word
        if word_idx is None:
            masked.append(None)
        else:
            if masking_strategy == FIRST:
                masked.append(word_idx if word_idx != word_ids[i - 1] else None)
            elif masking_strategy == LAST:
                masked.append(word_idx if word_idx != word_ids[i + 1] else None)
    return masked

def _align_row(values, masked_word_ids):
    return [MASK_VALUE if idx is None else values[idx] for idx in masked_word_ids]

def _align_example(example, masked_ids):
    
    column_names = list(example.keys())
    labels = defaultdict(list)
    masked_row = [MASK_VALUE for x in masked_ids]
    
    for column_name in column_names:
        if column_name not in (SPINES, ANCHORS, ANCHOR_HS):
            continue
        values = example[column_name]
        aligned_labels = _align_row(example[column_name], masked_ids)
        labels[column_name] = aligned_labels
    
    return labels

def bert_tokenize_and_align(example, tokenizer, masking_strategy=FIRST):
    
    if masking_strategy not in (FIRST, LAST):
        raise RuntimeError(f'Uknown masking strategy: {masking_strategy}')
    
    tokenized_inputs = tokenizer(example[TOKENS], truncation=True, is_split_into_words=True)
    word_ids = tokenized_inputs.word_ids()
    mask = masked_word_ids(word_ids, masking_strategy)
    labels = _align_example(example, mask)
    tokenized_inputs.update(labels)
    return tokenized_inputs