add_missing_has_nps_markers.py
3.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# -*- coding:utf-8 -*-
import re
import sys
import time
import jsonpickle
from django.core.management.base import BaseCommand
from django.db.models import Count
from multiservice.facade import Multiservice
from multiservice.facade.ttypes import *
from multiservice.types.ttypes import *
from thrift.transport import TSocket
from webapp.models import Expression, Source
PORT = 20000
HOST = 'multiservice.nlp.ipipan.waw.pl'
PROCESS_CHAIN = ['Concraft']
SOURCE = 'plwn'
START_ID = 0
class Command(BaseCommand):
help = 'Add missing has_nps markers.'
def handle(self, *args, **options):
repair_markers(SOURCE)
def repair_markers(source_key):
source = Source.objects.get(key=source_key)
expressions = Expression.objects.filter(link__source=source)
expressions = expressions.annotate(num_segments=Count('segments'))
expressions = expressions.filter(num_segments__gt=1, id__gt=START_ID).order_by('id')
for expression in expressions:
parse_and_update_expression(expression)
def parse_and_update_expression(expression):
transport, client = getThriftTransportAndClient(HOST, PORT)
expr_orth = expression.orth_text
request = createRequest(expr_orth, PROCESS_CHAIN)
try:
token = client.putObjectRequest(request)
status = None
while status not in [RequestStatus.DONE, RequestStatus.FAILED]:
status = client.getRequestStatus(token)
time.sleep(0.1)
if status == RequestStatus.DONE:
result = client.getResultObject(token)
update_expression(expression, result)
else:
print >> sys.stderr, client.getException(token)
sys.exit("Stopped loading data at %d expression!" % expression.id)
finally:
transport.close()
def getThriftTransportAndClient(host, port):
transport = TSocket.TSocket(host, port)
try:
transport = TTransport.TBufferedTransport(transport)
protocol = TBinaryProtocol.TBinaryProtocol(transport)
client = Multiservice.Client(protocol)
transport.open()
return (transport, client)
except:
transport.close()
raise
def createRequest(text, serviceNames):
ttext = TText(paragraphs=[TParagraph(text=chunk)
for chunk in re.split(r'\n\n+', text)])
chain = [RequestPart(serviceName=name) for name in serviceNames]
request = ObjectRequest(ttext, chain)
return request
def update_expression(expression, result):
jsonStr = jsonpickle.encode(result, unpicklable=False)
jsonObj = jsonpickle.decode(jsonStr)
json_segments = get_json_segments(jsonObj)
if len(json_segments) != expression.segments.count():
sys.exit("Can't update %d expression!" % expression.id)
else:
update_segments(expression, json_segments)
def get_json_segments(jsonObj):
segments = []
for para in jsonObj['paragraphs']:
for sent in para['sentences']:
for seg in sent['tokens']:
segments.append(seg)
return segments
def update_segments(expression, json_segments):
expr_segments = expression.segments.order_by('position_in_expr')
updated = False
first_seg = True
for expr_seg, json_seg in zip(expr_segments, json_segments):
if first_seg:
first_seg = False
continue
if json_seg['noPrecedingSpace']:
expr_seg.has_nps = True
expr_seg.save()
updated = True
if updated:
print_orth_by_segments(expression)
def print_orth_by_segments(expression):
expr = ''
for expr_seg in expression.segments.order_by('position_in_expr'):
orth = expr_seg.orth
if expr_seg.has_nps:
expr += orth
else:
expr += ' %s' % orth
print expr.lstrip()