data_utils.py
1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from collections import defaultdict
from itertools import chain
import numpy as np
import tensorflow as tf
from datasets.features import ClassLabel
from .constants import MASK_VALUE
# based on tensorflow.data.data_collator.DataCollatorForTokenClassification
class DataCollator(object):
def __init__(self, tokenizer, features):
self.tokenizer = tokenizer
self.features = features
def _pad_labels(self, labels, sequence_length):
if self.tokenizer.padding_side == 'right':
return list(labels) + [MASK_VALUE] * (sequence_length - len(labels))
else:
return [MASK_VALUE] * (sequence_length - len(labels)) + list(labels)
def __call__(self, instance):
batch = self.tokenizer.pad(
instance,
padding=True,
)
sequence_length = tf.convert_to_tensor(batch['input_ids']).shape[1]
for category, feat in self.features.items():
if type(feat.feature) == ClassLabel:
padded_labels = [self._pad_labels(lbl, sequence_length) for lbl in batch[category]]
batch[category] = padded_labels
else:
raise ValueError(f'Unsupported feature type {type(feat.feature)} for "{category}".')
batch = {k: tf.convert_to_tensor(v) for k, v in batch.items()}
return batch
def dict_to_tensors(d):
return { k : tf.convert_to_tensor(v) for k, v in d.items() }