ENIAMmstDisambiguation.ml 2.27 KB
open Xstd
open ENIAM_LCGtypes
open ENIAMmstModel
open ENIAMmstFeatures

let initialize () =
  MST_Model.initialize "dep.model.json";
  ()

exception UnsupportedLinearTerm of linear_term
exception EmptyVariant

let rec fill_dep_edges_array
    (data: disamb_info) parent (scores: float IntMap.t) =
  function
    Dot -> scores
  | Ref i -> IntMap.add scores i (score_edge data parent data.tree.(i))
  | Tuple l -> List.fold_left (fill_dep_edges_array data parent) scores l
  | Variant (_, l) -> List.fold_left
                        (fill_dep_edges_array data parent)
                        scores  (List.map snd l)
  | _ as x -> raise (UnsupportedLinearTerm x)

let rec disambiguate_args edge_scores =
  function
    Dot -> Dot, 0.0
  | Ref i -> Ref i, IntMap.find edge_scores i
  | Tuple l ->
    let (terms, scores) =
      List.map (disambiguate_args edge_scores) l |> List.split in
    let num = List.length scores |> float_of_int in
    Tuple terms, (List.fold_left (+.) 0.0 scores) /. num
  | Variant (lab, l) ->
    let (lbs, terms) = List.split l in
    let new_terms_scores = List.map (disambiguate_args edge_scores) terms in
    let select_best (term, score) (new_term, new_score) =
      if new_score >= score then
        new_term, new_score
      else
        term, score in
    List.fold_left select_best (List.hd new_terms_scores) (List.tl new_terms_scores)
  | _ as x -> raise (UnsupportedLinearTerm x)

(* dezambiguacja argumentów pojedynczego wierzchołka algorytmem zachłannym *)
let disambiguate_node (data: disamb_info) parent =
  let edge_scores = fill_dep_edges_array
      data parent IntMap.empty (parent.args) in
  let (new_term, _) = disambiguate_args edge_scores (parent.args) in
  {parent with args = new_term}

let disambiguate_tree (tree: linear_term array) =
  let extract_node = (function
        Node node -> node
      | _ as x -> UnsupportedLinearTerm x |> raise) in
  let data : disamb_info = {tree = Array.map extract_node tree} in
  let disambiguate term = Node (extract_node term |> disambiguate_node data) in
  let disambiguated = Array.map disambiguate tree in
  (* let extarray = ref (disambiguated, 0, Array.length disambiguated, Dot) in
     ENIAM_LCGreductions.reshape_dependency_tree extarray *)
  ENIAM_LCGreductions.normalize_variants disambiguated
  (* disambiguated *)