main.py 2.87 KB
import os
import sys

from argparse import ArgumentParser
from natsort import natsorted

sys.path.append(os.path.abspath(os.path.join('..')))

from inout import mmax
from inout.constants import INPUT_FORMATS
from resolvers import resolve
from resolvers.constants import RESOLVERS
from utils import eprint


def main():
    args = parse_arguments()
    if not args.input:
        eprint("Error: Input file(s) not specified!")
    elif args.resolver not in RESOLVERS:
        eprint("Error: Unknown resolve algorithm!")
    elif args.format not in INPUT_FORMATS:
        eprint("Error: Unknown input file format!")
    else:
        process_texts(args.input, args.output, args.format, args.resolver)


def parse_arguments():
    parser = ArgumentParser(description='Corneferencer: coreference resolver using neural nets.')
    parser.add_argument('-i', '--input', type=str, action='store',
                        dest='input', default='',
                        help='input file or dir path')
    parser.add_argument('-o', '--output', type=str, action='store',
                        dest='output', default='',
                        help='output path; if not specified writes output to standard output')
    parser.add_argument('-f', '--format', type=str, action='store',
                        dest='format', default='mmax',
                        help='input format; default: mmax')
    parser.add_argument('-r', '--resolver', type=str, action='store',
                        dest='resolver', default='incremental',
                        help='resolve algorithm; default: incremental; possibilities: %s'
                             % ', '.join(RESOLVERS))

    args = parser.parse_args()
    return args


def process_texts(inpath, outpath, informat, resolver):
    if os.path.isdir(inpath):
        process_directory(inpath, outpath, informat, resolver)
    elif os.path.isfile(inpath):
        process_file(inpath, outpath, informat, resolver)
    else:
        eprint("Error: Specified input does not exist!")


def process_directory(inpath, outpath, informat, resolver):
    inpath = os.path.abspath(inpath)
    outpath = os.path.abspath(outpath)

    files = os.listdir(inpath)
    files = natsorted(files)

    for filename in files:
        textname = os.path.splitext(os.path.basename(filename))[0]
        textoutput = os.path.join(outpath, textname)
        textinput = os.path.join(inpath, filename)
        process_file(textinput, textoutput, informat, resolver)


def process_file(inpath, outpath, informat, resolver):
    basename = os.path.basename(inpath)
    if informat == 'mmax' and basename.endswith('.mmax'):
        print (basename)
        text = mmax.read(inpath)
        if resolver == 'incremental':
            resolve.incremental(text)
        elif resolver == 'entity_based':
            resolve.entity_based(text)
        mmax.write(inpath, outpath, text)


if __name__ == '__main__':
    main()