ENIAMmstDisambiguation.ml 2.46 KB
open Xstd
open ENIAM_LCGtypes
open ENIAMmstModel
open ENIAMmstFeatures

let initialize () =
  MST_Model.initialize "dep.model.json";
  ()

exception UnsupportedLinearTerm of linear_term
exception EmptyVariant

let rec fill_dep_edges_array
    (data: disamb_info) parent (scores: float IntMap.t) =
  function
    Dot -> scores
  | Ref i -> IntMap.add scores i (score_edge data parent data.tree.(i))
  | Tuple l -> List.fold_left (fill_dep_edges_array data parent) scores l
  | Variant (_, l) -> List.fold_left
                        (fill_dep_edges_array data parent)
                        scores  (List.map snd l)
  | _ as x -> raise (UnsupportedLinearTerm x)

let rec disambiguate_args edge_scores =
  function
    Dot -> Dot, 0.0
  | Ref i -> Ref i, IntMap.find edge_scores i
  | Tuple l ->
    let (terms, scores) =
      List.map (disambiguate_args edge_scores) l |> List.split in
    let num = List.length scores |> float_of_int in
    Tuple terms, (List.fold_left (+.) 0.0 scores) /. num
  | Variant (lab, l) ->
    let (lbs, terms) = List.split l in
    let new_terms_scores = List.map (disambiguate_args edge_scores) terms in
    let select_best (term, score) (new_term, new_score) =
      if new_score > score then
        new_term, new_score
      else
        term, score in
    List.fold_left select_best (List.hd new_terms_scores) (List.tl new_terms_scores)
  | _ as x -> raise (UnsupportedLinearTerm x)

(* dezambiguacja argumentów pojedynczego wierzchołka algorytmem zachłannym *)
let disambiguate_node (data: disamb_info) parent =
  let edge_scores = fill_dep_edges_array
      data parent IntMap.empty (parent.args) in
  let (new_term, _) = disambiguate_args edge_scores (parent.args) in
  {parent with args = new_term}


let fix_array (tree: linear_term array) =
  let is_node = function
      Node _ -> true
    | _ -> false in
  tree |> Array.to_list |> List.partition is_node |> fst |> Array.of_list

let disambiguate_tree (tree: linear_term array) =
  let extract_node = (function
        Node node -> node
      | _ as x -> UnsupportedLinearTerm x |> raise) in
  let data : disamb_info = {tree = Array.map extract_node tree} in
  let disambiguate term = Node (extract_node term |> disambiguate_node data) in
  let disambiguated = Array.map disambiguate tree in
  (* let extarray = ref (disambiguated, 0, Array.length disambiguated, Dot) in
     ENIAM_LCGreductions.reshape_dependency_tree extarray *)
  disambiguated |> ENIAM_LCGreductions.normalize_variants |> fix_array
  (* disambiguated *)