MultiTarget.py 3.28 KB
import tensorflow as tf

from transformers import TFBertPreTrainedModel, BertConfig, TFBertMainLayer
from transformers.modeling_tf_utils import get_initializer, input_processing, TFTokenClassificationLoss

class TFBertForMultiTargetTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
    # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model
    _keys_to_ignore_on_load_unexpected = [
        r"pooler",
        r"mlm___cls",
        r"nsp___cls",
        r"cls.predictions",
        r"cls.seq_relationship",
    ]
    _keys_to_ignore_on_load_missing = [r"dropout"]

    def __init__(self, config: BertConfig, *inputs, **kwargs):
        
        categories = kwargs.pop('categories')
        labels = kwargs.pop('labels')
        
        super().__init__(config, *inputs, **kwargs)
        
        self.categories = categories
        
        self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = tf.keras.layers.Dropout(rate=classifier_dropout)
        self.classifiers = [
            tf.keras.layers.Dense(
                units=len(labels[cat]),
                kernel_initializer=get_initializer(config.initializer_range),
                name=f'classifier_{cat}',
            ) for cat in self.categories
        ]
        print(f'created {len(self.classifiers)} classifier(s)')

    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
        **kwargs,
    ):
        inputs = input_processing(
            func=self.call,
            config=self.config,
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            labels=labels,
            training=training,
            kwargs_call=kwargs,
        )
        outputs = self.bert(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            token_type_ids=inputs["token_type_ids"],
            position_ids=inputs["position_ids"],
            head_mask=inputs["head_mask"],
            inputs_embeds=inputs["inputs_embeds"],
            output_attentions=inputs["output_attentions"],
            output_hidden_states=inputs["output_hidden_states"],
            return_dict=inputs["return_dict"],
            training=inputs["training"],
        )
        sequence_output = outputs[0]
        sequence_output = self.dropout(inputs=sequence_output, training=inputs["training"])
        logits = [classifier(inputs=sequence_output) for classifier in self.classifiers]
        return dict(zip(self.categories, logits))
    
    def serving_output(self, output):
        return output