training.py 14.8 KB
import os
import subprocess

from itertools import chain

import tensorflow as tf
from tensorflow.keras import backend

from datasets import Dataset, concatenate_datasets, Sequence
from datasets.features import ClassLabel
from transformers import AutoTokenizer
from transformers.modeling_tf_utils import TFTokenClassificationLoss
from transformers.tf_utils import shape_list

from morfeusz2 import Morfeusz

from .constituency_parser import ConstituencyParser, category_names
from .data_utils import morf_tokenize_and_align, bert_tokenize_and_align, DataCollator, LabelDict, FIRST
from .constants import MASK_VALUE, TOKENS

SPLIT = 'split'
DROP = 'drop'

class AvgAccuracy(tf.keras.callbacks.Callback):
    def __init__(self):
        super(AvgAccuracy, self).__init__()
    def on_epoch_begin(self, epoch, logs={}):
        return
    def on_epoch_end(self, epoch, logs={}):
        accs = []
        val_accs = []
        for k, v in logs.items():
            if k.endswith('_acc'):
                if k.startswith('val_'):
                    val_accs.append(v)
                else:
                    accs.append(v)
        logs['avg_acc'] = sum(accs) / len(accs)
        logs['val_avg_acc'] = sum(val_accs) / len(val_accs)

def _masked_sparse_categorical_accuracy(y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.convert_to_tensor(y_true)
    y_pred_rank = y_pred.shape.ndims
    y_true_rank = y_true.shape.ndims
    # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
    if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
            backend.int_shape(y_true)) == len(backend.int_shape(y_pred))):
        y_true = tf.squeeze(y_true, [-1])
    y_pred = tf.compat.v1.argmax(y_pred, axis=-1)
    mask = tf.reshape(y_true, (-1,)) != MASK_VALUE
    y_true = tf.boolean_mask(tf.reshape(y_true, (-1,)), mask)
    y_pred = tf.boolean_mask(tf.reshape(y_pred, (-1,)), mask)
    # If the predicted output and actual output types don't match, force cast them
    # to match.
    if backend.dtype(y_pred) != backend.dtype(y_true):
        y_pred = tf.cast(y_pred, backend.dtype(y_true))
    ret = tf.cast(tf.equal(y_true, y_pred), backend.floatx())
    return ret
    
def _matrix_accuracy(y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.convert_to_tensor(y_true)
    row_mask = tf.math.reduce_any(y_true != MASK_VALUE, axis=-1)
    masked_true = tf.cast(tf.boolean_mask(y_true, row_mask), backend.floatx())
    masked_pred = tf.boolean_mask(y_pred, row_mask)
    argmax_true = tf.compat.v1.argmax(masked_true, axis=-1)
    argmax_pred = tf.compat.v1.argmax(masked_pred, axis=-1)
    ret = tf.cast(tf.equal(argmax_true, argmax_pred), backend.floatx())
    return ret
    
def _matrix_loss(labels, logits):
    loss_fn = tf.keras.losses.CategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE
    )
    row_mask = tf.math.reduce_any(labels != MASK_VALUE, axis=-1)
    masked_labels = tf.cast(tf.boolean_mask(labels, row_mask), backend.floatx())
    masked_logits = tf.boolean_mask(logits, row_mask)
    # add -MASK_VALUE where labels == MASK_VALUE to replace it with zero
    masked_labels = masked_labels + tf.cast(masked_labels == MASK_VALUE, backend.floatx()) * -MASK_VALUE
    loss = loss_fn(masked_labels, masked_logits)
    return loss

class Trainer(object):
    
    def __init__(
            self,
            bert_path,
            pretrain_datasets=[],
            train_datasets=[],
            dev_datasets=[],
            segmentation=True,
            lemmatisation=True,
            tagging=True,
            dependency=True,
            spines=True,
            #masking_strategy=FIRST,
            morfeusz_dict=None,
            #TODO?: long sequence handling
            #long_sequences=SPLIT,
            #max_seq_len=80,
            batch_size=32,
            ):
        if not train_datasets:
            raise ValueError('<train_datasets> not given, can’t train without training data!')
        #if long_sequences not in (SPLIT, DROP):
        #    raise ValueError(f'<long_sequences> must be one of: {SPLIT}, {DROP}')
        self.bert_path = bert_path
        self.pretrain_datasets = pretrain_datasets
        self.train_datasets = train_datasets
        self.dev_datasets = dev_datasets
        self.segmentation = segmentation
        self.lemmatisation = lemmatisation
        self.tagging = tagging
        self.dependency = dependency
        self.spines = spines
        self.masking_strategy = FIRST
        #TODO?
        #self._handle_long_sequences(long_sequences, max_seq_len)
        self.batch_size = batch_size
        
        self.categories, self.categories2d = category_names(self.segmentation, self.lemmatisation, self.tagging, self.dependency, self.spines)
        
        pretrain = self._join_datasets(self.pretrain_datasets)
        train = self._join_datasets(self.train_datasets)
        dev = self._join_datasets(self.dev_datasets)
        
        unified = list(self._unify_signatures(list(filter(None, [pretrain, train, dev]))))
        if pretrain:
            pretrain = unified[0]
            unified = unified[1:]
        if train:
            train = unified[0]
            unified = unified[1:]
        if dev:
            dev = unified[0]
            unified = unified[1:]
        assert(unified == [])
        assert(pretrain is None or train.features == pretrain.features)
        assert(dev is None or train.features == dev.features)
        
        self.label_dict = LabelDict(self.categories + self.categories2d, train.features)
        
        print('loading BERT tokenizer...')
        self.bert_tokenizer = AutoTokenizer.from_pretrained(self.bert_path)
        
        self.pretrain_data = self._prepare_tf_data(pretrain, shuffle=True) #!!!!!! TEST #!!!!!!
        self.train_data = self._prepare_tf_data(train, shuffle=True) #!!!!!! TEST #!!!!!!
        self.dev_data = self._prepare_tf_data(dev)
        
    def _remove_columns(self, dataset):
        to_keep = ['id', TOKENS] + self.categories + self.categories2d
        columns_to_remove = [col for col in dataset.column_names if col not in to_keep]
        return dataset.remove_columns(columns_to_remove)
    
    # TODO for 2d categories!!!
    def _unify_signatures(self, datasets):
        if not datasets:
            return None
        datasets = [self._remove_columns(dataset) for dataset in datasets]
        if len(datasets) == 1:
            return datasets
        print('unifying datasets:')
        for dataset in datasets:
            print(len(dataset), 'examples')
        for category in self.categories: #TODO!!! + self.categories2d:
            values = set()
            for dataset in datasets:
                if category in dataset.features:
                    feature = dataset.features[category].feature
                    if type(feature) == ClassLabel:
                        values.add(tuple(dataset.features[category].feature.names))
                    else:
                        print(type(feature))
                        1/0
                        values.add('VALUE')
            if len(values) > 1:
                print(f'{category}: aligning labels')
                mapping = {value : i for i, value in enumerate(sorted(set(chain.from_iterable(values))))}
                datasets = [dataset.align_labels_with_mapping(mapping, category) for dataset in datasets]
        return datasets
    
    def _join_datasets(self, datasets):
        if not datasets:
            return None
        if self.segmentation:
            datasets = [self._retokenize_dataset(d) for d in datasets]
        datasets = self._unify_signatures(datasets)
        if len(datasets) == 1:
            return datasets[0]
        print('joining datasets:')
        for dataset in datasets:
            print(len(dataset), 'examples')
        joined = concatenate_datasets(datasets)
        print('result:', len(joined), 'examples')
        if self.segmentation:
            joined = self._retokenize_dataset(joined)
        return joined
    
    def _prepare_tf_data(self, dataset, shuffle=False):
        if not dataset:
            return None
        dataset = dataset.map(lambda x: bert_tokenize_and_align(x, self.bert_tokenizer, self.masking_strategy))
        collator = DataCollator(self.bert_tokenizer, self.categories, self.categories2d)
        return Dataset.to_tf_dataset(
            dataset,
            columns=['input_ids', 'token_type_ids', 'attention_mask'],
            label_cols=self.categories + self.categories2d,
            batch_size=self.batch_size, shuffle=shuffle, collate_fn=collator
        )
    
    def _retokenize_dataset(self, dataset):
        morfeusz = Morfeusz(generate=False)
        print(f'\nretokenizing using {morfeusz.dict_id()}')
        retokenized = dataset.map(lambda x: morf_tokenize_and_align(x, morfeusz, self.masking_strategy))
        new_features = retokenized.features.copy()
        new_features['seg'] = Sequence(ClassLabel(names=['B', 'I']))
        return retokenized.cast(new_features)
    
    def _do_handle_long_sequences(self, tokens, tags, long_sequences, max_seq_len):
        if long_sequences == SPLIT:
            # TODO? add overlap
            #cut = (max_seq_len - overlap) if overlap else max_seq_len
            cut = max_seq_len
            new_tokens, new_tags = [], []
            for x, y in zip(tokens, tags):
                while x:
                    assert(len(x) == len(y))
                    new_tokens.append(x[:max_seq_len])
                    new_tags.append(y[:max_seq_len])
                    x, y = x[cut:], y[cut:]
            return new_tokens, new_tags
        elif long_sequences == DROP:
            return [x for x in tokens if len(x) <= max_seq_len], [y for y in tags if len(y) <= max_seq_len]
    
    '''
    def _handle_long_sequences(self, long_sequences, max_seq_len):
        print(f'\nhandling sequences longer than {max_seq_len} ({long_sequences})...')
        self.pretrain_tokens, self.pretrain_tags = self._do_handle_long_sequences(
            self.pretrain_tokens, self.pretrain_tags, long_sequences, max_seq_len)
        self.train_tokens, self.train_tags = self._do_handle_long_sequences(
            self.train_tokens, self.train_tags, long_sequences, max_seq_len)
        print(f'pretrain sentences: {len(self.pretrain_tokens)}')
        print(f'train sentences: {len(self.train_tokens)}')
    '''
    
    
    def _prepare_output_dir(self, path):
        if not os.path.exists(path):
            subprocess.call(f'mkdir {path}', shell=True)
        else:
            if not os.path.isdir(path):
                raise ValueError(f'{path} is not a directory')
            elif os.listdir(path):
                print(f'emptying {path}')
                subprocess.call(f'rm -r {path}/*', shell=True)
    
    
    def _make_callbacks(self, log_dir):
        callbacks = [
            AvgAccuracy(),
            tf.keras.callbacks.EarlyStopping(monitor='val_avg_acc', patience=4, verbose=1, restore_best_weights=True)
        ]
        if log_dir is not None:
            callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0, update_freq=50))
        return callbacks
    
    def train(self, epochs=10, pretrain_epochs=0, lr=0.00001, log_dir=None, model_dir=None, tree_weight=1):
        
        if log_dir is not None:
            self._prepare_output_dir(log_dir)
        if model_dir is not None:
            self._prepare_output_dir(model_dir)
        
        parser = ConstituencyParser.create(
            self.bert_path,
            self.label_dict.id2tag,
            segmentation=self.segmentation,
            lemmatisation=self.lemmatisation,
            tagging=self.tagging,
            dependency=self.dependency,
            spines=self.spines,
            bert_tokenizer=self.bert_tokenizer
        )
        
        parser.model.config.tf_legacy_loss = True
        hf_loss = parser.model.hf_compute_loss
        
        def _loss(labels, logits):
            print('LABELS:', labels)
            print('LOGITS:', logits)
            l = hf_loss(labels, logits)
            print('LOSS:', l)
            return l
        
        # wrap in eager execution so that tensor values can be printed
        def debug_loss(y_true, y_pred):
            return tf.py_function(func=_loss, inp=[tf.cast(y_true, tf.float32), y_pred], Tout=tf.float32)
        
        loss = hf_loss
        #loss = debug_loss
        matrix_loss = _matrix_loss
        
        accuracy_metric = lambda: tf.keras.metrics.MeanMetricWrapper(
            fn=_masked_sparse_categorical_accuracy, name='acc'
        )
        
        matrix_accuracy_metric = lambda: tf.keras.metrics.MeanMetricWrapper(
            fn=_matrix_accuracy, name='acc'
        )
        
        metrics = {cat : [accuracy_metric()] for cat in self.categories}
        metrics.update({cat : [matrix_accuracy_metric()] for cat in self.categories2d})
        
        losses = {cat : loss for cat in self.categories}
        losses.update({cat : matrix_loss for cat in self.categories2d})
        print(losses)
        
        loss_weights = None
        if tree_weight != 1:
            loss_weights = {
                cat : (1 if cat in ('seg', 'lemma_case', 'lemma_rule', 'tags') else tree_weight)
                for cat in self.categories + self.categories2d
            }
        
        initial_epoch = 0
        
        if self.pretrain_data:
            parser.model.compile(
                optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                loss=losses,
                metrics=metrics
            )
            parser.model.fit(
                x=self.pretrain_data,
                validation_data=(self.x_dev, self.y_dev),
                epochs=pretrain_epochs,
                callbacks=self._make_callbacks(log_dir),
            )
            initial_epoch, epochs = pretrain_epochs, epochs + pretrain_epochs
        
        parser.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
            loss=losses,
            loss_weights=loss_weights,
            metrics=metrics
        )
        parser.model.fit(
            x=self.train_data,
            validation_data=self.dev_data,
            epochs=epochs,
            callbacks=self._make_callbacks(log_dir),
            initial_epoch=initial_epoch,
        )
        
        if model_dir is not None:
            parser.save(model_dir)
            with open(f'{model_dir}/trainer.info', 'w') as f:
                print(f'pretrain_datasets: {self.pretrain_datasets}', file=f)
                print(f'train_datasets: {self.train_datasets}', file=f)
                print(f'dev_datasets: {self.dev_datasets}', file=f)
                print(f'tree_weight: {tree_weight}', file=f)
                print(f'batch_size: {self.batch_size}', file=f)
                print(f'epochs: {epochs}', file=f)
                print(f'pretrain_epochs: {pretrain_epochs}', file=f)
                print(f'lr: {lr}', file=f)
        
        return parser