ENIAMmstModel.ml 2.18 KB
open Yojson
open Xstd


module MST_Model : sig
  type mst_model
  type feature_vector_t
  exception MalformedModelJson

  val read_model: string -> mst_model
  val initialize: string -> unit
  val add_feature: string -> feature_vector_t -> feature_vector_t
  val score_fv: feature_vector_t -> float
  val empty_fv: feature_vector_t
end
= struct
  type feature_vector_t = IntSet.t

  type mst_model = {
    typeAlphabet: int StringMap.t;
    dataAlphabet: int StringMap.t;
    parameters: float array}


  exception MalformedModelJson

  let model = ref {typeAlphabet = StringMap.empty;
                   dataAlphabet = StringMap.empty;
                   parameters = Array.make 0 0.0}

  let empty_fv = IntSet.empty

  let add_feature str (fv: feature_vector_t) =
    if StringMap.mem !model.dataAlphabet str then
      (prerr_string (str ^": " ^ (
           let i = StringMap.find !model.dataAlphabet str in string_of_float(!model.parameters.(i))
         ) ^ "\n");IntSet.add fv (StringMap.find !model.dataAlphabet str))
    else
      fv

  let score_fv (fv: feature_vector_t) =
    IntSet.fold fv 0.0 (fun score i -> score +. !model.parameters.(i))

  let construct_data_alphabet keys =
    let counter = ref 0 in
    let map = ref StringMap.empty in
    let length = Array.length keys in
    for i = 0 to length -1 do
      map := StringMap.add !map keys.(i) !counter;
      counter := !counter + 1;
    done;
    !map

  let construct_type_alphabet = construct_data_alphabet

  let read_model fname =
    let data = Basic.from_file fname in
    try
      let open Yojson.Basic.Util in
      let unwrapList = function
          `List l -> l
        | _ -> raise MalformedModelJson in
      let dataA = data |> member "dataAlphabet"  |> unwrapList |> filter_string in
      let typeA = data |> member "typeAlphabet"  |> unwrapList |> filter_string in
      let params = data |> member "parameters" |> unwrapList |> filter_float in
      {typeAlphabet = Array.of_list typeA |> construct_type_alphabet;
       dataAlphabet = Array.of_list dataA |> construct_data_alphabet;
       parameters = Array.of_list params}
    with
      _ -> raise MalformedModelJson

  let initialize fname =
    model := read_model fname;
    ()
end