1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
open Ast
let is_na_predicate_error = function
| VError { code = NAPredicateError; _ } -> true
| _ -> false
let plural_suffix count =
if count = 1 then "" else "s"
let emit_na_filter_warning na_indices =
match na_indices with
| [] -> ()
| _ ->
let indices = List.rev na_indices in
let count = List.length indices in
let capped_indices = Utils.list_take 50 indices in
let message =
Printf.sprintf
"filter() excluded %d row%s because the predicate evaluated to NA"
count
(plural_suffix count)
in
(* Emit structured diagnostic for the pipeline engine *)
Eval.emit_node_warning {
nw_kind = "NAExcluded";
nw_fn = "filter";
nw_na_count = count;
nw_na_indices = capped_indices;
nw_message = message;
nw_source = WarningOwn;
};
(* Optional interactive stderr reporting *)
if !Eval.show_warnings then
let rendered =
capped_indices |> List.map string_of_int |> String.concat ", "
in
Printf.eprintf
"Warning: %s at row%s %s\n%!"
message
(plural_suffix count)
rendered
type vectorized_predicate = {
keep : bool array;
na : bool array;
}
let min_array_len a b =
min (Array.length a) (Array.length b)
let take_bool_array len arr =
if Array.length arr = len then arr else Array.init len (fun i -> arr.(i))
let false_mask keep na =
let len = min_array_len keep na in
Array.init len (fun i -> (not keep.(i)) && not na.(i))
let na_indices_of_mask mask =
let acc = ref [] in
Array.iteri (fun i is_na ->
if is_na then acc := (i + 1) :: !acc
) mask;
!acc
let vectorized_compare table field scalar op_s =
match Arrow_compute.compare_column_scalar table field scalar op_s with
| None -> None
| Some keep ->
let len = Array.length keep in
let na =
match Arrow_compute.column_null_mask table field with
| Some mask -> take_bool_array len mask
| None -> Array.make len false
in
Some { keep; na }
(** Try to vectorize a filter predicate.
Detects simple patterns like \(row) row.col > scalar and uses
Arrow_compute.compare_column_scalar for zero-copy filtering.
Also handles AND/OR combinations of simple comparisons. *)
let try_vectorize_filter (table : Arrow_table.t) (fn : value)
: vectorized_predicate option =
match fn with
| VLambda { params = [param]; body; _ } ->
let extract_scalar = function
| VInt i -> Some (float_of_int i)
| VFloat f -> Some f
| _ -> None
in
let try_cmp op left right =
let op_name = match op with
| Gt -> Some "gt" | Lt -> Some "lt" | GtEq -> Some "ge"
| LtEq -> Some "le" | Eq -> Some "eq" | _ -> None
in
match op_name with
| None -> None
| Some op_s ->
(* Pattern: row.field op scalar *)
(match left.node, right.node with
| DotAccess { target = { node = Var p; _ }; field }, Value scalar when p = param ->
(match extract_scalar scalar with
| Some sf -> vectorized_compare table field sf op_s
| None -> None)
(* Pattern: scalar op row.field → flip comparison *)
| Value scalar, DotAccess { target = { node = Var p; _ }; field } when p = param ->
let flipped_op = match op_s with
| "gt" -> "lt" | "lt" -> "gt" | "ge" -> "le" | "le" -> "ge"
| other -> other
in
(match extract_scalar scalar with
| Some sf -> vectorized_compare table field sf flipped_op
| None -> None)
| _ -> None)
in
(* Recursively try to vectorize an expression, handling AND/OR *)
let rec try_vectorize_expr expr =
match expr.node with
| UnOp { op = Not; operand } ->
(match try_vectorize_expr operand with
| Some { keep; na } ->
let n = min_array_len keep na in
Some {
keep = Array.init n (fun i -> (not keep.(i)) && not na.(i));
na = take_bool_array n na;
}
| None -> None)
| Call { fn = { node = Var "is_na"; _ };
args = [(None, { node = DotAccess { target = { node = Var p; _ }; field }; _ })] }
when p = param ->
(match Arrow_compute.column_null_mask table field with
| Some keep -> Some { keep; na = Array.make (Array.length keep) false }
| None -> None)
| BinOp { op; left; right } ->
(match op with
| And ->
(* Pattern: predA && predB — intersect boolean masks *)
(match try_vectorize_expr left, try_vectorize_expr right with
| Some left_pred, Some right_pred ->
let n = min_array_len left_pred.keep right_pred.keep in
let left_keep = take_bool_array n left_pred.keep in
let right_keep = take_bool_array n right_pred.keep in
let left_na = take_bool_array n left_pred.na in
let right_na = take_bool_array n right_pred.na in
let should_evaluate_right =
Array.init n (fun i -> left_keep.(i) || left_na.(i))
in
Some {
keep = Array.init n (fun i -> left_keep.(i) && right_keep.(i));
na =
(* Matches interpreter short-circuit AND: left-side NA always
propagates; right-side NA only propagates when the left
side is not definitively false (so the right predicate
would be evaluated). *)
Array.init n (fun i ->
left_na.(i)
|| (right_na.(i) && should_evaluate_right.(i)));
}
| _ -> None)
| Or ->
(* Pattern: predA || predB — union boolean masks *)
(match try_vectorize_expr left, try_vectorize_expr right with
| Some left_pred, Some right_pred ->
let n = min_array_len left_pred.keep right_pred.keep in
let left_keep = take_bool_array n left_pred.keep in
let right_keep = take_bool_array n right_pred.keep in
let left_na = take_bool_array n left_pred.na in
let right_na = take_bool_array n right_pred.na in
let left_false = false_mask left_keep left_na in
Some {
keep =
Array.init n (fun i ->
left_keep.(i) || (left_false.(i) && right_keep.(i)));
na =
(* Matches interpreter short-circuit OR: left-side NA always
propagates; right-side NA only propagates when the left
side is false (so the right predicate would be evaluated). *)
Array.init n (fun i ->
left_na.(i) || (right_na.(i) && left_false.(i)));
}
| _ -> None)
| _ -> try_cmp op left right)
| _ -> None
in
try_vectorize_expr body
| _ -> None
let register ~eval_call ~eval_expr:(_eval_expr : Ast.value Ast.Env.t -> Ast.expr -> Ast.value) ~uses_nse:(_uses_nse : Ast.expr -> bool) ~desugar_nse_expr:(_desugar_nse_expr : Ast.expr -> Ast.expr) env =
(*
--# Filter rows
--#
--# Retains rows that satisfy the predicate function.
--#
--# @name filter
--# @param df :: DataFrame The input DataFrame.
--# @param predicate :: Function A function returning Bool for each row.
--# @return :: DataFrame The filtered DataFrame.
--# @example
--# filter(mtcars, \(row) -> row.mpg > 20)
--# @family colcraft
--# @seealso select, arrange
--# @export
*)
Env.add "filter"
(make_builtin ~name:"filter" 2 (fun args env ->
match args with
| [VDataFrame df; fn] ->
(* Try vectorized path first for simple predicates *)
(match try_vectorize_filter df.arrow_table fn with
| Some { keep; na } ->
emit_na_filter_warning (na_indices_of_mask na);
let new_table = Arrow_compute.filter df.arrow_table keep in
VDataFrame { arrow_table = new_table; group_keys = df.group_keys }
| None ->
(* Fall back to row-by-row evaluation *)
let nrows = Arrow_table.num_rows df.arrow_table in
let keep = Array.make nrows false in
let had_error = ref None in
let na_indices = ref [] in
for i = 0 to nrows - 1 do
if !had_error = None then begin
let row_dict = VDict (Arrow_bridge.row_to_dict df.arrow_table i) in
let result = eval_call env fn [(None, Ast.mk_expr (Value row_dict))] in
match result with
| VBool true -> keep.(i) <- true
| VBool false -> ()
| VNA _ ->
na_indices := (i + 1) :: !na_indices
| VError _ when is_na_predicate_error result ->
na_indices := (i + 1) :: !na_indices
| VError _ as e -> had_error := Some e
| _ -> had_error := Some (make_error TypeError "filter() predicate must return a Bool")
end
done;
(match !had_error with
| Some e -> e
| None ->
emit_na_filter_warning !na_indices;
let new_table = Arrow_compute.filter df.arrow_table keep in
VDataFrame { arrow_table = new_table; group_keys = df.group_keys }))
| [VDataFrame _] -> make_error ArityError "Function `filter` requires a DataFrame and a predicate function."
| [_; _] -> make_error TypeError "Function `filter` expects a DataFrame as first argument."
| _ -> make_error ArityError "Function `filter` takes exactly 2 arguments."
))
env