MultiTarget.py 6.15 KB
import logging
logger = logging.getLogger(__name__)

import tensorflow as tf

from transformers.modeling_tf_utils import get_initializer, input_processing, TFTokenClassificationLoss
from transformers.modeling_tf_outputs import TFTokenClassifierOutput

class TFMultiTargetTokenClassification(TFTokenClassificationLoss):
    
    _keys_to_ignore_on_load_missing = [
        r'dropout',
        r'dot'
    ]
    lm_layer_input_keys = (
        'input_ids',
        'attention_mask',
        'token_type_ids',
        'position_ids',
        'head_mask',
        'inputs_embeds',
        'output_attentions',
        'output_hidden_states',
        'return_dict',
        'training',
    )

    def get_lm_layer(self, config):
        lm_layer_name = config.model_type if self.lm_layer_name is None else self.lm_layer_name
        return self.lm_layer_class(config, name=lm_layer_name, **self.lm_layer_init_kwargs)

    def __init__(self, config, *inputs, **kwargs):
        
        logger.debug(f'config.model_type: {config.model_type}')
        logger.debug(f'self.config_class: {self.config_class}')
        
        categories = kwargs.pop('categories')
        labels = kwargs.pop('labels')
        
        super().__init__(config, *inputs, **kwargs)
        
        self.categories, self.categories2d = [], []
        for cat in categories:
            if cat in labels:
                self.categories.append(cat)
            else:
                self.categories2d.append(cat)
        
        self.lm_layer = self.get_lm_layer(config)
        logger.debug('self.lm_layer: {self.lm_layer}')
        
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = tf.keras.layers.Dropout(rate=classifier_dropout)

        initializer_kwargs = {}
        if getattr(config, 'initializer_range', None) is not None:
            initializer_kwargs['initializer_range'] = config.initializer_range
        self.classifiers = [
            tf.keras.layers.Dense(
                units=len(labels[cat]),
                kernel_initializer=get_initializer(**initializer_kwargs),
                name=f'classifier_{cat}',
            ) for cat in self.categories
        ]
        logger.info(f'created {len(self.classifiers)} classifier(s): {", ".join(self.categories)}')
        self.mappings2d = [
            {
                'head': tf.keras.layers.Dense(
                    units=128,
                    kernel_initializer=get_initializer(**initializer_kwargs),
                    name=f'head_mapping_{cat}'
                ),
                'dependent' : tf.keras.layers.Dense(
                    units=128,
                    kernel_initializer=get_initializer(**initializer_kwargs),
                    name=f'dependent_mapping_{cat}',
                )
            } for cat in self.categories2d
        ]
        logger.info(f'created {len(self.mappings2d)} head/dependent mapping(s): {", ".join(self.categories2d)}')
        self.dot = tf.keras.layers.Dot(axes=(-1))
    
    def call(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
        training=False,
        **kwargs,
    ):
        inputs = input_processing(
            func=self.call,
            config=self.config,
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            labels=labels,
            training=training,
            kwargs_call=kwargs,
        )
        kwargs = { key : inputs[key] for key in self.lm_layer_input_keys }
        outputs = self.lm_layer(
            #input_ids=inputs['input_ids'],
            #attention_mask=inputs['attention_mask'],
            #token_type_ids=inputs['token_type_ids'],
            #position_ids=inputs['position_ids'],
            #head_mask=inputs['head_mask'],
            #inputs_embeds=inputs['inputs_embeds'],
            #output_attentions=inputs['output_attentions'],
            #output_hidden_states=inputs['output_hidden_states'],
            #return_dict=inputs['return_dict'],
            #training=inputs['training'],
            **kwargs,
        )
        sequence_output = outputs[0]
        sequence_output = self.dropout(inputs=sequence_output, training=inputs['training'])
        logits = [classifier(inputs=sequence_output) for classifier in self.classifiers]
        logits2d = []
        for mapping in self.mappings2d:
            heads = mapping['head'](inputs=sequence_output)
            dependents = mapping['dependent'](inputs=sequence_output)
            product = self.dot([heads, dependents])
            logits2d.append(product)
        keys = self.categories + self.categories2d
        if len(keys) > 1:
            return dict(zip(keys, logits + logits2d))
        else:
            logits = logits[0] if logits else logits2d[0]
            return TFTokenClassifierOutput(logits=logits)
            
    def serving_output(self, output):
        return output
    
    def build(self, input_shape=None):
        if self.built:
            return
        self.built = True
        if getattr(self, 'lm_layer', None) is not None:
            with tf.name_scope(self.lm_layer.name):
                self.lm_layer.build(None)
        if getattr(self, 'classifiers', None) is not None:
            for classifier in self.classifiers:
                with tf.name_scope(classifier.name):
                    classifier.build([None, None, self.config.hidden_size])
        if getattr(self, 'mappings2d', None) is not None:
            for mapping2d in self.mappings2d:
                for key in ['head', 'dependent']:
                    with tf.name_scope(mapping2d[key].name):
                        mapping2d[key].build([None, None, self.config.hidden_size])