MultiTarget.py 4.43 KB
import tensorflow as tf

from transformers import TFBertPreTrainedModel, BertConfig, TFBertMainLayer
from transformers.modeling_tf_utils import get_initializer, input_processing, TFTokenClassificationLoss

class TFBertForMultiTargetTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
    # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
    _keys_to_ignore_on_load_unexpected = [
        r"pooler",
        r"mlm___cls",
        r"nsp___cls",
        r"cls.predictions",
        r"cls.seq_relationship",
    ]
    _keys_to_ignore_on_load_missing = [r"dropout"]

    def __init__(self, config: BertConfig, *inputs, **kwargs):
        
        categories = kwargs.pop('categories')
        labels = kwargs.pop('labels')
        
        super().__init__(config, *inputs, **kwargs)
        
        self.categories, self.categories2d = [], []
        for cat in categories:
            if cat in labels:
                self.categories.append(cat)
            else:
                self.categories2d.append(cat)
        
        self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = tf.keras.layers.Dropout(rate=classifier_dropout)
        self.classifiers = [
            tf.keras.layers.Dense(
                units=len(labels[cat]),
                kernel_initializer=get_initializer(config.initializer_range),
                name=f'classifier_{cat}',
            ) for cat in self.categories
        ]
        print(f'created {len(self.classifiers)} classifier(s)')
        self.mappings2d = [
            {
                'head': tf.keras.layers.Dense(
                    units=128,
                    kernel_initializer=get_initializer(config.initializer_range),
                    name=f'head_mapping_{cat}'
                ),
                'dependent' : tf.keras.layers.Dense(
                    units=128,
                    kernel_initializer=get_initializer(config.initializer_range),
                    name=f'dependent_mapping_{cat}',
                )
            } for cat in self.categories2d
        ]
        print(f'created {len(self.mappings2d)} head/dependent mapping(s)')
        self.dot = tf.keras.layers.Dot(axes=(-1))

    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
        **kwargs,
    ):
        inputs = input_processing(
            func=self.call,
            config=self.config,
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            labels=labels,
            training=training,
            kwargs_call=kwargs,
        )
        outputs = self.bert(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs["token_type_ids"],
            position_ids=inputs["position_ids"],
            head_mask=inputs["head_mask"],
            inputs_embeds=inputs["inputs_embeds"],
            output_attentions=inputs["output_attentions"],
            output_hidden_states=inputs["output_hidden_states"],
            return_dict=inputs["return_dict"],
            training=inputs["training"],
        )
        sequence_output = outputs[0]
        sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
        logits = [classifier(inputs=sequence_output) for classifier in self.classifiers]
        logits2d = []
        for mapping in self.mappings2d:
            heads = mapping['head'](inputs=sequence_output)
            dependents = mapping['dependent'](inputs=sequence_output)
            product = self.dot([heads, dependents])
            logits2d.append(product)
        return dict(zip(self.categories + self.categories2d, logits + logits2d))
    
    def serving_output(self, output):
        return output