1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
open Ast

type mode = Repl | Strict

let mode_of_string = function
  | "repl" -> Some Repl
  | "strict" -> Some Strict
  | _ -> None

let rec contains_type_var = function
  | TVar _ -> true
  | TList (Some t) -> contains_type_var t
  | TDict (Some k, Some v) -> contains_type_var k || contains_type_var v
  | TTuple ts -> List.exists contains_type_var ts
  | TDataFrame (Some schema) -> contains_type_var schema
  | _ -> false

let rec collect_type_vars = function
  | TVar name -> [name]
  | TList (Some t) -> collect_type_vars t
  | TDict (Some k, Some v) -> collect_type_vars k @ collect_type_vars v
  | TTuple ts -> List.concat_map collect_type_vars ts
  | TDataFrame (Some schema) -> collect_type_vars schema
  | _ -> []

let make_type_error ?location message =
  { code = Ast.TypeError; message; context = []; location; na_count = 0 }

let validate_lambda ?location name (l : lambda) =
  let all_params_annotated = List.for_all Option.is_some l.param_types in
  if not all_params_annotated then
    Error (make_type_error ?location (Printf.sprintf "Strict mode: top-level function '%s' must annotate all parameter types." name))
  else if Option.is_none l.return_type then
    Error (make_type_error ?location (Printf.sprintf "Strict mode: top-level function '%s' must annotate a return type." name))
  else
    (* Collect all type variables used in parameter types and return type *)
    let used_vars =
      let param_vars = List.concat_map (function Some t -> collect_type_vars t | None -> []) l.param_types in
      let return_vars = match l.return_type with Some t -> collect_type_vars t | None -> [] in
      List.sort_uniq String.compare (param_vars @ return_vars)
    in
    if used_vars <> [] && l.generic_params = [] then
      Error (make_type_error ?location (Printf.sprintf "Strict mode: top-level function '%s' uses generic type variables %s but declares none. Use syntax like \\<T, U>(...) -> ..."
        name (String.concat ", " used_vars)))
    else
      (* Check that all used type variables are declared *)
      let undeclared = List.filter (fun v -> not (List.mem v l.generic_params)) used_vars in
      if undeclared <> [] then
        Error (make_type_error ?location (Printf.sprintf "Strict mode: top-level function '%s' uses undeclared type variables: %s"
          name (String.concat ", " undeclared)))
      else
        Ok ()

let validate_program ~(mode : mode) (program : program) =
  match mode with
  | Repl -> Ok ()
  | Strict ->
      let rec go = function
        | [] -> Ok ()
        | stmt :: rest ->
            (match stmt.node with
            | Assignment { name; expr; _ } ->
                (match expr.node with
                | Lambda l ->
                    let location = match expr.loc with Some _ as loc -> loc | None -> stmt.loc in
                    (match validate_lambda ?location name l with
                    | Ok () -> go rest
                    | Error _ as e -> e)
                | _ -> go rest)
            | _ -> go rest)
      in
      go program