1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
open Ast

(** Internal helper to format cell values *)
let cell_to_string v =
  match v with
  | VString s -> s
  | VNA na_t ->
      let tag = Utils.na_type_to_string na_t in
      if tag = "" then "NA" else "NA(" ^ tag ^ ")"
  | VFloat f ->
      if f = floor f then Printf.sprintf "%.1f" f
      else if Float.abs f < 0.001 then Printf.sprintf "%.4e" f
      else Printf.sprintf "%.4g" f
  | other -> Utils.value_to_string other

(** Pretty-print a DataFrame as a table *)
let pretty_print_dataframe ?(headers) { arrow_table; group_keys } =
  let nrows = Arrow_table.num_rows arrow_table in
  let value_columns = Arrow_bridge.table_to_value_columns arrow_table in
  if value_columns = [] then
    "Empty DataFrame (0 rows x 0 cols)\n"
  else
    let col_names = match headers with
      | Some h -> List.map (fun (old_n, _) -> match List.assoc_opt old_n h with Some new_n -> new_n | None -> old_n) value_columns
      | None -> List.map fst value_columns in
    (* Compute column widths *)
    let col_widths = List.map (fun (name, col_data) ->
      let header_len = String.length name in
      let max_data_len = Array.fold_left (fun acc v ->
        max acc (String.length (cell_to_string v))
      ) 0 col_data in
      max header_len max_data_len
    ) value_columns in
    let buf = Buffer.create 256 in
    (* Header *)
    let header_parts = List.map2 (fun name width ->
      Printf.sprintf "%-*s" width name
    ) col_names col_widths in
    Buffer.add_string buf ("  " ^ String.concat "  " header_parts ^ "\n");
    (* Separator *)
    let sep_parts = List.map (fun width ->
      String.make width '-'
    ) col_widths in
    Buffer.add_string buf ("  " ^ String.concat "  " sep_parts ^ "\n");
    (* Data rows — show at most 20 rows *)
    let display_rows = min nrows 20 in
    for row_idx = 0 to display_rows - 1 do
      let row_parts = List.map2 (fun (_name, col_data) width ->
        let v = col_data.(row_idx) in
        Printf.sprintf "%-*s" width (cell_to_string v)
      ) value_columns col_widths in
      Buffer.add_string buf ("  " ^ String.concat "  " row_parts ^ "\n")
    done;
    if nrows > 20 then
      Buffer.add_string buf (Printf.sprintf "  ... (%d more rows)\n" (nrows - 20));
    (* Footer *)
    let group_info = if group_keys = [] then ""
      else Printf.sprintf " grouped by [%s]" (String.concat ", " group_keys) in
    Buffer.add_string buf (Printf.sprintf "DataFrame: %d rows x %d cols%s\n"
      nrows (List.length value_columns) group_info);
    Buffer.contents buf

(** Pretty-print an error value *)
let pretty_print_error { code; message; context; location; na_count = _ } =
  let buf = Buffer.create 128 in
  let rendered_message =
    match location with
    | Some { file; line; column } ->
        let prefix =
          match file with
          | Some filename -> Printf.sprintf "[%s:L%d:C%d]" filename line column
          | None -> Printf.sprintf "[L%d:C%d]" line column
        in
        prefix ^ " " ^ message
    | None -> message
  in
  Buffer.add_string buf (Printf.sprintf "Error(%s): %s\n"
    (Utils.error_code_to_string code) rendered_message);
  if context <> [] then begin
    Buffer.add_string buf "  Context:\n";
    List.iter (fun (k, v) ->
      Buffer.add_string buf (Printf.sprintf "    %s: %s\n" k (Utils.value_to_string v))
    ) context
  end;
  Buffer.contents buf

(** Pretty-print a pipeline *)
let pretty_print_pipeline { p_nodes; p_deps; p_runtimes; _ } =
  let buf = Buffer.create 256 in
  Buffer.add_string buf (Printf.sprintf "Pipeline (%d nodes):\n" (List.length p_nodes));
  List.iter (fun (name, v) ->
    let deps = match List.assoc_opt name p_deps with
      | Some d when d <> [] -> Printf.sprintf " [depends: %s]" (String.concat ", " d)
      | _ -> ""
    in
    let runtime = match List.assoc_opt name p_runtimes with
      | Some r when r <> "T" -> Printf.sprintf " [%s]" r
      | _ -> ""
    in
    let val_str = match v with
      | VDataFrame { arrow_table; _ } ->
          Printf.sprintf "DataFrame(%d rows x %d cols)"
            (Arrow_table.num_rows arrow_table) (Arrow_table.num_columns arrow_table)
      | _ -> Utils.value_to_string v
    in
    Buffer.add_string buf (Printf.sprintf "  %s = %s%s%s\n" name val_str runtime deps)
  ) p_nodes;
  Buffer.contents buf

(** Pretty-print a model summary *)
let pretty_print_summary pairs =
  let model_class = match List.assoc_opt "model_class" pairs with Some (VString s) -> s | _ -> "lm" in
  let summary_type = match List.assoc_opt "summary_type" pairs with Some (VString s) -> s | _ -> "coefficients" in
  let is_glm = model_class = "glm" in
  let family = match List.assoc_opt "family" pairs with Some (VString s) -> s | Some v -> Utils.value_to_string v | None -> "Gaussian" in
  let link = match List.assoc_opt "link" pairs with Some (VString s) -> s | Some v -> Utils.value_to_string v | None -> "identity" in
  let buf = Buffer.create 256 in
  if is_glm then begin
    Buffer.add_string buf (Printf.sprintf "Family:   %s\n" family);
    Buffer.add_string buf (Printf.sprintf "Link:     %s\n\n" link)
  end;
  Buffer.add_string buf (if summary_type = "fit_stats" then "Model metrics:\n" else "Coefficients:\n");
  (match List.assoc_opt "_tidy_df" pairs with
  | Some (VDataFrame df) ->
      if summary_type = "fit_stats" then
        Buffer.add_string buf (pretty_print_dataframe df)
      else
        let headers = [
          ("term", "");
          ("estimate", "Estimate");
          ("std_error", "Std. Error");
          ("statistic", if is_glm then "z value" else "t value");
          ("p_value", if is_glm then "Pr(>|z|)" else "Pr(>|t|)")
        ] in
        Buffer.add_string buf (pretty_print_dataframe ~headers df)
  | _ -> Buffer.add_string buf "No coefficient data available.\n");
  Buffer.contents buf

let is_visual_metadata_class = function
  | VString "ggplot" | VString "matplotlib" | VString "plotnine" | VString "seaborn" | VString "plotly" | VString "altair" -> true
  | _ -> false

let display_keys_from_pairs pairs =
  List.fold_left (fun acc (k, v) ->
    match k, v with
    | "_display_keys", VList items ->
        Some (List.filter_map (fun (_, v) -> match v with VString s -> Some s | _ -> None) items)
    | _ -> acc
  ) None pairs

let visible_pairs_from_dict pairs =
  let non_metadata_pairs =
    List.filter (fun (k, _) -> k <> "_display_keys") pairs
  in
  match display_keys_from_pairs pairs with
  | None -> non_metadata_pairs
  | Some keys -> List.filter (fun (k, _) -> List.mem k keys) non_metadata_pairs

let list_items_are_simple items =
  List.length items <= 5
  && List.for_all (fun (_, v) ->
       match v with
       | VDict _ | VList _ | VVector _ | VDataFrame _ | VPipeline _ -> false
       | _ -> true
     ) items

(* Mutual recursion: entries recurse into children, and child rendering feeds
   back into entry rendering for each sibling in tree order. *)
let rec render_tree_entry prefix is_last (label, value) =
  let branch = if is_last then "└── " else "├── " in
  let child_prefix = prefix ^ if is_last then "    " else "│   " in
  match value with
  | VDict pairs ->
      let visible_pairs = visible_pairs_from_dict pairs in
      if visible_pairs = [] then
        [prefix ^ branch ^ label ^ ": {}"]
      else
        (prefix ^ branch ^ label) :: render_tree_children child_prefix visible_pairs
  | VList items when items = [] ->
      [prefix ^ branch ^ label ^ ": []"]
  | VList items when list_items_are_simple items ->
      [prefix ^ branch ^ label ^ ": " ^ Utils.value_to_string value]
  | VList items ->
      let indexed_items =
        List.mapi (fun index (_, item) -> (Printf.sprintf "[%d]" index, item)) items
      in
      (prefix ^ branch ^ label) :: render_tree_children child_prefix indexed_items
  | _ ->
      [prefix ^ branch ^ label ^ ": " ^ Utils.value_to_string value]

and render_tree_children prefix entries =
  let rec aux acc = function
    | [] -> acc
    | [entry] -> acc @ render_tree_entry prefix true entry
    | entry :: rest ->
        let acc = acc @ render_tree_entry prefix false entry in
        aux acc rest
  in
  aux [] entries

let pretty_print_tree pairs =
  let visible_pairs = visible_pairs_from_dict pairs in
  let title, children =
    match List.assoc_opt "kind" visible_pairs with
    | Some (VString s) ->
        let child_pairs = List.filter (fun (k, _) -> k <> "kind") visible_pairs in
        (s, child_pairs)
    | _ -> ("dict", visible_pairs)
  in
  match children with
  | [] -> title
  | _ -> String.concat "\n" (title :: render_tree_children "" children)

(** Internal helper for recursive pretty formatting with indentation *)
let rec pretty_format ?(max_depth=5) ?(indent="") v =
  match v with
  | VDict pairs ->
      if max_depth <= 0 then Utils.value_to_string v
      else if pairs = [] then "{}" else
      let visible_pairs = visible_pairs_from_dict pairs in
      if visible_pairs = [] then "{}" else
      let next_indent = indent ^ "  " in
      let lines = List.map (fun (k, v) ->
        Printf.sprintf "%s`%s`: %s" next_indent k (pretty_format ~max_depth:(max_depth - 1) ~indent:next_indent v)
      ) visible_pairs in
      "{\n" ^ String.concat ",\n" lines ^ "\n" ^ indent ^ "}"
  | VList items ->
      if max_depth <= 0 then Utils.value_to_string v
      else if items = [] then "[]" else
      let next_indent = indent ^ "  " in
      let lines = List.map (fun (_, v) ->
         pretty_format ~max_depth:(max_depth - 1) ~indent:next_indent v
      ) items in
       if list_items_are_simple items then Utils.value_to_string v
        else "[\n" ^ indent ^ "  " ^ String.concat (",\n" ^ indent ^ "  ") lines ^ "\n" ^ indent ^ "]"
    | other -> Utils.value_to_string other

and pretty_print_visual_metadata pairs =
  let visible_pairs =
    pairs
    |> List.filter (fun (k, _) -> k <> "_display_keys")
  in
  let class_name =
    match List.assoc_opt "class" visible_pairs with
    | Some (VString s) -> s
    | _ -> "plot"
  in
  let body_pairs =
    visible_pairs
    |> List.filter (fun (k, _) -> k <> "class")
  in
  if body_pairs = [] then
    Printf.sprintf "%s {}\n" class_name
  else
    let display_keys =
      match display_keys_from_pairs pairs with
      | Some keys ->
          let body_key_set =
            List.fold_left (fun acc (k, _) -> String_set.add k acc) String_set.empty body_pairs
          in
          List.filter (fun key -> String_set.mem key body_key_set) keys
      | None -> List.map fst body_pairs
    in
    let display_key_set =
      List.fold_left (fun acc key -> String_set.add key acc) String_set.empty display_keys
    in
    let filtered_body_pairs =
      List.filter (fun (k, _) -> String_set.mem k display_key_set) body_pairs
    in
    let body =
      pretty_format
        (VDict (filtered_body_pairs @ [("_display_keys", VList (List.map (fun k -> (None, VString k)) display_keys))]))
    in
    Printf.sprintf "%s %s\n" class_name body

(** Pretty-print any value for REPL display *)
let pretty_print_value v =
  match v with
  | VDataFrame df -> pretty_print_dataframe df
  | VError err -> pretty_print_error err
  | VPipeline p -> pretty_print_pipeline p
  | VDict pairs ->
      let is_summary = List.mem_assoc "class" pairs && List.assoc "class" pairs = VString "summary" in
      let is_visual_metadata =
        List.mem_assoc "class" pairs
        && is_visual_metadata_class (List.assoc "class" pairs)
      in
      let is_explain_tree =
        List.mem_assoc "kind" pairs
        && Option.is_some (display_keys_from_pairs pairs)
      in
      let has_kind = List.mem_assoc "kind" pairs in
      let is_large = List.length pairs > 5 in
      let has_nested = List.exists (fun (_, v) -> match v with VDict _ | VList _ | VVector _ -> true | _ -> false) pairs in
      if is_summary then
        pretty_print_summary pairs
      else if is_visual_metadata then
        pretty_print_visual_metadata pairs
      else if is_explain_tree then
        pretty_print_tree pairs ^ "\n"
      else if has_kind || is_large || has_nested then
        pretty_format v ^ "\n"
      else
        Utils.value_to_string v ^ "\n"
  | VList _ -> pretty_format v ^ "\n"
  | VNA _ -> ""
  | other -> Utils.value_to_string other ^ "\n"

(** Register pretty_print as a builtin function *)
(*
--# Pretty-print a value
--#
--# Prints a formatted representation of a value. DataFrames are printed as tables.
--#
--# @name pretty_print
--# @param x :: Any The value to print.
--# @return :: Null
--# @example
--#   pretty_print(df)
--# @family core
--# @seealso print
--# @export
*)
let register env =
  Env.add "pretty_print"
    (make_builtin ~name:"pretty_print" 1 (fun args _env ->
      match args with
      | [v] ->
          print_string (pretty_print_value v);
          (VNA NAGeneric)
      | _ -> Error.arity_error_named "pretty_print" 1 (List.length args)
    ))
    env