dataset_utils.py
4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from collections import Counter, defaultdict
from itertools import chain
from datasets import ClassLabel, Sequence
from .hybrid_tree_utils import tree_from_dataset_instance
from .constants import (
FIRST,
LAST,
MASK_VALUE,
EMPTY,
TOKENS,
SPINES,
ANCHORS,
ANCHOR_HS,
)
def _do_collect_spines(tree):
if not tree.children:
return [tree], []
heads = [child for child in tree.children if child.is_head]
assert(len(heads) == 1)
head = heads[0]
paths = []
my_path = [tree]
non_heads = []
for child in tree.children:
child_path, grandchildren_paths = _do_collect_spines(child)
paths += grandchildren_paths
if child == head:
my_path += child_path
else:
non_heads.append(child_path)
for child_path in non_heads:
# h == which <tree.category> counting from the bottom is the anchor
h = [n.category for n in my_path].count(tree.category)
paths.append((tree.category, h, child_path))
return my_path, paths
def _collect_spines(tree):
try:
path, paths = _do_collect_spines(tree)
except:
print(tree.to_brackets())
raise
return {p[-1] : (anchor, h, p[:-1]) for anchor, h, p in [('<ROOT>', '<ROOT>', path)] + paths}
def _compress_spine(spine):
compressed = []
for category in spine:
if category in compressed:
assert(category == compressed[-1])
else:
compressed.append(category)
return compressed
def _add_spines_and_attachments(instance, dataset_features, compress):
tree = tree_from_dataset_instance(instance, dataset_features)
spines = _collect_spines(tree)
leafs_linear = sorted(tree.get_yield(), key=lambda leaf: leaf.from_index)
rows = []
for leaf in leafs_linear:
anchor, anchor_h, spine = spines[leaf]
spine = [node.category for node in spine]
if compress:
spine = _compress_spine(spine)
spine = '_'.join(spine) if spine else EMPTY
rows.append((spine, anchor, str(anchor_h)))
spines, anchors, anchor_hs = zip(*rows)
return {
SPINES : spines,
ANCHORS : anchors,
ANCHOR_HS : anchor_hs,
}
def cast_labels(dataset, columns):
vals = defaultdict(Counter)
for d in dataset.values():
for column in columns:
vals[column].update(chain.from_iterable(s[column] for s in d))
new_features = dataset['train'].features.copy()
for column in columns:
new_features[column] = Sequence(ClassLabel(names=sorted(vals[column].keys())))
return dataset.cast(new_features)
def add_spines_and_attachments(dataset, compress=False):
dataset_features = dataset['train'].features
new_dataset = dataset.map(lambda instance: _add_spines_and_attachments(instance, dataset_features, compress=compress))
return cast_labels(new_dataset, [SPINES, ANCHORS, ANCHOR_HS])
# https://huggingface.co/docs/transformers/v4.23.1/en/tasks/token_classification
def masked_word_ids(word_ids, masking_strategy=FIRST):
masked = []
for i, word_idx in enumerate(word_ids):
# Set the label for the first/last token of each word.
# Mask the label for:
# * special tokens (word id = None)
# * other tokens in a word
if word_idx is None:
masked.append(None)
else:
if masking_strategy == FIRST:
masked.append(word_idx if word_idx != word_ids[i - 1] else None)
elif masking_strategy == LAST:
masked.append(word_idx if word_idx != word_ids[i + 1] else None)
return masked
def _align_row(values, masked_word_ids):
return [MASK_VALUE if idx is None else values[idx] for idx in masked_word_ids]
def _align_example(example, masked_ids):
column_names = list(example.keys())
labels = defaultdict(list)
masked_row = [MASK_VALUE for x in masked_ids]
for column_name in column_names:
if column_name not in (SPINES, ANCHORS, ANCHOR_HS):
continue
values = example[column_name]
aligned_labels = _align_row(example[column_name], masked_ids)
labels[column_name] = aligned_labels
return labels
def bert_tokenize_and_align(example, tokenizer, masking_strategy=FIRST):
if masking_strategy not in (FIRST, LAST):
raise RuntimeError(f'Uknown masking strategy: {masking_strategy}')
tokenized_inputs = tokenizer(example[TOKENS], truncation=True, is_split_into_words=True)
word_ids = tokenized_inputs.word_ids()
mask = masked_word_ids(word_ids, masking_strategy)
labels = _align_example(example, mask)
tokenized_inputs.update(labels)
return tokenized_inputs