ENIAM_EdgeScore.ml 7.16 KB
open Xstd
open ENIAM_LCGtypes
open Yojson

module MST_Model : sig
  type mst_model = {
    typeAlphabet: int StringMap.t;
    dataAlphabet: int StringMap.t;
    parameters: float array}
  val read_model: string -> mst_model option
end
= struct
  type mst_model = {
    typeAlphabet: int StringMap.t;
    dataAlphabet: int StringMap.t;
    parameters: float array}
  exception MalformedModelJson
  (* TODO użyć efektywniejszych struktur i wymienić json na jakiś binarny format serializacji*)
  let construct_data_alphabet = function
      `List l -> let acc (map, counter) = function
          `String s -> (StringMap.add map s counter, counter +1)
        | _ -> raise MalformedModelJson
      in let (result, _) = List.fold_left acc (StringMap.empty, 0) l
      in result
    | _ -> raise MalformedModelJson

  let construct_type_alphabet = construct_data_alphabet

  let construct_parameters_array = function
      `List l -> let pull = function
          `Float f -> f
        | _ -> raise MalformedModelJson
      in
      Array.of_list (List.map pull l)
    | _ -> raise MalformedModelJson

  let read_model fname =
    try
      let data = Basic.from_file fname
      in match data with
        `Assoc l ->
        let rec scan l1 (ta, da, p) =
          match l1 with
            (str, data) :: t when str = "typeAlphabet" ->
            scan t (construct_type_alphabet data, da, p)
          | (str, data) :: t when str = "dataAlphabet" ->
            scan t (ta, construct_data_alphabet data, p)
          | (str, data) :: t when str = "parameters" ->
            scan t (ta, da, construct_parameters_array data)
          | _:: t -> scan t (ta,da,p)
          | [] -> (ta, da, p)
        in let (ta,da,p) = (scan l (StringMap.empty, StringMap.empty, Array.make 0 0.0))
        in Some {typeAlphabet = ta; dataAlphabet = da; parameters = p}
      | _ -> None
    with
      _ -> None
end
open MST_Model

exception UnsupportedLinearTerm of linear_term
exception EmptyVariant

let add_feature model str (fv: IntSet.t) =
  if StringMap.mem model.dataAlphabet str then
    IntSet.add fv (StringMap.find model.dataAlphabet str)
  else
    fv

let score_fv model (fv:IntSet.t) =
  IntSet.fold fv 0.0 (fun score i -> score +. model.parameters.(i))

let apply_features features fv =
  List.fold_left (|>) fv features


let add_linear_features model f_type (obs: string array) first second distStr fv =
  fv

let add_two_obs_features model prefix item1F1 item1F2 item2F1 item2F2 distStr fv =
  let add_diststr str = [str; str^"*"^distStr] in
  let flist = List.map ((^) prefix)[
    "2FF1="^item1F1;
    "2FF1="^item1F1^" "^item1F2;
    "2FF1="^item1F1^" "^item1F2^" "^item2F2;
    "2FF1="^item1F1^" "^item1F2^" "^item2F2^" "^item2F1;
    "2FF2="^item1F1^" "^item2F1;
    "2FF3="^item1F1^" "^item2F2;
    "2FF4="^item1F2^" "^item2F1^" "^item2F2;
    "2FF5="^item1F2^" "^item2F2;
    "2FF6="^item2F1^" "^item2F2;
    "2FF7="^item1F2;
    "2FF8="^item2F1;
    "2FF9="^item2F2;
  ] in
  let funs = List.map (add_feature model) (List.flatten (List.map add_diststr flist)) in
  apply_features funs fv

(*let add_core_features model tree attR small large (fv: IntSet.t) =
  let dist = 0 (*match abs (first.id - second.id) with
      x when x > 10 -> 10
    | x when x > 5 -> 5
                 | x -> x - 1 *) in
  let distStr = Printf.sprintf "&%s&%d" (if attR then "RA" else "LA") dist in
  let head_index =  if attR then small else large in
  let child_index = if attR then large else small in
  let nodes = Array.map (fun (Node node) -> node) tree in
  apply_features
    [add_two_obs_features model "HC"
       nodes.(head_index).orth nodes.(head_index).pos
       nodes.(child_index).orth nodes.(child_index).pos (*distStr*) "";
    ]
    fv *)

let score_edge model (parent: node) (child: node) =
  let fv = IntSet.empty in
  let fv1 = add_two_obs_features model "HC"
      parent.orth parent.pos child.orth child.pos "" fv in
  score_fv model fv1

let rec fill_dep_edges_array
    model (tree: linear_term array) parent (scores: float IntMap.t) =
  function
    Dot -> scores
  | Ref i -> (match tree.(i) with
        Node child -> IntMap.add scores i (score_edge model parent child)
      | _ as x -> raise (UnsupportedLinearTerm x))
  | Tuple l -> List.fold_left (fill_dep_edges_array model tree parent) scores l
  | Variant (_, l) -> List.fold_left
                        (fill_dep_edges_array model tree parent)
                        scores  (List.map snd l)
  | _ as x -> raise (UnsupportedLinearTerm x)

let rec disambiguate_args
    tree (edge_scores: float IntMap.t) (taken: IntSet.t) (parent: node) =
  let da = disambiguate_args tree edge_scores taken parent in
  let find_best (a1,b1,s1) (a2,b2,s2) =
    if s1 >= s2 then (a1,b1,s1) else (a2,b2,s2) in
  function
    Dot -> (taken, Dot, 0.0)
  | Ref i -> (match tree.(i) with
        Node child -> (IntSet.add taken i, Ref i, IntMap.find edge_scores i)
      | _ as x -> raise(UnsupportedLinearTerm x))
  | Variant (_, h::t) -> List.fold_left
                           find_best (da (snd h))
                           (List.map da (List.map snd t))
  | Variant (_, []) -> raise EmptyVariant
  | Tuple l ->
    let (_,to_do) = List.fold_left (fun (i, li) term -> (i+1, (i,term)::li)) (0, []) l in
    let (new_taken, output) = disambiguate_process_tuple
        tree edge_scores taken parent IntMap.empty to_do in
    let score = IntMap.fold output 0.0 (fun a i (_,f) -> a +. f) /.
                float (IntMap.size output) in
    let res_arr: linear_term array = Array.make (List.length l) Dot in
    let out_tuple =
      IntMap.iter output (fun index (term, _) -> res_arr.(index) <- term);
      Array.to_list res_arr in
    (new_taken, Tuple out_tuple, score)
  | _ as x -> raise (UnsupportedLinearTerm x)
and disambiguate_process_tuple
    tree edge_scores taken parent (cleared: (linear_term * float) IntMap.t) =
  let find_best (id1,a1,b1,s1) (id2, term) =
  if IntMap.mem cleared id2 then (id1,a1,b1,s1) else
    let (a2,b2,s2) = disambiguate_args tree edge_scores taken parent term in
    if s1 >= s2 then (id1,a1,b1,s1) else (id2,a2,b2,s2) in
  function
    [] -> (taken, cleared)
  | (index, term)::t as l ->
    let (a,b,s) = disambiguate_args tree edge_scores taken parent term in
    let (id_best, taken_best, tree_best, s_best) =
      List.fold_left find_best (index,a,b,s) t in
    disambiguate_process_tuple
      tree edge_scores taken_best parent
      (IntMap.add cleared id_best (tree_best, s_best))
      (List.remove_assoc id_best l)


(* dezambiguacja argumentów pojedynczego wierzchołka algorytmem zachłannym*)
(* TODO ten sam algorytm dla całej tablicy *)
let disambiguate_node model tree taken parent =
  let edge_scores = fill_dep_edges_array model
      tree parent IntMap.empty (parent.args) in
  let (new_taken, new_term, score) = disambiguate_args
      tree edge_scores taken parent (parent.args) in
  (new_taken, Node {parent with args = new_term}, score)
(* TODO rekurencyjna dezambiguacja jak w przypadku tuple *)
let disambiguate_tree model tree =
  let tree2 = Array.copy tree in
  let taken = ref IntSet.empty in
  let update i (Node parent) =
    (let (new_taken, new_term, _) = disambiguate_node model tree !taken parent in
     tree2.(i) <- new_term; taken := new_taken ) in
  Array.iteri update tree; tree2