ENIAMmstDisambiguation.ml 3.18 KB
open Xstd
open ENIAM_LCGtypes
open ENIAMmstModel
open ENIAMmstFeatures


let initialize () =
  MST_Model.initialize "dep.model.json";
  ()

exception UnsupportedLinearTerm of linear_term
exception EmptyVariant

type disamb_output = {
  arg: linear_term;
  score: float
}

let rec fill_dep_edges_array
    (data: disamb_info) parent (scores: float IntMap.t) =
  function
    Dot -> scores
  | Ref i -> IntMap.add scores i (score_edge data parent data.tree.(i))
  | Tuple l -> List.fold_left (fill_dep_edges_array data parent) scores l
  | Variant (_, l) -> List.fold_left
                        (fill_dep_edges_array data parent)
                        scores  (List.map snd l)
  | _ as x -> raise (UnsupportedLinearTerm x)

let rec disambiguate_args edge_scores data cache =
  function
    Dot -> Dot, 0.0
  | Ref i ->
    Ref i, IntMap.find edge_scores i +. (disambiguate_node data cache i).score
  | Tuple l ->
    let (terms, scores) =
      List.map (disambiguate_args edge_scores data cache) l |> List.split in
    let num = List.length scores |> float_of_int in
    Tuple terms, (List.fold_left (+.) 0.0 scores) /. num
  | Variant (lab, l) ->
    let (lbs, terms) = List.split l in
    let new_terms_scores = List.map (disambiguate_args edge_scores data cache) terms in
    let select_best (term, score) (new_term, new_score) =
      if new_score > score then
        new_term, new_score
      else
        term, score in
    List.fold_left select_best (List.hd new_terms_scores) (List.tl new_terms_scores)
  | _ as x -> raise (UnsupportedLinearTerm x)
and disambiguate_node (data: disamb_info) (cache: disamb_output option array) idx =
  match cache.(idx) with
    Some data -> data
  | None ->
    let parent = data.tree.(idx) in
    let edge_scores = fill_dep_edges_array
        data parent IntMap.empty (parent.args) in
    let (new_term, new_score) = disambiguate_args edge_scores data cache (parent.args) in
    let res = {arg = new_term; score = new_score} in
    cache.(idx) <- Some res; res


let fix_array (tree: linear_term array) =
  let is_node = function
      Node _ -> true
    | _ -> false in
  tree |> Array.to_list |> List.partition is_node |> fst |> Array.of_list

let convert_paths (paths: (int * int * int) list) =
  let add_to_map map (id, l, r) =
    IntMap.add map id (l,r) in
  List.fold_left add_to_map IntMap.empty paths

let disambiguate_tree
    (paths: (int * int * int) list)
    (tokens: ENIAMtokenizerTypes.token_env ExtArray.t)
    (tree: linear_term array) =
  let extract_node = (function
        Node node -> node
      | _ as x -> UnsupportedLinearTerm x |> raise) in
  let data : disamb_info = {tree = Array.map extract_node tree;
                            tokens = tokens;
                            paths = convert_paths paths} in
  let cache: disamb_output option array = Array.make (Array.length tree) None in
  let replace_args idx node =
    let new_arg = match cache.(idx) with
        Some data -> data.arg
      | None -> Failure "something went wrong (ENIAMmstDisambiguation)" |> raise in
    Node {node with args = new_arg} in
  let disambiguated = disambiguate_node data cache 0; Array.mapi replace_args data.tree in
  disambiguated |> ENIAM_LCGreductions.normalize_variants |> fix_array