1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
(* src/packages/stats/compare.ml *)
open Ast
(** compare(models) — aligns and compares multiple models in a wide DataFrame. *)
(*
--# Compare Models
--#
--# Align multiple model coefficient tables into a single wide DataFrame for comparison.
--#
--# @name compare
--# @param ... :: Variadic Models or a List of models to compare.
--# @return :: DataFrame A wide DataFrame with aligned terms and suffixed columns.
--# @example
--# m1 = lm(mpg ~ wt, data = mtcars)
--# m2 = lm(mpg ~ wt + hp, data = mtcars)
--# compare(m1, m2)
--# @family stats
--# @export
*)
let register env =
Env.add "compare"
(make_builtin ~name:"compare" ~variadic:true 0 (fun args _env ->
let models = match args with
| [VList items] -> List.map snd items
| _ -> args
in
let rec collect_models acc = function
| [] -> Ok (List.rev acc)
| VDict p :: rest -> collect_models (p :: acc) rest
| VError e :: _ -> Error (VError e)
| _ :: rest -> collect_models acc rest
in
let parsed_models = collect_models [] models in
match parsed_models with
| Error e -> e
| Ok valid_models ->
if List.length valid_models = 0 then
Error.type_error "Function `compare` expects one or more model Dicts."
else
try
(* 1. Extract tidy DFs and names *)
let model_infos = List.mapi (fun i m ->
let name = match List.assoc_opt "name" m with
| Some (VString s) -> s
| _ -> string_of_int (i + 1)
in
let tidy = match List.assoc_opt "_tidy_df" m with
| Some (VDataFrame df) -> df
| _ -> raise (Failure (Printf.sprintf "Model %s has no tidy coefficient table." name))
in
(name, tidy)
) valid_models in
(* 2. Collect union of terms, preserving order *)
let seen_terms = Hashtbl.create 16 in
let all_terms = ref [] in
List.iter (fun (_, tidy) ->
let terms = Arrow_table.get_string_column tidy.arrow_table "term" in
Array.iter (function
| Some t -> if not (Hashtbl.mem seen_terms t) then begin
Hashtbl.add seen_terms t ();
all_terms := t :: !all_terms
end
| None -> ()) terms
) model_infos;
let union_terms = Array.of_list (List.rev !all_terms) in
let n_union = Array.length union_terms in
(* 3. build columns for each model *)
let build_model_columns (name, tidy) =
let model_terms = Arrow_table.get_string_column tidy.arrow_table "term" in
let estimates = Arrow_table.get_float_column tidy.arrow_table "estimate" in
let std_errors = Arrow_table.get_float_column tidy.arrow_table "std_error" in
let statistics = Arrow_table.get_float_column tidy.arrow_table "statistic" in
let p_values = Arrow_table.get_float_column tidy.arrow_table "p_value" in
let term_map = Hashtbl.create (Array.length model_terms) in
Array.iteri (fun i t_opt ->
match t_opt with
| Some t -> Hashtbl.add term_map t i
| None -> ()
) model_terms;
let align_col src_col =
Array.init n_union (fun i ->
let t = union_terms.(i) in
match Hashtbl.find_opt term_map t with
| Some idx -> src_col.(idx)
| None -> None
)
in
[
("estimate_" ^ name, Arrow_table.FloatColumn (align_col estimates));
("std_error_" ^ name, Arrow_table.FloatColumn (align_col std_errors));
("statistic_" ^ name, Arrow_table.FloatColumn (align_col statistics));
("p_value_" ^ name, Arrow_table.FloatColumn (align_col p_values));
]
in
let term_col = ("term", Arrow_table.StringColumn (Array.map (fun t -> Some t) union_terms)) in
let other_cols = List.concat (List.map build_model_columns model_infos) in
let table = Arrow_table.create (term_col :: other_cols) n_union in
VDataFrame { arrow_table = table; group_keys = [] }
with Failure msg -> Error.type_error msg
)) env