add_missing_has_nps_markers.py 3.38 KB
# -*- coding:utf-8 -*-

import re
import sys
import time

import jsonpickle
from django.core.management.base import BaseCommand
from django.db.models import Count

from multiservice.facade import Multiservice
from multiservice.facade.ttypes import *
from multiservice.types.ttypes import *
from thrift.transport import TSocket

from webapp.models import Expression, Source


PORT = 20000
HOST = 'multiservice.nlp.ipipan.waw.pl'
PROCESS_CHAIN = ['Concraft']

SOURCE = 'wikidata'
START_ID = 0


class Command(BaseCommand):
    help = 'Add missing has_nps markers.'

    def handle(self, *args, **options):
        repair_markers(SOURCE)


def repair_markers(source_key):
    source = Source.objects.get(key=source_key)

    expressions = Expression.objects.filter(link__source=source)
    expressions = expressions.annotate(num_segments=Count('segments'))
    expressions = expressions.filter(num_segments__gt=1, id__gte=START_ID).order_by('id')

    for expression in expressions:
        parse_and_update_expression(expression)


def parse_and_update_expression(expression):
    transport, client = getThriftTransportAndClient(HOST, PORT)

    expr_orth = expression.orth_text

    request = createRequest(expr_orth, PROCESS_CHAIN)
    try:
        token = client.putObjectRequest(request)
        status = None
        while status not in [RequestStatus.DONE, RequestStatus.FAILED]:
            status = client.getRequestStatus(token)
            time.sleep(0.1)
        if status == RequestStatus.DONE:
            result = client.getResultObject(token)
            update_expression(expression, result)
        else:
            print >> sys.stderr, client.getException(token)
            sys.exit("Stopped loading data at %d expression!" % expression.id)
    finally:
        transport.close()


def getThriftTransportAndClient(host, port):
    transport = TSocket.TSocket(host, port)
    try:
        transport = TTransport.TBufferedTransport(transport)
        protocol = TBinaryProtocol.TBinaryProtocol(transport)
        client = Multiservice.Client(protocol)
        transport.open()
        return (transport, client)
    except:
        transport.close()
        raise


def createRequest(text, serviceNames):
    ttext = TText(paragraphs=[TParagraph(text=chunk)
                              for chunk in re.split(r'\n\n+', text)])
    chain = [RequestPart(serviceName=name) for name in serviceNames]
    request = ObjectRequest(ttext, chain)
    return request


def update_expression(expression, result):
    jsonStr = jsonpickle.encode(result, unpicklable=False)
    jsonObj = jsonpickle.decode(jsonStr)

    json_segments = get_json_segments(jsonObj)
    if len(json_segments) != expression.segments.count():
        print ("Can't update %d expression!" % expression.id)
    else:
        update_segments(expression, json_segments)


def get_json_segments(jsonObj):
    segments = []
    for para in jsonObj['paragraphs']:
        for sent in para['sentences']:
            for seg in sent['tokens']:
                segments.append(seg)

    return segments


def update_segments(expression, json_segments):
    expr_segments = expression.segments.order_by('position_in_expr')
    first_seg = True
    for expr_seg, json_seg in zip(expr_segments, json_segments):
        if first_seg:
            first_seg = False
            continue
        if json_seg['noPrecedingSpace']:
            expr_seg.has_nps = True
            expr_seg.save()