training.py 6.28 KB
import os
import subprocess

from itertools import chain

import tensorflow as tf
from tensorflow.keras import backend

from datasets import Dataset
from datasets.features import Features
from transformers import AutoTokenizer
from transformers.modeling_tf_utils import TFTokenClassificationLoss
from transformers.tf_utils import shape_list

from .constituency_parser import ConstituencyParser
from .data_utils import DataCollator
from .dataset_utils import bert_tokenize_and_align
from .constants import (
    FIRST,
    MASK_VALUE,
    TOKENS,
    SPINES,
    ANCHORS,
    ANCHOR_HS,
)

class AvgAccuracy(tf.keras.callbacks.Callback):
    def __init__(self):
        super(AvgAccuracy, self).__init__()
    def on_epoch_begin(self, epoch, logs={}):
        return
    def on_epoch_end(self, epoch, logs={}):
        accs = []
        val_accs = []
        for k, v in logs.items():
            if k.endswith('_acc'):
                if k.startswith('val_'):
                    val_accs.append(v)
                else:
                    accs.append(v)
        logs['avg_acc'] = sum(accs) / len(accs)
        logs['val_avg_acc'] = sum(val_accs) / len(val_accs)

def _masked_sparse_categorical_accuracy(y_true, y_pred):
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.convert_to_tensor(y_true)
    y_pred_rank = y_pred.shape.ndims
    y_true_rank = y_true.shape.ndims
    # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
    if (y_true_rank is not None) and (y_pred_rank is not None) and (len(
            backend.int_shape(y_true)) == len(backend.int_shape(y_pred))):
        y_true = tf.squeeze(y_true, [-1])
    y_pred = tf.compat.v1.argmax(y_pred, axis=-1)
    mask = tf.reshape(y_true, (-1,)) != MASK_VALUE
    y_true = tf.boolean_mask(tf.reshape(y_true, (-1,)), mask)
    y_pred = tf.boolean_mask(tf.reshape(y_pred, (-1,)), mask)
    # If the predicted output and actual output types don't match, force cast them
    # to match.
    if backend.dtype(y_pred) != backend.dtype(y_true):
        y_pred = tf.cast(y_pred, backend.dtype(y_true))
    ret = tf.cast(tf.equal(y_true, y_pred), backend.floatx())
    return ret

class Trainer(object):
    
    def __init__(
            self,
            bert_path,
            dataset,
            batch_size=32,
            ):
        self.bert_path = bert_path
        self.dataset = dataset
        self.masking_strategy = FIRST
        self.batch_size = batch_size
        
        self.categories = [SPINES, ANCHORS, ANCHOR_HS]
        
        self.features = Features({cat : self.dataset['train'].features[cat] for cat in self.categories})
        
        print('Loading BERT tokenizer...')
        self.bert_tokenizer = AutoTokenizer.from_pretrained(self.bert_path)
        
        print('Preprocessing the dataset for BERT...')
        self.dataset = self.dataset.map(lambda x: bert_tokenize_and_align(x, self.bert_tokenizer, self.masking_strategy))
        
        self.train_data = self._prepare_tf_data(self.dataset['train'], shuffle=True)
        self.dev_data = self._prepare_tf_data(self.dataset['validation'])
        
    def _prepare_tf_data(self, dataset, shuffle=False):
        collator = DataCollator(self.bert_tokenizer, self.features)
        return Dataset.to_tf_dataset(
            dataset,
            columns=['input_ids', 'token_type_ids', 'attention_mask'],
            label_cols=self.categories,
            batch_size=self.batch_size, shuffle=shuffle, collate_fn=collator
        )
    
    def _prepare_output_dir(self, path):
        if not os.path.exists(path):
            subprocess.call(f'mkdir {path}', shell=True)
        else:
            if not os.path.isdir(path):
                raise ValueError(f'{path} is not a directory')
            elif os.listdir(path):
                print(f'emptying {path}')
                subprocess.call(f'rm -r {path}/*', shell=True)
    
    
    def _make_callbacks(self, log_dir):
        callbacks = [
            AvgAccuracy(),
            tf.keras.callbacks.EarlyStopping(monitor='val_avg_acc', patience=4, verbose=1, restore_best_weights=True)
        ]
        if log_dir is not None:
            callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=0, update_freq=50))
        return callbacks
    
    def train(self, epochs=10, lr=0.00001, log_dir=None, model_dir=None):
        
        if log_dir is not None:
            self._prepare_output_dir(log_dir)
        if model_dir is not None:
            self._prepare_output_dir(model_dir)
        
        parser = ConstituencyParser.create(
            self.bert_path,
            self.features,
            bert_tokenizer=self.bert_tokenizer
        )
        
        parser.model.config.tf_legacy_loss = True
        hf_loss = parser.model.hf_compute_loss
        
        def _loss(labels, logits):
            print('LABELS:', labels)
            print('LOGITS:', logits)
            l = hf_loss(labels, logits)
            print('LOSS:', l)
            return l
        
        # wrap in eager execution so that tensor values can be printed
        def debug_loss(y_true, y_pred):
            return tf.py_function(func=_loss, inp=[tf.cast(y_true, tf.float32), y_pred], Tout=tf.float32)
        
        loss = hf_loss
        #loss = debug_loss
        
        accuracy_metric = lambda: tf.keras.metrics.MeanMetricWrapper(
            fn=_masked_sparse_categorical_accuracy, name='acc'
        )
        
        metrics = {cat : [accuracy_metric()] for cat in self.categories}
        
        losses = {cat : loss for cat in self.categories}
        
        initial_epoch = 0
        
        parser.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
            loss=losses,
            metrics=metrics
        )
        parser.model.fit(
            x=self.train_data,
            validation_data=self.dev_data,
            epochs=epochs,
            callbacks=self._make_callbacks(log_dir),
            initial_epoch=initial_epoch,
        )
        
        if model_dir is not None:
            parser.save(model_dir)
            with open(f'{model_dir}/trainer.info', 'w') as f:
                print(f'dataset: {self.dataset}', file=f)
                print(f'batch_size: {self.batch_size}', file=f)
                print(f'epochs: {epochs}', file=f)
                print(f'lr: {lr}', file=f)
        
        return parser