1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
(* src/packages/stats/augment.ml *)
open Ast

(** augment(data, model) — augments a dataset with model-based predictions and residuals. *)
(*
--# Augment Data with Model Calculations
--#
--# Appends model predictions, residuals, and potentially diagnostic metrics to a dataset.
--#
--# @name augment
--# @param data :: DataFrame The dataset to augment.
--# @param model :: Model The model object.
--# @return :: DataFrame The original DataFrame with appended `fitted`, `resid`, etc.
--# @example
--#   aug = augment(mtcars, model)
--# @family stats
--# @export
*)
let register env =
  Env.add "augment"
    (make_builtin_named ~name:"augment" ~variadic:true 0 (fun args _env ->
      let named = List.filter_map (fun (n, v) -> match n with Some name -> Some (name, v) | None -> None) args in
      let positional = List.filter_map (fun (n, v) -> match n with None -> Some v | Some _ -> None) args in
      let data_v = match List.assoc_opt "data" named with
        | Some v -> Some v
        | None -> (match positional with v :: _ -> Some v | [] -> None)
      in
      let model_v = match List.assoc_opt "model" named with
        | Some v -> Some v
        | None -> (match positional with _ :: v :: _ -> Some v | _ -> (match positional with v :: _ when data_v <> Some v -> Some v | _ -> None))
      in
      match (data_v, model_v) with
      | (Some (VDataFrame df), Some (VDict model)) ->
        (* 1. Use residuals() to get fitted and resid *)
        let residuals_fn = match Env.find_opt "residuals" _env with
          | Some (VBuiltin b) -> b.b_func
          | _ -> fun _ _ -> Error.type_error "Internal error: `residuals` not found."
        in
        let res_v = residuals_fn [(None, VDataFrame df); (None, VDict model)] (ref _env) in
        
        (match res_v with
         | VDataFrame res_df ->
            let fitted = Arrow_table.get_column res_df.arrow_table "fitted" in
            let resid  = Arrow_table.get_column res_df.arrow_table "resid" in
            
            let sigma = match List.assoc_opt "_model_data" model with
              | Some (VDict d) -> (match List.assoc_opt "sigma" d with Some (VFloat f) -> f | _ -> 1.0)
              | _ -> 1.0
            in
            
            let std_resid = match resid with
              | Some (Arrow_table.FloatColumn data) ->
                  let n = Array.length data in
                  let r = Array.init n (fun i -> match data.(i) with Some e -> Some (e /. sigma) | None -> None) in
                  Some (Arrow_table.FloatColumn r)
              | _ -> None
            in
            
            let new_cols = [
              ("fitted", match fitted with Some c -> c | None -> Arrow_table.NAColumn 0);
              ("resid",  match resid with Some c -> c | None -> Arrow_table.NAColumn 0);
            ] in
            let new_cols = match std_resid with
              | Some c -> new_cols @ [("std_resid", c)]
              | None -> new_cols
            in
            
            (* Combine with original columns *)
            let orig_names = Arrow_table.column_names df.arrow_table in
            let combined_cols = List.map (fun name ->
              (name, match Arrow_table.get_column df.arrow_table name with Some c -> c | None -> Arrow_table.NAColumn 0)
            ) orig_names in
            
            let final_table = Arrow_table.create (combined_cols @ new_cols) (Arrow_table.num_rows df.arrow_table) in
            VDataFrame { arrow_table = final_table; group_keys = df.group_keys }
            
         | VError e -> VError e
         | _ -> Error.type_error "Function `residuals` did not return a DataFrame.")
      | (Some (VError _ as e), _) | (_, Some (VError _ as e)) -> e
      | _ -> Error.type_error "Function `augment` expects (DataFrame, Model)."
    )) env