data_utils.py 2.01 KB
from collections import defaultdict
from itertools import chain

import numpy as np
import tensorflow as tf

from datasets.features import ClassLabel, Sequence

from .constants import MASK_VALUE

# based on tensorflow.data.data_collator.DataCollatorForTokenClassification
class DataCollator(object):
    
    def __init__(self, tokenizer, features):
        self.tokenizer = tokenizer
        self.features = features
    
    def _pad_labels(self, labels, sequence_length):
        if self.tokenizer.padding_side == 'right':
            return list(labels) + [MASK_VALUE] * (sequence_length - len(labels))
        else:
            return [MASK_VALUE] * (sequence_length - len(labels)) + list(labels)

    def __call__(self, instance):

        batch = self.tokenizer.pad(
            instance,
            padding=True,
        )
        sequence_length = tf.convert_to_tensor(batch['input_ids']).shape[1]
        for category, feat in self.features.items():
            if type(feat.feature) == ClassLabel:
                padded_labels = [self._pad_labels(lbl, sequence_length) for lbl in batch[category]]
                batch[category] = padded_labels
            elif type(feat.feature) == Sequence:
                # pad the matrix rows to sequence_length
                # and add empty rows to pad to obtain a sequence_length x sequence_length matrix
                padded_labels = [
                    [
                        self._pad_labels(row, sequence_length) for row in lbl
                    ] + [
                        self._pad_labels([], sequence_length) for _ in range(sequence_length - len(lbl))
                    ]
                    for lbl in batch[category]
                ]
                batch[category] = padded_labels
            else:
                raise ValueError(f'Unsupported feature type {type(feat.feature)} for "{category}".')
        batch = {k: tf.convert_to_tensor(v) for k, v in batch.items()}
        return batch

def dict_to_tensors(d):
    return { k : tf.convert_to_tensor(v) for k, v in d.items() }