1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
(* src/packages/stats/augment.ml *)
open Ast
(** augment(data, model) — augments a dataset with model-based predictions and residuals. *)
(*
--# Augment Data with Model Calculations
--#
--# Appends model predictions, residuals, and potentially diagnostic metrics to a dataset.
--#
--# @name augment
--# @param data :: DataFrame The dataset to augment.
--# @param model :: Model The model object.
--# @return :: DataFrame The original DataFrame with appended `fitted`, `resid`, etc.
--# @example
--# aug = augment(mtcars, model)
--# @family stats
--# @export
*)
let register env =
Env.add "augment"
(make_builtin_named ~name:"augment" ~variadic:true 0 (fun args _env ->
let named = List.filter_map (fun (n, v) -> match n with Some name -> Some (name, v) | None -> None) args in
let positional = List.filter_map (fun (n, v) -> match n with None -> Some v | Some _ -> None) args in
let data_v = match List.assoc_opt "data" named with
| Some v -> Some v
| None -> (match positional with v :: _ -> Some v | [] -> None)
in
let model_v = match List.assoc_opt "model" named with
| Some v -> Some v
| None -> (match positional with _ :: v :: _ -> Some v | _ -> (match positional with v :: _ when data_v <> Some v -> Some v | _ -> None))
in
match (data_v, model_v) with
| (Some (VDataFrame df), Some (VDict model)) ->
(* 1. Use residuals() to get fitted and resid *)
let residuals_fn = match Env.find_opt "residuals" _env with
| Some (VBuiltin b) -> b.b_func
| _ -> fun _ _ -> Error.type_error "Internal error: `residuals` not found."
in
let res_v = residuals_fn [(None, VDataFrame df); (None, VDict model)] (ref _env) in
(match res_v with
| VDataFrame res_df ->
let fitted = Arrow_table.get_column res_df.arrow_table "fitted" in
let resid = Arrow_table.get_column res_df.arrow_table "resid" in
let sigma = match List.assoc_opt "_model_data" model with
| Some (VDict d) -> (match List.assoc_opt "sigma" d with Some (VFloat f) -> f | _ -> 1.0)
| _ -> 1.0
in
let std_resid = match resid with
| Some (Arrow_table.FloatColumn data) ->
let n = Array.length data in
let r = Array.init n (fun i -> match data.(i) with Some e -> Some (e /. sigma) | None -> None) in
Some (Arrow_table.FloatColumn r)
| _ -> None
in
let new_cols = [
("fitted", match fitted with Some c -> c | None -> Arrow_table.NAColumn 0);
("resid", match resid with Some c -> c | None -> Arrow_table.NAColumn 0);
] in
let new_cols = match std_resid with
| Some c -> new_cols @ [("std_resid", c)]
| None -> new_cols
in
(* Combine with original columns *)
let orig_names = Arrow_table.column_names df.arrow_table in
let combined_cols = List.map (fun name ->
(name, match Arrow_table.get_column df.arrow_table name with Some c -> c | None -> Arrow_table.NAColumn 0)
) orig_names in
let final_table = Arrow_table.create (combined_cols @ new_cols) (Arrow_table.num_rows df.arrow_table) in
VDataFrame { arrow_table = final_table; group_keys = df.group_keys }
| VError e -> VError e
| _ -> Error.type_error "Function `residuals` did not return a DataFrame.")
| (Some (VError _ as e), _) | (_, Some (VError _ as e)) -> e
| _ -> Error.type_error "Function `augment` expects (DataFrame, Model)."
)) env