1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
open Ast

(*
--# Arrange rows
--#
--# Sorts a DataFrame by a column. Use "desc" for descending order.
--#
--# @name arrange
--# @param df :: DataFrame The input DataFrame.
--# @param col :: Symbol The column to sort by.
--# @param direction :: String (Optional) "asc" (default) or "desc".
--# @return :: DataFrame The sorted DataFrame.
--# @example
--#   arrange(mtcars, $mpg)
--#   arrange(mtcars, $mpg, "desc")
--# @family colcraft
--# @seealso filter, group_by
--# @export
*)
let register env =
  Env.add "arrange"
    (make_builtin ~name:"arrange" ~variadic:true 2 (fun args _env ->
      match args with
      | [VDataFrame df; col_val] | [VDataFrame df; col_val; VString "asc"] ->
          (match Utils.extract_column_name col_val with
           | None -> Error.type_error "Function `arrange` expects a $column reference."
           | Some col_name ->
              if not (Arrow_table.has_column df.arrow_table col_name) then
                Error.make_error KeyError (Printf.sprintf "Column `%s` not found in DataFrame." col_name)
              else
                (match Arrow_compute.sort_by_column df.arrow_table col_name true with
                 | Some new_table ->
                   VDataFrame { arrow_table = new_table; group_keys = df.group_keys }
                 | None ->
                   let col = match Arrow_table.get_column df.arrow_table col_name with
                     | Some c -> c | None -> Arrow_table.NAColumn (Arrow_table.num_rows df.arrow_table) in
                   let col_values = Arrow_bridge.column_to_values col in
                   let nrows = Arrow_table.num_rows df.arrow_table in
                   let indices = Array.init nrows (fun i -> i) in
                   let compare_values a b =
                     match (a, b) with
                     | (VInt x, VInt y) -> compare x y
                     | (VFloat x, VFloat y) -> compare x y
                      | (VString x, VString y) -> String.compare x y
                      | (VBool x, VBool y) -> compare x y
                      | (VDate x, VDate y) -> compare x y
                      | (VDatetime (x, _), VDatetime (y, _)) -> Int64.compare x y
                      | (VFactor (x, _, _), VFactor (y, _, _)) -> compare x y
                      | (VNA _, _) -> 1
                      | (_, VNA _) -> -1
                     | _ -> 0
                   in
                   Array.stable_sort (fun i j -> compare_values col_values.(i) col_values.(j)) indices;
                   let new_table = Arrow_compute.sort_by_indices df.arrow_table indices in
                   VDataFrame { arrow_table = new_table; group_keys = df.group_keys }))
      | [VDataFrame df; col_val; VString "desc"] ->
          (match Utils.extract_column_name col_val with
           | None -> Error.type_error "Function `arrange` expects a $column reference."
           | Some col_name ->
              if not (Arrow_table.has_column df.arrow_table col_name) then
                Error.make_error KeyError (Printf.sprintf "Column `%s` not found in DataFrame." col_name)
              else
                (match Arrow_compute.sort_by_column df.arrow_table col_name false with
                 | Some new_table ->
                   VDataFrame { arrow_table = new_table; group_keys = df.group_keys }
                 | None ->
                   let col = match Arrow_table.get_column df.arrow_table col_name with
                     | Some c -> c | None -> Arrow_table.NAColumn (Arrow_table.num_rows df.arrow_table) in
                   let col_values = Arrow_bridge.column_to_values col in
                   let nrows = Arrow_table.num_rows df.arrow_table in
                   let indices = Array.init nrows (fun i -> i) in
                   let compare_values a b =
                     match (a, b) with
                     | (VInt x, VInt y) -> compare y x
                     | (VFloat x, VFloat y) -> compare y x
                      | (VString x, VString y) -> String.compare y x
                      | (VBool x, VBool y) -> compare y x
                      | (VDate x, VDate y) -> compare y x
                      | (VDatetime (x, _), VDatetime (y, _)) -> Int64.compare y x
                      | (VFactor (x, _, _), VFactor (y, _, _)) -> compare y x
                      | (VNA _, _) -> 1
                      | (_, VNA _) -> -1
                     | _ -> 0
                   in
                   Array.stable_sort (fun i j -> compare_values col_values.(i) col_values.(j)) indices;
                   let new_table = Arrow_compute.sort_by_indices df.arrow_table indices in
                   VDataFrame { arrow_table = new_table; group_keys = df.group_keys }))
      | [VDataFrame _; _; VString dir] ->
          Error.value_error (Printf.sprintf "Function `arrange` direction must be \"asc\" or \"desc\", got \"%s\"." dir)
      | [VDataFrame _; _; _] ->
          Error.type_error "Function `arrange` expects a $column reference."
      | [_; _] | [_; _; _] -> Error.type_error "Function `arrange` expects a DataFrame as first argument."
      | _ -> Error.make_error ArityError "Function `arrange` takes 2 or 3 arguments."
     ))
     env