1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
open Ast
let get_bool_flag name default named_args =
match List.find_opt (fun (n, _) -> n = Some name) named_args with
| Some (_, VBool b) -> Ok b
| Some (_, v) ->
Error
(Error.type_error
(Printf.sprintf "Flag `%s` must be Bool, but received %s." name
(Utils.type_name v)))
| None -> Ok default
(** Strip selected named arguments and return the remaining positional values
in their original order. *)
let positional_args_without names named_args =
List.filter
(fun (name, _) ->
match name with
| Some n -> not (List.mem n names)
| None -> true)
named_args
|> List.map snd
let optional_named_arg name named_args =
match List.find_opt (fun (n, _) -> n = Some name) named_args with
| Some (_, VNA _) | None -> None
| Some (_, v) -> Some v
(** Shared numeric-unary mapper for math builtins.
- [fname] is the user-facing function name for error messages.
- [expects] describes the accepted input shape in arity/type errors.
- [na_ignore] preserves NA inputs/slots instead of failing on them.
- [f] is the numeric transform applied to concrete float values.
Returns either a transformed scalar/vector/ndarray or a structured error. *)
let map_numeric_unary ~fname ?(expects = "numeric input") ?(na_ignore = false) f =
function
| [VInt n] -> VFloat (f (float_of_int n))
| [VFloat x] -> VFloat (f x)
| [VVector arr] ->
let out = Array.make (Array.length arr) (VNA NAGeneric) in
let type_err = ref None in
let na_count = ref 0 in
Array.iteri
(fun i v ->
match v with
| VInt n -> if !type_err = None && (!na_count = 0 || na_ignore) then out.(i) <- VFloat (f (float_of_int n))
| VFloat x -> if !type_err = None && (!na_count = 0 || na_ignore) then out.(i) <- VFloat (f x)
| VNA na_t ->
na_count := !na_count + 1;
if na_ignore then out.(i) <- VNA na_t
| _ ->
if !type_err = None then
type_err := Some (Error.type_error (Printf.sprintf "Function `%s` requires numeric values." fname))
) arr;
(match !type_err with
| Some e -> e
| None when !na_count > 0 && not na_ignore -> Error.na_value_error ~na_count:!na_count fname
| None -> VVector out)
| [VNDArray arr] -> VNDArray { shape = arr.shape; data = Array.map f arr.data }
| [VNA na_t] when na_ignore -> VNA na_t
| [VNA _] -> Error.na_value_error fname
| [_] -> Error.type_error (Printf.sprintf "Function `%s` expects %s." fname expects)
| args -> Error.arity_error_named fname 1 (List.length args)
let map_numeric_unary_named ~fname ?expects f named_args =
match get_bool_flag "na_ignore" false named_args with
| Error e -> e
| Ok na_ignore ->
let args = positional_args_without [ "na_ignore" ] named_args in
map_numeric_unary ~fname ?expects ~na_ignore f args