parser_server.py 3.99 KB
import argparse
import os
import sys

#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def parse_arguments():
    parser = argparse.ArgumentParser(description='Run Hydra as a TCP service.')
    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-p', '--port', type=int, required=True)
    required_arguments.add_argument('-m', '--models_dir', help='directory with Hydra model(s)', required=True)
    optional_arguments = parser.add_argument_group('optional arguments')
    optional_arguments.add_argument('-s', '--sentencer_models_dir', help='directory with sentencer model(s)', required=False)
    return parser.parse_args()

args = parse_arguments()

logger.info(args)

models_dir = args.models_dir
model_dirs = os.listdir(models_dir)

sentencer_models_dir = args.sentencer_models_dir
sentencer_model_dirs = os.listdir(sentencer_models_dir) if sentencer_models_dir is not None else []

from gpu_utils import tf_setup
tf_setup(3 * 1024 * len(model_dirs) + 2 * 1024 * len(sentencer_model_dirs))

import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
DEVICE = 'GPU' if gpus else 'CPU'

import json
import socketserver
import time

from hydra.hydra import Hydra
from sentencer_utils import Sentencer

parsers = { d : Hydra.load(os.path.join(models_dir, d)) for d in model_dirs }
sentencers = { d : Sentencer(os.path.join(sentencer_models_dir, d)) for d in sentencer_model_dirs }

class MyTCPHandler(socketserver.StreamRequestHandler):

    def handle(self):
        self.data = self.rfile.readline()
        request = json.loads(self.data)
        correct_lemmata = request['correct_lemmata']
        text = request['text']
        if 'sentencer_model' in request:
            paragraphs = sentencers[request['sentencer_model']].do_segmentation(text)
        else:
            if type(text) == str:
                paragraphs = [[text]]
            else:
                paragraphs = [text]
        output_format = request.get('format', 'json')
        response = {'device' : DEVICE, 'paragraphs' : []}
        for paragraph in paragraphs:
            trees, times = parsers[request['model']].parse(
                paragraph,
                return_jsons=(output_format == 'json'),
                return_conllu=(output_format == 'conllu'),
                root_label='ROOT',
                force_root_label=True,
                correct_lemmata=correct_lemmata,
                return_times=True
            )
            if output_format == 'json':
                trees = [{'sentence' : t['sentence'], 'tree' : t['json']['tree']} for t in trees]
                response['paragraphs'].append({'sentences' : trees, 'times' : times})
            if output_format == 'conllu':
                conllus = [{'sentence' : t['sentence'], 'conllu' : t['conllu']} for t in trees]
                response['paragraphs'].append({'sentences' : conllus, 'times' : times})
        self.wfile.write(bytes(json.dumps(response) + '\n', 'utf-8'))


HOST, PORT = '0.0.0.0', args.port
while True:
    try:
        with socketserver.TCPServer((HOST, PORT), MyTCPHandler) as server:
            n = len(str(PORT)) + len(HOST)
            logger.info('+--------------------' + '-' * n + '+')
            logger.info(f'|   device: {DEVICE}      ' + ' ' * n + '|')
            logger.info(f'|   listening on {HOST}:{PORT}   |')
            logger.info('+--------------------' + '-' * n + '+')
            logger.info('  available parser models:')
            for model in parsers.keys():
                logger.info(f'   * {model}')
            if sentencers:
                logger.info('  available sentencer models:')
                for model in sentencers.keys():
                    logger.info(f'   * {model}')
            server.serve_forever()
    except OSError as e:
        if e.errno == 98:
            # Address already in use
            logger.warning(f'{time.time()} {e}')
            time.sleep(3)
        else:
            raise