1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
open Ast

(*
--# Trimmed mean
--#
--# Compute mean after trimming both tails by fraction.
--#
--# @name trimmed_mean
--# @param x :: Vector | List Numeric input.
--# @param trim :: Float Trim proportion in [0, 0.5).
--# @param na_rm :: Bool = false Remove NA values first.
--# @param weights :: Vector[Float] | List[Float] = NA Optional non-negative observation weights.
--# @return :: Number | Vector Computed result (scalar or vectorized).
--# @family stats
--# @export
*)

let numeric_values ~label ~na_rm v =
  let vals =
    match v with
    | VVector arr -> Ok (Array.to_list arr)
    | VList items -> Ok (List.map snd items)
    | VNA _ -> Error (Error.na_value_error ~na_rm:true label)
    | _ -> Error (Error.type_error (Printf.sprintf "Function `%s` expects a numeric List or Vector." label))
  in
  match vals with
  | Error e -> Error e
  | Ok vals ->
      let rec go acc = function
        | [] -> Ok (List.rev acc)
        | VInt n :: tl -> go (float_of_int n :: acc) tl
        | VFloat f :: tl -> go (f :: acc) tl
        | VNA _ :: tl when na_rm -> go acc tl
        | VNA _ :: _ -> Error (Error.na_value_error ~na_rm:true label)
        | _ -> Error (Error.type_error (Printf.sprintf "Function `%s` requires numeric values." label))
      in
      go [] vals

let register env =
  Env.add "trimmed_mean" (make_builtin_named ~name:"trimmed_mean" ~variadic:true 2 (fun named_args _ ->
    match Math_common.get_bool_flag "na_rm" false named_args with
    | Error e -> e
    | Ok na_rm ->
        let weight_arg = Math_common.optional_named_arg "weights" named_args in
        let args = Math_common.positional_args_without ["na_rm"; "weights"] named_args in
        let trim_of = function VFloat f -> Some f | VInt i -> Some (float_of_int i) | _ -> None in
        match args with
        | [x; t] ->
            (match trim_of t with
             | None -> Error.type_error "Function `trimmed_mean` expects (x, trim) where trim is numeric."
             | Some trim when trim < 0.0 || trim >= 0.5 -> Error.value_error "Function `trimmed_mean` expects trim in [0, 0.5)."
             | Some trim ->
                 (match weight_arg with
                  | Some weight_v ->
                      (match Math_utils.extract_numeric_array_with_weights ~label:"trimmed_mean" ~na_rm x weight_v with
                       | Error e -> e
                       | Ok (xs, ws) ->
                           (match Math_utils.weighted_quantile_array xs ws trim,
                                  Math_utils.weighted_quantile_array xs ws (1.0 -. trim) with
                            | Some lo, Some hi ->
                                let kept = ref [] in
                                let kept_w = ref [] in
                                for i = 0 to Array.length xs - 1 do
                                  if xs.(i) >= lo && xs.(i) <= hi then begin
                                    kept := xs.(i) :: !kept;
                                    kept_w := ws.(i) :: !kept_w
                                  end
                                done;
                                let kept = Array.of_list (List.rev !kept) in
                                let kept_w = Array.of_list (List.rev !kept_w) in
                                if Array.length kept = 0 then VNA NAFloat
                                else
                                  (match Math_utils.weighted_mean_array kept kept_w with
                                   | Some v -> VFloat v
                                   | None -> VNA NAFloat)
                            | _ -> VNA NAFloat))
                  | None ->
                      (match numeric_values ~label:"trimmed_mean" ~na_rm x with
                       | Error e -> e
                       | Ok [] -> VNA NAFloat
                       | Ok xs ->
                           let arr = Array.of_list xs in
                           let n = Array.length arr in
                           let k = int_of_float (Float.floor (trim *. float_of_int n)) in
                           Array.sort compare arr;
                           let kept = Array.sub arr k (n - (2 * k)) in
                           VFloat (Array.fold_left ( +. ) 0.0 kept /. float_of_int (Array.length kept)))))
        | _ -> Error.arity_error_named "trimmed_mean" 2 (List.length args))) env