1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
open Ast

(*
--# Two-argument arctangent
--#
--# Compute `atan2(y, x)` with quadrant-aware angle.
--#
--# @name atan2
--# @param y :: Number | List | Vector | NDArray Y coordinate(s).
--# @param x :: Number Scalar X coordinate.
--# @return :: Number | Vector | NDArray Computed result (scalar or vectorized).
--# @family math
--# @export
*)

let register env =
  Env.add "atan2"
    (make_builtin_named ~name:"atan2" ~variadic:true 2 (fun named_args _env ->
      match Math_common.get_bool_flag "na_ignore" false named_args with
      | Error e -> e
      | Ok na_ignore ->
          let args = Math_common.positional_args_without [ "na_ignore" ] named_args in
          let scalar_of = function
            | VInt n -> Some (float_of_int n)
            | VFloat f -> Some f
            | _ -> None
          in
          let vectorized_result arr x =
            let result = Array.make (Array.length arr) (VNA NAGeneric) in
            let had_error = ref None in
            Array.iteri (fun i v ->
              if !had_error = None then
                match v with
                | VInt y -> result.(i) <- VFloat (Float.atan2 (float_of_int y) x)
                | VFloat y -> result.(i) <- VFloat (Float.atan2 y x)
                | VNA na_t when na_ignore -> result.(i) <- VNA na_t
                | VNA _ -> had_error := Some (Error.na_value_error "atan2")
                | _ -> had_error := Some (Error.type_error "Function `atan2` requires numeric values."))
              arr;
            match !had_error with Some e -> e | None -> VVector result
          in
          match args with
          | [VInt y; VInt x] -> VFloat (Float.atan2 (float_of_int y) (float_of_int x))
          | [VInt y; VFloat x] -> VFloat (Float.atan2 (float_of_int y) x)
          | [VFloat y; VInt x] -> VFloat (Float.atan2 y (float_of_int x))
          | [VFloat y; VFloat x] -> VFloat (Float.atan2 y x)
          | [VVector arr; x_val] ->
              (match x_val with
               | VNA na_t when na_ignore -> VNA na_t
               | VNA _ -> Error.na_value_error "atan2"
               | _ ->
                   (match scalar_of x_val with
                    | None -> Error.type_error "Function `atan2` expects numeric arguments."
                    | Some x -> vectorized_result arr x))
          | [VList items; x_val] ->
              (match x_val with
               | VNA na_t when na_ignore -> VNA na_t
               | VNA _ -> Error.na_value_error "atan2"
               | _ ->
                   (match scalar_of x_val with
                    | None -> Error.type_error "Function `atan2` expects numeric arguments."
                    | Some x -> vectorized_result (Array.of_list (List.map snd items)) x))
          | [VNDArray arr; x_val] ->
              (match x_val with
               | VNA na_t when na_ignore -> VNA na_t
               | VNA _ -> Error.na_value_error "atan2"
               | _ ->
                   (match scalar_of x_val with
                    | None -> Error.type_error "Function `atan2` expects numeric arguments."
                    | Some x ->
                        VNDArray { shape = arr.shape; data = Array.map (fun y -> Float.atan2 y x) arr.data }))
          | [VNA na_t; _] when na_ignore -> VNA na_t
          | [VInt _; VNA na_t] when na_ignore -> VNA na_t
          | [VFloat _; VNA na_t] when na_ignore -> VNA na_t
          | [VNA _; _]
          | [VInt _; VNA _]
          | [VFloat _; VNA _] ->
              Error.na_value_error "atan2"
          | [_; _] -> Error.type_error "Function `atan2` expects numeric arguments."
          | _ -> Error.arity_error_named "atan2" 2 (List.length args)
    )) env