parser_server.py
2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
#for gpu in gpus:
# tf.config.experimental.set_memory_growth(gpu, True)
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=3*1024)]
)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
DEVICE = 'GPU' if gpus else 'CPU'
import json
import socketserver
import sys
import time
from neural_parser.constituency_parser import ConstituencyParser
parser = ConstituencyParser.load(sys.argv[2])
class MyTCPHandler(socketserver.StreamRequestHandler):
def handle(self):
self.data = self.rfile.readline()
request = json.loads(self.data)
correct_lemmata = request['correct_lemmata']
single = False
if 'sentence' in request:
single = True
sentences = [request['sentence']]
else:
sentences = request['sentences']
trees, times = parser.parse(sentences, return_jsons=True, root_label='ROOT', force_root_label=True, correct_lemmata=correct_lemmata, return_times=True)
assert(len(trees) == len(sentences))
trees = [t['tree'] for t in trees]
if single:
response = {'tree' : trees[0], 'device' : DEVICE}
else:
response = {'trees' : trees, 'device' : DEVICE}
response.update(times)
self.wfile.write(bytes(json.dumps(response), 'utf-8'))
if __name__ == '__main__':
HOST, PORT = '0.0.0.0', int(sys.argv[1])
while True:
try:
with socketserver.TCPServer((HOST, PORT), MyTCPHandler) as server:
n = len(str(PORT)) + len(HOST)
print('+--------------------' + '-' * n + '+', file=sys.stderr)
print(f'| device: {DEVICE} ' + ' ' * n + '|', file=sys.stderr)
print(f'| listening on {HOST}:{PORT} |', file=sys.stderr)
print('+--------------------' + '-' * n + '+', file=sys.stderr)
server.serve_forever()
except OSError as e:
if e.errno == 98:
# Address already in use
print(time.time(), e, file=sys.stderr)
time.sleep(3)
else:
raise