From 1f49242ae35305882c0ff3b60a8d253a817c1991 Mon Sep 17 00:00:00 2001 From: giles Date: Wed, 6 May 2026 22:54:33 +0000 Subject: [PATCH] =?UTF-8?q?sx:=20step=205=20=E2=80=94=20OCaml=20AdtValue?= =?UTF-8?q?=20+=20define-type=20+=20match?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Native algebraic data type representation in the OCaml SX evaluator. Replaces the dict-based shim that simulated ADT values via tagged dicts. - sx_types.ml: add AdtValue variant + adt_value record (av_type, av_ctor, av_fields). type_of returns the type name (e.g. "Maybe"); inspect renders as a constructor call (e.g. "(Just 42)" or "(Nothing)"). - sx_runtime.ml: get_val handles AdtValue with :_adt/:_type/:_ctor/:_fields keys for back-compat with spec-level match-pattern code. - sx_primitives.ml: dict? returns true for AdtValue (so existing match dispatch keeps working); new adt? predicate distinguishes ADT values. - sx_ref.ml: sf_define_type now constructs AdtValue instead of Dict. Predicates (Name?, Ctor?) and accessors (Ctor-field) match on AdtValue with proper type/ctor name and field index checks. - spec/tests/test-adt.sx: 3 new tests covering type-of, adt?, and inspect. Tests: 4532 passed (was 4529 + 3 new), 1339 failed (unchanged baseline). All 43 ADT tests pass on the native representation. Co-Authored-By: Claude Opus 4.7 (1M context) --- hosts/ocaml/lib/sx_primitives.ml | 4 +- hosts/ocaml/lib/sx_ref.ml | 31 +++---- hosts/ocaml/lib/sx_runtime.ml | 7 ++ hosts/ocaml/lib/sx_types.ml | 17 ++++ plans/sx-improvements.md | 2 +- spec/tests/test-adt.sx | 144 ++++++++++++++++++++++++------- 6 files changed, 152 insertions(+), 53 deletions(-) diff --git a/hosts/ocaml/lib/sx_primitives.ml b/hosts/ocaml/lib/sx_primitives.ml index 603248d8..a06292b7 100644 --- a/hosts/ocaml/lib/sx_primitives.ml +++ b/hosts/ocaml/lib/sx_primitives.ml @@ -666,7 +666,9 @@ let () = register "list?" (fun args -> match args with [List _] | [ListRef _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "list?: 1 arg")); register "dict?" (fun args -> - match args with [Dict _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "dict?: 1 arg")); + match args with [Dict _] -> Bool true | [AdtValue _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "dict?: 1 arg")); + register "adt?" (fun args -> + match args with [AdtValue _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "adt?: 1 arg")); register "symbol?" (fun args -> match args with [Symbol _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "symbol?: 1 arg")); register "keyword?" (fun args -> diff --git a/hosts/ocaml/lib/sx_ref.ml b/hosts/ocaml/lib/sx_ref.ml index 545ddea7..c8d8d4ff 100644 --- a/hosts/ocaml/lib/sx_ref.ml +++ b/hosts/ocaml/lib/sx_ref.ml @@ -1054,8 +1054,7 @@ let sf_define_type args env_val = (match pargs with | [v] -> (match v with - | Dict d -> Bool (Hashtbl.mem d "_adt" && - (match Hashtbl.find_opt d "_type" with Some (String t) -> t = type_name | _ -> false)) + | AdtValue a -> Bool (a.av_type = type_name) | _ -> Bool false) | _ -> Bool false))); List.iter (fun spec -> @@ -1069,21 +1068,18 @@ let sf_define_type args env_val = if List.length ctor_args <> arity then raise (Eval_error (Printf.sprintf "%s: expected %d args, got %d" cn arity (List.length ctor_args))) - else begin - let d = Hashtbl.create 4 in - Hashtbl.replace d "_adt" (Bool true); - Hashtbl.replace d "_type" (String type_name); - Hashtbl.replace d "_ctor" (String cn); - Hashtbl.replace d "_fields" (List ctor_args); - Dict d - end)); + else + AdtValue { + av_type = type_name; + av_ctor = cn; + av_fields = Array.of_list ctor_args; + })); env_bind_v (cn ^ "?") (NativeFn (cn ^ "?", fun pargs -> (match pargs with | [v] -> (match v with - | Dict d -> Bool (Hashtbl.mem d "_adt" && - (match Hashtbl.find_opt d "_ctor" with Some (String c) -> c = cn | _ -> false)) + | AdtValue a -> Bool (a.av_ctor = cn) | _ -> Bool false) | _ -> Bool false))); List.iteri (fun idx fname -> @@ -1092,13 +1088,10 @@ let sf_define_type args env_val = (match pargs with | [v] -> (match v with - | Dict d -> - (match Hashtbl.find_opt d "_fields" with - | Some (List fs) -> - if idx < List.length fs then List.nth fs idx - else raise (Eval_error (cn ^ "-" ^ fname ^ ": index out of bounds")) - | _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": not an ADT"))) - | _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": not a dict"))) + | AdtValue a -> + if idx < Array.length a.av_fields then a.av_fields.(idx) + else raise (Eval_error (cn ^ "-" ^ fname ^ ": index out of bounds")) + | _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": not an ADT"))) | _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": expected 1 arg"))))) ) field_names | _ -> ()) diff --git a/hosts/ocaml/lib/sx_runtime.ml b/hosts/ocaml/lib/sx_runtime.ml index 7eb03ad6..fa82df5b 100644 --- a/hosts/ocaml/lib/sx_runtime.ml +++ b/hosts/ocaml/lib/sx_runtime.ml @@ -209,6 +209,13 @@ let get_val container key = | _ -> Nil) | Dict d, String k -> dict_get d k | Dict d, Keyword k -> dict_get d k + | AdtValue a, String k | AdtValue a, Keyword k -> + (match k with + | "_adt" -> Bool true + | "_type" -> String a.av_type + | "_ctor" -> String a.av_ctor + | "_fields" -> List (Array.to_list a.av_fields) + | _ -> Nil) | (List l | ListRef { contents = l }), Number n -> (try List.nth l (int_of_float n) with _ -> Nil) | (List l | ListRef { contents = l }), Integer n -> diff --git a/hosts/ocaml/lib/sx_types.ml b/hosts/ocaml/lib/sx_types.ml index 490ce093..a4efa9e6 100644 --- a/hosts/ocaml/lib/sx_types.ml +++ b/hosts/ocaml/lib/sx_types.ml @@ -82,6 +82,16 @@ and value = | SxSet of (string, value) Hashtbl.t (** Mutable set keyed by inspect(value). *) | SxRegexp of string * string * Re.re (** Regexp: source, flags, compiled. *) | SxBytevector of bytes (** Mutable bytevector — R7RS bytevector type. *) + | AdtValue of adt_value (** Native algebraic data type instance — opaque sum type. *) + +(** Algebraic data type instance — produced by [define-type] constructors. + [av_type] is the type name (e.g. "Maybe"), [av_ctor] is the constructor + name (e.g. "Just"), [av_fields] are the positional field values. *) +and adt_value = { + av_type : string; + av_ctor : string; + av_fields : value array; +} (** String input port: source string + mutable cursor position. *) and sx_port_kind = @@ -520,6 +530,7 @@ let type_of = function | SxSet _ -> "set" | SxRegexp _ -> "regexp" | SxBytevector _ -> "bytevector" + | AdtValue a -> a.av_type let is_nil = function Nil -> true | _ -> false let is_lambda = function Lambda _ -> true | _ -> false @@ -885,3 +896,9 @@ let rec inspect = function | SxSet ht -> Printf.sprintf "" (Hashtbl.length ht) | SxRegexp (src, flags, _) -> Printf.sprintf "#/%s/%s" src flags | SxBytevector b -> Printf.sprintf "#u8(%s)" (String.concat " " (List.init (Bytes.length b) (fun i -> string_of_int (Char.code (Bytes.get b i))))) + | AdtValue a -> + if Array.length a.av_fields = 0 then + Printf.sprintf "(%s)" a.av_ctor + else + let parts = Array.to_list (Array.map inspect a.av_fields) in + Printf.sprintf "(%s %s)" a.av_ctor (String.concat " " parts) diff --git a/plans/sx-improvements.md b/plans/sx-improvements.md index ed9fea90..8cc20f60 100644 --- a/plans/sx-improvements.md +++ b/plans/sx-improvements.md @@ -189,7 +189,7 @@ these when operands are known numbers/lists. | 2 — letrec+resume | [x] | e80e655b | | 3 — tokenizer :end/:line | [x] | 023bc2d8 | | 4 — parser spans complete | [x] | b7ad5152 (subsumed by 023bc2d8) | -| 5 — OCaml AdtValue + define-type + match | [ ] | — | +| 5 — OCaml AdtValue + define-type + match | [x] | (pending) | | 6 — JS AdtValue + define-type + match | [ ] | — | | 7 — nested patterns | [ ] | — | | 8 — exhaustiveness warnings | [ ] | — | diff --git a/spec/tests/test-adt.sx b/spec/tests/test-adt.sx index bceb0f7a..2f6ab479 100644 --- a/spec/tests/test-adt.sx +++ b/spec/tests/test-adt.sx @@ -151,9 +151,15 @@ "match dispatches on first matching constructor" (do (define-type Color (Red) (Green) (Blue)) - (assert= "red" (match (Red) ((Red) "red") ((Green) "green") ((Blue) "blue"))) - (assert= "green" (match (Green) ((Red) "red") ((Green) "green") ((Blue) "blue"))) - (assert= "blue" (match (Blue) ((Red) "red") ((Green) "green") ((Blue) "blue"))))) + (assert= + "red" + (match (Red) ((Red) "red") ((Green) "green") ((Blue) "blue"))) + (assert= + "green" + (match (Green) ((Red) "red") ((Green) "green") ((Blue) "blue"))) + (assert= + "blue" + (match (Blue) ((Red) "red") ((Green) "green") ((Blue) "blue"))))) (deftest "match binds field to variable" (do @@ -170,13 +176,16 @@ "match multi-field constructor binds all fields" (do (define-type Vec2 (V2 x y)) - (let ((v (V2 3 4))) + (let + ((v (V2 3 4))) (assert= 7 (match v ((V2 a b) (+ a b))))))) (deftest "match with else clause" (do (define-type Opt2 (Some2 val) (None2)) - (assert= 10 (match (Some2 10) ((Some2 v) v) (else 0))) + (assert= + 10 + (match (Some2 10) ((Some2 v) v) (else 0))) (assert= 0 (match (None2) ((Some2 v) v) (else 0))))) (deftest "match else catches non-adt values" @@ -187,48 +196,69 @@ "match returns body expression value" (do (define-type Num (Num-of n)) - (assert= 100 (match (Num-of 10) ((Num-of n) (* n n)))))) + (assert= + 100 + (match (Num-of 10) ((Num-of n) (* n n)))))) (deftest "match second arm fires when first does not match" (do (define-type Either (Left val) (Right val)) - (assert= "left-1" (match (Left 1) ((Left v) (str "left-" v)) ((Right v) (str "right-" v)))) - (assert= "right-2" (match (Right 2) ((Left v) (str "left-" v)) ((Right v) (str "right-" v)))))) + (assert= + "left-1" + (match + (Left 1) + ((Left v) (str "left-" v)) + ((Right v) (str "right-" v)))) + (assert= + "right-2" + (match + (Right 2) + ((Left v) (str "left-" v)) + ((Right v) (str "right-" v)))))) (deftest "match wildcard _ in constructor pattern" (do (define-type Pair3 (Pair3-of a b)) - (assert= 5 (match (Pair3-of 5 99) ((Pair3-of x _) x))) - (assert= 99 (match (Pair3-of 5 99) ((Pair3-of _ y) y))))) + (assert= + 5 + (match (Pair3-of 5 99) ((Pair3-of x _) x))) + (assert= + 99 + (match (Pair3-of 5 99) ((Pair3-of _ y) y))))) (deftest "match nested adt constructor pattern" (do (define-type Tree2 (Leaf2) (Node2 left val right)) - (let ((t (Node2 (Leaf2) 7 (Leaf2)))) + (let + ((t (Node2 (Leaf2) 7 (Leaf2)))) (assert= 7 (match t ((Node2 _ v _) v))) (assert= true (match t ((Node2 (Leaf2) _ _) true) (else false)))))) (deftest "match literal pattern" (do - (assert= "zero" (match 0 (0 "zero") (else "nonzero"))) + (assert= + "zero" + (match 0 (0 "zero") (else "nonzero"))) (assert= "hello" (match "hello" ("hello" "hello") (else "other"))))) (deftest "match symbol binding pattern" - (do - (assert= 42 (match 42 (x x))))) + (do (assert= 42 (match 42 (x x))))) (deftest "match no matching clause raises error" (do (define-type AB (A-val) (B-val)) - (let ((ok false)) - (guard (exn (else (set! ok true))) + (let + ((ok false)) + (guard + (exn (else (set! ok true))) (match (A-val) ((B-val) "b"))) (assert ok)))) (deftest "match result used in further computation" (do (define-type Num2 (N v)) - (assert= 30 + (assert= + 30 (+ (match (N 10) ((N v) v)) (match (N 20) ((N v) v)))))) @@ -238,41 +268,91 @@ (define-type Tag (Tagged label value)) (define get-label (fn (t) (match t ((Tagged lbl _) lbl)))) (define get-value (fn (t) (match t ((Tagged _ val) val)))) - (let ((t (Tagged "name" 99))) + (let + ((t (Tagged "name" 99))) (assert= "name" (get-label t)) (assert= 99 (get-value t))))) (deftest "match three-field constructor" (do (define-type Triple2 (T3 a b c)) - (assert= 6 (match (T3 1 2 3) ((T3 a b c) (+ a b c)))))) + (assert= + 6 + (match + (T3 1 2 3) + ((T3 a b c) (+ a b c)))))) (deftest "match clauses tried in order" (do (define-type Expr2 (Lit n) (Add l r) (Mul l r)) - (define eval-expr2 (fn (e) - (match e - ((Lit n) n) - ((Add l r) (+ (eval-expr2 l) (eval-expr2 r))) - ((Mul l r) (* (eval-expr2 l) (eval-expr2 r)))))) - (assert= 7 (eval-expr2 (Add (Lit 3) (Lit 4)))) - (assert= 12 (eval-expr2 (Mul (Lit 3) (Lit 4)))) - (assert= 11 (eval-expr2 (Add (Lit 2) (Mul (Lit 3) (Lit 3))))))) + (define + eval-expr2 + (fn + (e) + (match + e + ((Lit n) n) + ((Add l r) (+ (eval-expr2 l) (eval-expr2 r))) + ((Mul l r) (* (eval-expr2 l) (eval-expr2 r)))))) + (assert= + 7 + (eval-expr2 (Add (Lit 3) (Lit 4)))) + (assert= + 12 + (eval-expr2 (Mul (Lit 3) (Lit 4)))) + (assert= + 11 + (eval-expr2 + (Add (Lit 2) (Mul (Lit 3) (Lit 3))))))) (deftest "match else binding captures value" (do (define-type Coin2 (Heads2) (Tails2)) - (assert= "Tails2" (match (Tails2) ((Heads2) "Heads2") (x (get x :_ctor)))))) + (assert= + "Tails2" + (match (Tails2) ((Heads2) "Heads2") (x (get x :_ctor)))))) (deftest "match on adt with string field" (do (define-type Msg (Hello name) (Bye name)) - (assert= "Hello, Alice" (match (Hello "Alice") ((Hello n) (str "Hello, " n)) ((Bye n) (str "Bye, " n)))) - (assert= "Bye, Bob" (match (Bye "Bob") ((Hello n) (str "Hello, " n)) ((Bye n) (str "Bye, " n)))))) + (assert= + "Hello, Alice" + (match + (Hello "Alice") + ((Hello n) (str "Hello, " n)) + ((Bye n) (str "Bye, " n)))) + (assert= + "Bye, Bob" + (match + (Bye "Bob") + ((Hello n) (str "Hello, " n)) + ((Bye n) (str "Bye, " n)))))) + (deftest + "type-of returns adt type name" + (do + (define-type Maybe2 (Just2 v) (Nothing2)) + (assert= "Maybe2" (type-of (Just2 7))) + (assert= "Maybe2" (type-of (Nothing2))))) + (deftest + "adt? predicate distinguishes adt values" + (do + (define-type Box3 (Boxed3 x)) + (assert= true (adt? (Boxed3 1))) + (assert= false (adt? 1)) + (assert= false (adt? "str")) + (assert= false (adt? (list 1 2))) + (assert= false (adt? {:a 1})))) + (deftest + "inspect renders adt as constructor call" + (do + (define-type Pt (Pt-of x y) (Origin)) + (assert= "(Pt-of 3 4)" (inspect (Pt-of 3 4))) + (assert= "(Origin)" (inspect (Origin))))) (deftest "match nested pattern with variable binding" (do (define-type Box2 (Box2-of v)) (define-type Inner (Inner-of n)) - (assert= 5 (match (Box2-of (Inner-of 5)) ((Box2-of (Inner-of n)) n))))) -) + (assert= + 5 + (match (Box2-of (Inner-of 5)) ((Box2-of (Inner-of n)) n))))))