parser_server.py 2.25 KB
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        #for gpu in gpus:
        #    tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.set_logical_device_configuration(
            gpus[0],
            [tf.config.LogicalDeviceConfiguration(memory_limit=3*1024)]
        )
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

DEVICE = 'GPU' if gpus else 'CPU'

import json
import socketserver
import sys
import time

from neural_parser.constituency_parser import ConstituencyParser

parser = ConstituencyParser.load(sys.argv[2])


class MyTCPHandler(socketserver.StreamRequestHandler):

    def handle(self):
        self.data = self.rfile.readline()
        request = json.loads(self.data)
        sentence, correct_lemmata = request['sentence'], request['correct_lemmata']
        trees, times = parser.parse([sentence], return_jsons=True, correct_lemmata=correct_lemmata, return_times=True)
        assert(len(trees) == 1)
        tree = trees[0]['tree']
        response = {'tree' : tree, 'device' : DEVICE}
        response.update(times)
        self.wfile.write(bytes(json.dumps(response), 'utf-8'))

if __name__ == '__main__':
    HOST, PORT = '0.0.0.0', int(sys.argv[1])
    while True:
        try:
            with socketserver.TCPServer((HOST, PORT), MyTCPHandler) as server:
                n = len(str(PORT)) + len(HOST)
                print('+--------------------' + '-' * n + '+', file=sys.stderr)
                print(f'|   device: {DEVICE}      ' + ' ' * n + '|', file=sys.stderr)
                print(f'|   listening on {HOST}:{PORT}   |', file=sys.stderr)
                print('+--------------------' + '-' * n + '+', file=sys.stderr)
                server.serve_forever()
        except OSError as e:
            if e.errno == 98:
                # Address already in use
                print(time.time(), e, file=sys.stderr)
                time.sleep(3)
            else:
                raise