training.py 12.6 KB
import os
import subprocess

from itertools import chain

import tensorflow as tf
from tensorflow.keras import backend

from datasets import Dataset
from datasets.features import Features
from transformers import AutoTokenizer
from transformers.modeling_tf_utils import TFTokenClassificationLoss
from transformers.tf_utils import shape_list

from .constituency_parser import ConstituencyParser, category_names
from .data_utils import DataCollator
from .dataset_utils import (
    add_lemma_rules,
    add_spines_and_attachments,
    add_adjacency_matrix,
    morfeusz_retokenize,
    bert_tokenize_and_align
)
from .constants import (
    FIRST,
    MASK_VALUE,
    TOKENS,
    SEG_BEGIN,
    SEG_INSIDE,
    SEGS,
    LEMMAS,
    LEMMA_CASES,
    LEMMA_RULES,
    TAGS,
    NONTERMINALS,
    HEADS,
    DEPRELS,
    ADJACENCY_MATRIX,
)

SPLIT = 'split'
DROP = 'drop'

class AvgAccuracy(tf.keras.callbacks.Callback):
    def __init__(self):
        super(AvgAccuracy, self).__init__()
    def on_epoch_begin(self, epoch, logs={}):
        return
    def on_epoch_end(self, epoch, logs={}):
        accs = []
        val_accs = []
        for k, v in logs.items():
            if k.endswith('_acc'):
                if k.startswith('val_'):
                    val_accs.append(v)
                else:
                    accs.append(v)
        logs['avg_acc'] = sum(accs) / len(accs)
        logs['val_avg_acc'] = sum(val_accs) / len(val_accs)

def _masked_sparse_categorical_accuracy(y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.convert_to_tensor(y_true)
    y_pred_rank = y_pred.shape.ndims
    y_true_rank = y_true.shape.ndims
    # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
    if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
            backend.int_shape(y_true)) == len(backend.int_shape(y_pred))):
        y_true = tf.squeeze(y_true, [-1])
    y_pred = tf.compat.v1.argmax(y_pred, axis=-1)
    mask = tf.reshape(y_true, (-1,)) != MASK_VALUE
    y_true = tf.boolean_mask(tf.reshape(y_true, (-1,)), mask)
    y_pred = tf.boolean_mask(tf.reshape(y_pred, (-1,)), mask)
    # If the predicted output and actual output types don't match, force cast them
    # to match.
    if backend.dtype(y_pred) != backend.dtype(y_true):
        y_pred = tf.cast(y_pred, backend.dtype(y_true))
    ret = tf.cast(tf.equal(y_true, y_pred), backend.floatx())
    return ret
    
def _matrix_accuracy(y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.convert_to_tensor(y_true)
    row_mask = tf.math.reduce_any(y_true != MASK_VALUE, axis=-1)
    masked_true = tf.cast(tf.boolean_mask(y_true, row_mask), backend.floatx())
    masked_pred = tf.boolean_mask(y_pred, row_mask)
    argmax_true = tf.compat.v1.argmax(masked_true, axis=-1)
    argmax_pred = tf.compat.v1.argmax(masked_pred, axis=-1)
    ret = tf.cast(tf.equal(argmax_true, argmax_pred), backend.floatx())
    return ret
    
def _matrix_loss(labels, logits):
    loss_fn = tf.keras.losses.CategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE
    )
    row_mask = tf.math.reduce_any(labels != MASK_VALUE, axis=-1)
    masked_labels = tf.cast(tf.boolean_mask(labels, row_mask), backend.floatx())
    masked_logits = tf.boolean_mask(logits, row_mask)
    # add -MASK_VALUE where labels == MASK_VALUE to replace it with zero
    masked_labels = masked_labels + tf.cast(masked_labels == MASK_VALUE, backend.floatx()) * -MASK_VALUE
    loss = loss_fn(masked_labels, masked_logits)
    return loss

class Trainer(object):
    
    def __init__(
            self,
            bert_path,
            dataset,
            segmentation=True,
            lemmatisation=True,
            tagging=True,
            dependency=True,
            spines=True,
            nonterminal_features=[],
            #masking_strategy=FIRST,
            #TODO?: long sequence handling
            #long_sequences=SPLIT,
            #max_seq_len=80,
            batch_size=32,
            ):
        #if long_sequences not in (SPLIT, DROP):
        #    raise ValueError(f'<long_sequences> must be one of: {SPLIT}, {DROP}')
        self.bert_path = bert_path
        self.dataset = dataset
        self.segmentation = segmentation
        self.lemmatisation = lemmatisation
        self.tagging = tagging
        self.dependency = dependency
        self.spines = spines
        self.nonterminal_features = nonterminal_features
        self.masking_strategy = FIRST
        #TODO?
        #self._handle_long_sequences(long_sequences, max_seq_len)
        self.batch_size = batch_size
        
        self.categories = category_names(self.segmentation, self.lemmatisation, self.tagging, self.dependency, self.spines, self.nonterminal_features)
        
        if self.lemmatisation:
            if LEMMAS not in self.dataset['train'].features:
                raise RuntimeError(f'Can’t train lemmatisation without "{LEMMAS}" column in the dataset!')
            print('Preprocessing the dataset for lemmatisation...')
            self.dataset = add_lemma_rules(self.dataset)
        
        if self.tagging:
            if TAGS not in self.dataset['train'].features:
                raise RuntimeError(f'Can’t train tagging without "{TAGS}" column in the dataset!')
        
        if self.spines:
            if NONTERMINALS not in self.dataset['train'].features:
                raise RuntimeError(f'Can’t train spines without "{NONTERMINALS}" column in the dataset!')
            print('Preprocessing the dataset for spines and attachments...')
            self.dataset = add_spines_and_attachments(self.dataset)
        
        if self.dependency:
            if HEADS not in self.dataset['train'].features or DEPRELS not in self.dataset['train'].features:
                raise RuntimeError(f'Can’t train dependency without "{HEADS}" and "{DEPRELS}" columns in the dataset!')
            print('Preprocessing the dataset for dependency...')
            self.dataset = add_adjacency_matrix(self.dataset)
        
        if self.nonterminal_features:
            if not self.spines or not self.dependency:
                raise RuntimeError(f'Can’t train nonterminal features without training spines and dependency!')
            for ntf in self.nonterminal_features:
                if ntf not in self.dataset['train'].features:
                    raise RuntimeError(f'Can’t train "{ntf}" feature without "{ntf}" column in the dataset!')
        
        cols_to_remove = [c for c in self.dataset['train'].column_names if c not in self.categories + [TOKENS, LEMMAS, 'sent_id']]
        print('Removing columns:', ', '.join(cols_to_remove))
        self.dataset = self.dataset.remove_columns(cols_to_remove)
        
        if self.segmentation:
            print('Preprocessing the dataset for segmentation...')
            self.dataset = morfeusz_retokenize(self.dataset)
        
        self.features = Features({cat : self.dataset['train'].features[cat] for cat in self.categories})
        
        print('Loading BERT tokenizer...')
        self.bert_tokenizer = AutoTokenizer.from_pretrained(self.bert_path)
        
        print('Preprocessing the dataset for BERT...')
        self.dataset = self.dataset.map(lambda x: bert_tokenize_and_align(x, self.bert_tokenizer, self.masking_strategy))
        
        self.train_data = self._prepare_tf_data(self.dataset['train'], shuffle=True)
        self.dev_data = self._prepare_tf_data(self.dataset['validation'])
        
    def _prepare_tf_data(self, dataset, shuffle=False):
        collator = DataCollator(self.bert_tokenizer, self.features)
        return Dataset.to_tf_dataset(
            dataset,
            columns=['input_ids', 'token_type_ids', 'attention_mask'],
            label_cols=self.categories,
            batch_size=self.batch_size, shuffle=shuffle, collate_fn=collator
        )
    
    '''
    def _do_handle_long_sequences(self, tokens, tags, long_sequences, max_seq_len):
        if long_sequences == SPLIT:
            # TODO? add overlap
            #cut = (max_seq_len - overlap) if overlap else max_seq_len
            cut = max_seq_len
            new_tokens, new_tags = [], []
            for x, y in zip(tokens, tags):
                while x:
                    assert(len(x) == len(y))
                    new_tokens.append(x[:max_seq_len])
                    new_tags.append(y[:max_seq_len])
                    x, y = x[cut:], y[cut:]
            return new_tokens, new_tags
        elif long_sequences == DROP:
            return [x for x in tokens if len(x) <= max_seq_len], [y for y in tags if len(y) <= max_seq_len]
    
    def _handle_long_sequences(self, long_sequences, max_seq_len):
        print(f'\nhandling sequences longer than {max_seq_len} ({long_sequences})...')
        self.train_tokens, self.train_tags = self._do_handle_long_sequences(
            self.train_tokens, self.train_tags, long_sequences, max_seq_len)
        print(f'train sentences: {len(self.train_tokens)}')
    '''
    
    def _prepare_output_dir(self, path):
        if not os.path.exists(path):
            subprocess.call(f'mkdir {path}', shell=True)
        else:
            if not os.path.isdir(path):
                raise ValueError(f'{path} is not a directory')
            elif os.listdir(path):
                print(f'emptying {path}')
                subprocess.call(f'rm -r {path}/*', shell=True)
    
    
    def _make_callbacks(self, log_dir):
        callbacks = [
            AvgAccuracy(),
            tf.keras.callbacks.EarlyStopping(monitor='val_avg_acc', patience=4, verbose=1, restore_best_weights=True)
        ]
        if log_dir is not None:
            callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0, update_freq=50))
        return callbacks
    
    def train(self, epochs=10, lr=0.00001, log_dir=None, model_dir=None, tree_weight=1):
        
        if log_dir is not None:
            self._prepare_output_dir(log_dir)
        if model_dir is not None:
            self._prepare_output_dir(model_dir)
        
        parser = ConstituencyParser.create(
            self.bert_path,
            self.features,
            segmentation=self.segmentation,
            lemmatisation=self.lemmatisation,
            tagging=self.tagging,
            dependency=self.dependency,
            spines=self.spines,
            nonterminal_features=self.nonterminal_features,
            bert_tokenizer=self.bert_tokenizer
        )
        
        parser.model.config.tf_legacy_loss = True
        hf_loss = parser.model.hf_compute_loss
        
        def _loss(labels, logits):
            print('LABELS:', labels)
            print('LOGITS:', logits)
            l = hf_loss(labels, logits)
            print('LOSS:', l)
            return l
        
        # wrap in eager execution so that tensor values can be printed
        def debug_loss(y_true, y_pred):
            return tf.py_function(func=_loss, inp=[tf.cast(y_true, tf.float32), y_pred], Tout=tf.float32)
        
        loss = hf_loss
        #loss = debug_loss
        matrix_loss = _matrix_loss
        
        accuracy_metric = lambda: tf.keras.metrics.MeanMetricWrapper(
            fn=_masked_sparse_categorical_accuracy, name='acc'
        )
        
        matrix_accuracy_metric = lambda: tf.keras.metrics.MeanMetricWrapper(
            fn=_matrix_accuracy, name='acc'
        )
        
        metrics = {cat : [accuracy_metric() if cat != ADJACENCY_MATRIX else matrix_accuracy_metric()] for cat in self.categories}
        
        losses = {cat : loss if cat != ADJACENCY_MATRIX else matrix_loss for cat in self.categories}
        
        loss_weights = None
        if tree_weight != 1:
            loss_weights = {
                cat : (1 if cat in (SEGS, LEMMA_CASES, LEMMA_RULES, TAGS) else tree_weight)
                for cat in self.categories
            }
        
        initial_epoch = 0
        
        parser.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
            loss=losses,
            loss_weights=loss_weights,
            metrics=metrics
        )
        parser.model.fit(
            x=self.train_data,
            validation_data=self.dev_data,
            epochs=epochs,
            callbacks=self._make_callbacks(log_dir),
            initial_epoch=initial_epoch,
        )
        
        if model_dir is not None:
            parser.save(model_dir)
            with open(f'{model_dir}/trainer.info', 'w') as f:
                print(f'dataset: {self.dataset}', file=f)
                print(f'tree_weight: {tree_weight}', file=f)
                print(f'batch_size: {self.batch_size}', file=f)
                print(f'epochs: {epochs}', file=f)
                print(f'lr: {lr}', file=f)
        
        return parser