add_domain_terms.py 6.89 KB
import argparse
import os


def main():
    args = parse_arguments()
    dts = read_dts(args.terms)
    add_dts_by_year(args.input, args.output, dts)


def parse_arguments():
    parser = argparse.ArgumentParser(description='Add domain terms to the corpora.')
    required_arguments = parser.add_argument_group('required arguments')
    required_arguments.add_argument('-o', '--output', help='output directory')
    required_arguments.add_argument('-i', '--input', help='corpora root directory', required=True)
    required_arguments.add_argument('-t', '--terms', help='file with terms definitions', required=True)
    return parser.parse_args()


def read_dts(terms_file_path):
    dts = {}
    for term_def in open(terms_file_path).readlines():
        term_def = term_def.strip()
        if term_def:
            def_parts = term_def.split('|')
            term_id = def_parts[0]
            term_str = def_parts[2]
            domains = def_parts[4].split(',')
            for dom in domains:
                full_id = f'{dom}{term_id}'
                if term_str in dts:
                    dts[term_str].append(full_id)
                    print(f'DT {term_str} already included in {dts[term_str]}.')
                else:
                    dts[term_str] = [full_id]
    return dts


def add_dts_by_year(root_directory, out_corpora_directory, dts):
    for root, dirs, files in os.walk(root_directory):
        for filename in files:
            if filename.endswith('.conllup') or filename.endswith('.conllu'):
                src = os.path.join(root, filename)
                year = get_year(src)
                year_path = os.path.join(out_corpora_directory, year)
                os.makedirs(year_path, exist_ok=True)
                dst = os.path.join(year_path, filename)
                add_dts(src, dst, dts)


def get_year(filepath):
    with open(filepath, 'r') as conllup_file:
        for line in conllup_file:
            line = line.strip()
            if is_segment(line):
                continue
            elif is_metadata(line):
                name, value = get_metadata(line)
                if name == 'PublicationDate':
                    return value.split('-')[0]
    return '0'


def is_segment(line):
    if line and line[0].isdigit():
        return True
    return False


def is_metadata(line):
    if line.startswith('#'):
        return True
    return False


def get_metadata(line):
    name_value_pair = line.split('=', 1)
    name = name_value_pair[0].lstrip('#').strip()
    value = name_value_pair[1].strip()
    return name, value


def add_dts(src, dst, dts):
    cleaned_lines = []
    paragraph = []
    sentence = []

    with open(src, 'r') as conllup_file:
        for line in conllup_file:
            line = line.strip()
            if is_segment(line):
                sentence.append(line)
            elif is_metadata(line):
                name, value = get_metadata(line)
                if name == 'global.columns':
                    cleaned_lines.append(f'{line} CURLICAT:DOMAINTERM')
                elif name == 'newpar id':
                    if sentence:
                        add_dts_to_sentence(sentence, dts)
                        sentence.append('')
                        paragraph.extend(sentence)
                    sentence = []
                    if len(paragraph) > 1:
                        cleaned_lines.extend(paragraph)
                    paragraph = [line]
                elif name == 'sent_id':
                    if sentence:
                        add_dts_to_sentence(sentence, dts)
                        sentence.append('')
                        paragraph.extend(sentence)
                    sentence = [line]
                elif name == 'text':
                    sentence.append(line)
                else:
                    cleaned_lines.append(line)

        if sentence:
            add_dts_to_sentence(sentence, dts)
            sentence.append('')
            paragraph.extend(sentence)
        if len(paragraph) > 1:
            cleaned_lines.extend(paragraph)

        if not cleaned_lines[-1]:
            cleaned_lines.append('')

    if not cleaned_lines[-1].strip():
        cleaned_lines.pop()

    with open(dst, 'w') as dst_file:
        dst_file.write('\n'.join(cleaned_lines))


def add_dts_to_sentence(sentence, dts):
    sent_dts_count = 0
    enum2id = {}
    for i in range(2, len(sentence)):
        tokens = sentence[i:i+2]
        sentence[i:i+2], sent_dts_count = mark_dts(tokens, sent_dts_count, dts, enum2id)


def tokens_to_base(tokens):
    text = ''
    for tok in tokens:
        columns = tok.split('\t')
        base = columns[2]
        nsa = columns[9] == 'SpaceAfter=No'
        if nsa:
            text += base
        else:
            text += f'{base} '
    return text.strip()


def mark_dts(tokens, sent_dts_count, dts, enum2id):

    first_tok_base = tokens_to_base([tokens[0]])
    new_first_tok_dtids = dts[first_tok_base] if first_tok_base in dts else []
    first_tok_columns = tokens[0].split('\t')
    if len(first_tok_columns) < 14:
        first_tok_columns.append('_')

    if len(tokens) == 2:
        both_toks_base = tokens_to_base(tokens)
        both_toks_dtids = dts[both_toks_base] if both_toks_base in dts else []
        second_tok_columns = tokens[1].split('\t')
        if len(second_tok_columns) < 14:
            second_tok_columns.append('_')
        for dtid in both_toks_dtids:
            sent_dts_count += 1
            enum2id[sent_dts_count] = dtid

            dt_col_val = f'{sent_dts_count}:{dtid}'
            if first_tok_columns[13] == '_':
                first_tok_columns[13] = dt_col_val
            else:
                first_tok_columns[13] += f',{dt_col_val}'

            dt_col_val = str(sent_dts_count)
            if second_tok_columns[13] == '_':
                second_tok_columns[13] = dt_col_val
            else:
                second_tok_columns[13] += f',{dt_col_val}'

    actual_first_tok_domains = get_token_domains(first_tok_columns, enum2id)
    for ftid in new_first_tok_dtids:
        if ftid[:2] not in actual_first_tok_domains:
            sent_dts_count += 1
            enum2id[sent_dts_count] = ftid
            dt_col_val = f'{sent_dts_count}:{ftid}'
            if first_tok_columns[13] == '_':
                first_tok_columns[13] = dt_col_val
            else:
                first_tok_columns[13] += f',{dt_col_val}'

    if len(tokens) == 1:
        return ['\t'.join(first_tok_columns)], sent_dts_count

    return ['\t'.join(first_tok_columns), '\t'.join(second_tok_columns)], sent_dts_count


def get_token_domains(columns, enum2id):
    domains = []
    if len(columns) < 14 or columns[13] == '_':
        pass
    else:
        domain_terms_col = columns[13]
        for dt in domain_terms_col.split(','):
            if ':' in dt:
                domains.append(dt.split(':')[1][:2])
            else:
                domains.append(enum2id[int(dt)][:2])
    return domains


if __name__ == '__main__':
    main()