data_utils.py 1.42 KB
from collections import defaultdict
from itertools import chain

import numpy as np
import tensorflow as tf

from datasets.features import ClassLabel

from .constants import MASK_VALUE

# based on tensorflow.data.data_collator.DataCollatorForTokenClassification
class DataCollator(object):
    
    def __init__(self, tokenizer, features):
        self.tokenizer = tokenizer
        self.features = features
    
    def _pad_labels(self, labels, sequence_length):
        if self.tokenizer.padding_side == 'right':
            return list(labels) + [MASK_VALUE] * (sequence_length - len(labels))
        else:
            return [MASK_VALUE] * (sequence_length - len(labels)) + list(labels)

    def __call__(self, instance):

        batch = self.tokenizer.pad(
            instance,
            padding=True,
        )
        sequence_length = tf.convert_to_tensor(batch['input_ids']).shape[1]
        for category, feat in self.features.items():
            if type(feat.feature) == ClassLabel:
                padded_labels = [self._pad_labels(lbl, sequence_length) for lbl in batch[category]]
                batch[category] = padded_labels
            else:
                raise ValueError(f'Unsupported feature type {type(feat.feature)} for "{category}".')
        batch = {k: tf.convert_to_tensor(v) for k, v in batch.items()}
        return batch

def dict_to_tensors(d):
    return { k : tf.convert_to_tensor(v) for k, v in d.items() }