sx: step 5 — OCaml AdtValue + define-type + match

Native algebraic data type representation in the OCaml SX evaluator.
Replaces the dict-based shim that simulated ADT values via tagged dicts.

- sx_types.ml: add AdtValue variant + adt_value record (av_type, av_ctor,
  av_fields). type_of returns the type name (e.g. "Maybe"); inspect renders
  as a constructor call (e.g. "(Just 42)" or "(Nothing)").
- sx_runtime.ml: get_val handles AdtValue with :_adt/:_type/:_ctor/:_fields
  keys for back-compat with spec-level match-pattern code.
- sx_primitives.ml: dict? returns true for AdtValue (so existing match
  dispatch keeps working); new adt? predicate distinguishes ADT values.
- sx_ref.ml: sf_define_type now constructs AdtValue instead of Dict.
  Predicates (Name?, Ctor?) and accessors (Ctor-field) match on AdtValue
  with proper type/ctor name and field index checks.
- spec/tests/test-adt.sx: 3 new tests covering type-of, adt?, and inspect.

Tests: 4532 passed (was 4529 + 3 new), 1339 failed (unchanged baseline).
All 43 ADT tests pass on the native representation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-06 22:54:33 +00:00
parent b19f2017d0
commit 1f49242ae3
6 changed files with 152 additions and 53 deletions

View File

@@ -666,7 +666,9 @@ let () =
register "list?" (fun args -> register "list?" (fun args ->
match args with [List _] | [ListRef _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "list?: 1 arg")); match args with [List _] | [ListRef _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "list?: 1 arg"));
register "dict?" (fun args -> register "dict?" (fun args ->
match args with [Dict _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "dict?: 1 arg")); match args with [Dict _] -> Bool true | [AdtValue _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "dict?: 1 arg"));
register "adt?" (fun args ->
match args with [AdtValue _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "adt?: 1 arg"));
register "symbol?" (fun args -> register "symbol?" (fun args ->
match args with [Symbol _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "symbol?: 1 arg")); match args with [Symbol _] -> Bool true | [_] -> Bool false | _ -> raise (Eval_error "symbol?: 1 arg"));
register "keyword?" (fun args -> register "keyword?" (fun args ->

View File

@@ -1054,8 +1054,7 @@ let sf_define_type args env_val =
(match pargs with (match pargs with
| [v] -> | [v] ->
(match v with (match v with
| Dict d -> Bool (Hashtbl.mem d "_adt" && | AdtValue a -> Bool (a.av_type = type_name)
(match Hashtbl.find_opt d "_type" with Some (String t) -> t = type_name | _ -> false))
| _ -> Bool false) | _ -> Bool false)
| _ -> Bool false))); | _ -> Bool false)));
List.iter (fun spec -> List.iter (fun spec ->
@@ -1069,21 +1068,18 @@ let sf_define_type args env_val =
if List.length ctor_args <> arity then if List.length ctor_args <> arity then
raise (Eval_error (Printf.sprintf "%s: expected %d args, got %d" raise (Eval_error (Printf.sprintf "%s: expected %d args, got %d"
cn arity (List.length ctor_args))) cn arity (List.length ctor_args)))
else begin else
let d = Hashtbl.create 4 in AdtValue {
Hashtbl.replace d "_adt" (Bool true); av_type = type_name;
Hashtbl.replace d "_type" (String type_name); av_ctor = cn;
Hashtbl.replace d "_ctor" (String cn); av_fields = Array.of_list ctor_args;
Hashtbl.replace d "_fields" (List ctor_args); }));
Dict d
end));
env_bind_v (cn ^ "?") env_bind_v (cn ^ "?")
(NativeFn (cn ^ "?", fun pargs -> (NativeFn (cn ^ "?", fun pargs ->
(match pargs with (match pargs with
| [v] -> | [v] ->
(match v with (match v with
| Dict d -> Bool (Hashtbl.mem d "_adt" && | AdtValue a -> Bool (a.av_ctor = cn)
(match Hashtbl.find_opt d "_ctor" with Some (String c) -> c = cn | _ -> false))
| _ -> Bool false) | _ -> Bool false)
| _ -> Bool false))); | _ -> Bool false)));
List.iteri (fun idx fname -> List.iteri (fun idx fname ->
@@ -1092,13 +1088,10 @@ let sf_define_type args env_val =
(match pargs with (match pargs with
| [v] -> | [v] ->
(match v with (match v with
| Dict d -> | AdtValue a ->
(match Hashtbl.find_opt d "_fields" with if idx < Array.length a.av_fields then a.av_fields.(idx)
| Some (List fs) -> else raise (Eval_error (cn ^ "-" ^ fname ^ ": index out of bounds"))
if idx < List.length fs then List.nth fs idx | _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": not an ADT")))
else raise (Eval_error (cn ^ "-" ^ fname ^ ": index out of bounds"))
| _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": not an ADT")))
| _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": not a dict")))
| _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": expected 1 arg"))))) | _ -> raise (Eval_error (cn ^ "-" ^ fname ^ ": expected 1 arg")))))
) field_names ) field_names
| _ -> ()) | _ -> ())

View File

@@ -209,6 +209,13 @@ let get_val container key =
| _ -> Nil) | _ -> Nil)
| Dict d, String k -> dict_get d k | Dict d, String k -> dict_get d k
| Dict d, Keyword k -> dict_get d k | Dict d, Keyword k -> dict_get d k
| AdtValue a, String k | AdtValue a, Keyword k ->
(match k with
| "_adt" -> Bool true
| "_type" -> String a.av_type
| "_ctor" -> String a.av_ctor
| "_fields" -> List (Array.to_list a.av_fields)
| _ -> Nil)
| (List l | ListRef { contents = l }), Number n -> | (List l | ListRef { contents = l }), Number n ->
(try List.nth l (int_of_float n) with _ -> Nil) (try List.nth l (int_of_float n) with _ -> Nil)
| (List l | ListRef { contents = l }), Integer n -> | (List l | ListRef { contents = l }), Integer n ->

View File

@@ -82,6 +82,16 @@ and value =
| SxSet of (string, value) Hashtbl.t (** Mutable set keyed by inspect(value). *) | SxSet of (string, value) Hashtbl.t (** Mutable set keyed by inspect(value). *)
| SxRegexp of string * string * Re.re (** Regexp: source, flags, compiled. *) | SxRegexp of string * string * Re.re (** Regexp: source, flags, compiled. *)
| SxBytevector of bytes (** Mutable bytevector — R7RS bytevector type. *) | SxBytevector of bytes (** Mutable bytevector — R7RS bytevector type. *)
| AdtValue of adt_value (** Native algebraic data type instance — opaque sum type. *)
(** Algebraic data type instance — produced by [define-type] constructors.
[av_type] is the type name (e.g. "Maybe"), [av_ctor] is the constructor
name (e.g. "Just"), [av_fields] are the positional field values. *)
and adt_value = {
av_type : string;
av_ctor : string;
av_fields : value array;
}
(** String input port: source string + mutable cursor position. *) (** String input port: source string + mutable cursor position. *)
and sx_port_kind = and sx_port_kind =
@@ -520,6 +530,7 @@ let type_of = function
| SxSet _ -> "set" | SxSet _ -> "set"
| SxRegexp _ -> "regexp" | SxRegexp _ -> "regexp"
| SxBytevector _ -> "bytevector" | SxBytevector _ -> "bytevector"
| AdtValue a -> a.av_type
let is_nil = function Nil -> true | _ -> false let is_nil = function Nil -> true | _ -> false
let is_lambda = function Lambda _ -> true | _ -> false let is_lambda = function Lambda _ -> true | _ -> false
@@ -885,3 +896,9 @@ let rec inspect = function
| SxSet ht -> Printf.sprintf "<set:%d>" (Hashtbl.length ht) | SxSet ht -> Printf.sprintf "<set:%d>" (Hashtbl.length ht)
| SxRegexp (src, flags, _) -> Printf.sprintf "#/%s/%s" src flags | SxRegexp (src, flags, _) -> Printf.sprintf "#/%s/%s" src flags
| SxBytevector b -> Printf.sprintf "#u8(%s)" (String.concat " " (List.init (Bytes.length b) (fun i -> string_of_int (Char.code (Bytes.get b i))))) | SxBytevector b -> Printf.sprintf "#u8(%s)" (String.concat " " (List.init (Bytes.length b) (fun i -> string_of_int (Char.code (Bytes.get b i)))))
| AdtValue a ->
if Array.length a.av_fields = 0 then
Printf.sprintf "(%s)" a.av_ctor
else
let parts = Array.to_list (Array.map inspect a.av_fields) in
Printf.sprintf "(%s %s)" a.av_ctor (String.concat " " parts)

View File

@@ -189,7 +189,7 @@ these when operands are known numbers/lists.
| 2 — letrec+resume | [x] | e80e655b | | 2 — letrec+resume | [x] | e80e655b |
| 3 — tokenizer :end/:line | [x] | 023bc2d8 | | 3 — tokenizer :end/:line | [x] | 023bc2d8 |
| 4 — parser spans complete | [x] | b7ad5152 (subsumed by 023bc2d8) | | 4 — parser spans complete | [x] | b7ad5152 (subsumed by 023bc2d8) |
| 5 — OCaml AdtValue + define-type + match | [ ] | | | 5 — OCaml AdtValue + define-type + match | [x] | (pending) |
| 6 — JS AdtValue + define-type + match | [ ] | — | | 6 — JS AdtValue + define-type + match | [ ] | — |
| 7 — nested patterns | [ ] | — | | 7 — nested patterns | [ ] | — |
| 8 — exhaustiveness warnings | [ ] | — | | 8 — exhaustiveness warnings | [ ] | — |

View File

@@ -151,9 +151,15 @@
"match dispatches on first matching constructor" "match dispatches on first matching constructor"
(do (do
(define-type Color (Red) (Green) (Blue)) (define-type Color (Red) (Green) (Blue))
(assert= "red" (match (Red) ((Red) "red") ((Green) "green") ((Blue) "blue"))) (assert=
(assert= "green" (match (Green) ((Red) "red") ((Green) "green") ((Blue) "blue"))) "red"
(assert= "blue" (match (Blue) ((Red) "red") ((Green) "green") ((Blue) "blue"))))) (match (Red) ((Red) "red") ((Green) "green") ((Blue) "blue")))
(assert=
"green"
(match (Green) ((Red) "red") ((Green) "green") ((Blue) "blue")))
(assert=
"blue"
(match (Blue) ((Red) "red") ((Green) "green") ((Blue) "blue")))))
(deftest (deftest
"match binds field to variable" "match binds field to variable"
(do (do
@@ -170,13 +176,16 @@
"match multi-field constructor binds all fields" "match multi-field constructor binds all fields"
(do (do
(define-type Vec2 (V2 x y)) (define-type Vec2 (V2 x y))
(let ((v (V2 3 4))) (let
((v (V2 3 4)))
(assert= 7 (match v ((V2 a b) (+ a b))))))) (assert= 7 (match v ((V2 a b) (+ a b)))))))
(deftest (deftest
"match with else clause" "match with else clause"
(do (do
(define-type Opt2 (Some2 val) (None2)) (define-type Opt2 (Some2 val) (None2))
(assert= 10 (match (Some2 10) ((Some2 v) v) (else 0))) (assert=
10
(match (Some2 10) ((Some2 v) v) (else 0)))
(assert= 0 (match (None2) ((Some2 v) v) (else 0))))) (assert= 0 (match (None2) ((Some2 v) v) (else 0)))))
(deftest (deftest
"match else catches non-adt values" "match else catches non-adt values"
@@ -187,48 +196,69 @@
"match returns body expression value" "match returns body expression value"
(do (do
(define-type Num (Num-of n)) (define-type Num (Num-of n))
(assert= 100 (match (Num-of 10) ((Num-of n) (* n n)))))) (assert=
100
(match (Num-of 10) ((Num-of n) (* n n))))))
(deftest (deftest
"match second arm fires when first does not match" "match second arm fires when first does not match"
(do (do
(define-type Either (Left val) (Right val)) (define-type Either (Left val) (Right val))
(assert= "left-1" (match (Left 1) ((Left v) (str "left-" v)) ((Right v) (str "right-" v)))) (assert=
(assert= "right-2" (match (Right 2) ((Left v) (str "left-" v)) ((Right v) (str "right-" v)))))) "left-1"
(match
(Left 1)
((Left v) (str "left-" v))
((Right v) (str "right-" v))))
(assert=
"right-2"
(match
(Right 2)
((Left v) (str "left-" v))
((Right v) (str "right-" v))))))
(deftest (deftest
"match wildcard _ in constructor pattern" "match wildcard _ in constructor pattern"
(do (do
(define-type Pair3 (Pair3-of a b)) (define-type Pair3 (Pair3-of a b))
(assert= 5 (match (Pair3-of 5 99) ((Pair3-of x _) x))) (assert=
(assert= 99 (match (Pair3-of 5 99) ((Pair3-of _ y) y))))) 5
(match (Pair3-of 5 99) ((Pair3-of x _) x)))
(assert=
99
(match (Pair3-of 5 99) ((Pair3-of _ y) y)))))
(deftest (deftest
"match nested adt constructor pattern" "match nested adt constructor pattern"
(do (do
(define-type Tree2 (Leaf2) (Node2 left val right)) (define-type Tree2 (Leaf2) (Node2 left val right))
(let ((t (Node2 (Leaf2) 7 (Leaf2)))) (let
((t (Node2 (Leaf2) 7 (Leaf2))))
(assert= 7 (match t ((Node2 _ v _) v))) (assert= 7 (match t ((Node2 _ v _) v)))
(assert= true (match t ((Node2 (Leaf2) _ _) true) (else false)))))) (assert= true (match t ((Node2 (Leaf2) _ _) true) (else false))))))
(deftest (deftest
"match literal pattern" "match literal pattern"
(do (do
(assert= "zero" (match 0 (0 "zero") (else "nonzero"))) (assert=
"zero"
(match 0 (0 "zero") (else "nonzero")))
(assert= "hello" (match "hello" ("hello" "hello") (else "other"))))) (assert= "hello" (match "hello" ("hello" "hello") (else "other")))))
(deftest (deftest
"match symbol binding pattern" "match symbol binding pattern"
(do (do (assert= 42 (match 42 (x x)))))
(assert= 42 (match 42 (x x)))))
(deftest (deftest
"match no matching clause raises error" "match no matching clause raises error"
(do (do
(define-type AB (A-val) (B-val)) (define-type AB (A-val) (B-val))
(let ((ok false)) (let
(guard (exn (else (set! ok true))) ((ok false))
(guard
(exn (else (set! ok true)))
(match (A-val) ((B-val) "b"))) (match (A-val) ((B-val) "b")))
(assert ok)))) (assert ok))))
(deftest (deftest
"match result used in further computation" "match result used in further computation"
(do (do
(define-type Num2 (N v)) (define-type Num2 (N v))
(assert= 30 (assert=
30
(+ (+
(match (N 10) ((N v) v)) (match (N 10) ((N v) v))
(match (N 20) ((N v) v)))))) (match (N 20) ((N v) v))))))
@@ -238,41 +268,91 @@
(define-type Tag (Tagged label value)) (define-type Tag (Tagged label value))
(define get-label (fn (t) (match t ((Tagged lbl _) lbl)))) (define get-label (fn (t) (match t ((Tagged lbl _) lbl))))
(define get-value (fn (t) (match t ((Tagged _ val) val)))) (define get-value (fn (t) (match t ((Tagged _ val) val))))
(let ((t (Tagged "name" 99))) (let
((t (Tagged "name" 99)))
(assert= "name" (get-label t)) (assert= "name" (get-label t))
(assert= 99 (get-value t))))) (assert= 99 (get-value t)))))
(deftest (deftest
"match three-field constructor" "match three-field constructor"
(do (do
(define-type Triple2 (T3 a b c)) (define-type Triple2 (T3 a b c))
(assert= 6 (match (T3 1 2 3) ((T3 a b c) (+ a b c)))))) (assert=
6
(match
(T3 1 2 3)
((T3 a b c) (+ a b c))))))
(deftest (deftest
"match clauses tried in order" "match clauses tried in order"
(do (do
(define-type Expr2 (Lit n) (Add l r) (Mul l r)) (define-type Expr2 (Lit n) (Add l r) (Mul l r))
(define eval-expr2 (fn (e) (define
(match e eval-expr2
((Lit n) n) (fn
((Add l r) (+ (eval-expr2 l) (eval-expr2 r))) (e)
((Mul l r) (* (eval-expr2 l) (eval-expr2 r)))))) (match
(assert= 7 (eval-expr2 (Add (Lit 3) (Lit 4)))) e
(assert= 12 (eval-expr2 (Mul (Lit 3) (Lit 4)))) ((Lit n) n)
(assert= 11 (eval-expr2 (Add (Lit 2) (Mul (Lit 3) (Lit 3))))))) ((Add l r) (+ (eval-expr2 l) (eval-expr2 r)))
((Mul l r) (* (eval-expr2 l) (eval-expr2 r))))))
(assert=
7
(eval-expr2 (Add (Lit 3) (Lit 4))))
(assert=
12
(eval-expr2 (Mul (Lit 3) (Lit 4))))
(assert=
11
(eval-expr2
(Add (Lit 2) (Mul (Lit 3) (Lit 3)))))))
(deftest (deftest
"match else binding captures value" "match else binding captures value"
(do (do
(define-type Coin2 (Heads2) (Tails2)) (define-type Coin2 (Heads2) (Tails2))
(assert= "Tails2" (match (Tails2) ((Heads2) "Heads2") (x (get x :_ctor)))))) (assert=
"Tails2"
(match (Tails2) ((Heads2) "Heads2") (x (get x :_ctor))))))
(deftest (deftest
"match on adt with string field" "match on adt with string field"
(do (do
(define-type Msg (Hello name) (Bye name)) (define-type Msg (Hello name) (Bye name))
(assert= "Hello, Alice" (match (Hello "Alice") ((Hello n) (str "Hello, " n)) ((Bye n) (str "Bye, " n)))) (assert=
(assert= "Bye, Bob" (match (Bye "Bob") ((Hello n) (str "Hello, " n)) ((Bye n) (str "Bye, " n)))))) "Hello, Alice"
(match
(Hello "Alice")
((Hello n) (str "Hello, " n))
((Bye n) (str "Bye, " n))))
(assert=
"Bye, Bob"
(match
(Bye "Bob")
((Hello n) (str "Hello, " n))
((Bye n) (str "Bye, " n))))))
(deftest
"type-of returns adt type name"
(do
(define-type Maybe2 (Just2 v) (Nothing2))
(assert= "Maybe2" (type-of (Just2 7)))
(assert= "Maybe2" (type-of (Nothing2)))))
(deftest
"adt? predicate distinguishes adt values"
(do
(define-type Box3 (Boxed3 x))
(assert= true (adt? (Boxed3 1)))
(assert= false (adt? 1))
(assert= false (adt? "str"))
(assert= false (adt? (list 1 2)))
(assert= false (adt? {:a 1}))))
(deftest
"inspect renders adt as constructor call"
(do
(define-type Pt (Pt-of x y) (Origin))
(assert= "(Pt-of 3 4)" (inspect (Pt-of 3 4)))
(assert= "(Origin)" (inspect (Origin)))))
(deftest (deftest
"match nested pattern with variable binding" "match nested pattern with variable binding"
(do (do
(define-type Box2 (Box2-of v)) (define-type Box2 (Box2-of v))
(define-type Inner (Inner-of n)) (define-type Inner (Inner-of n))
(assert= 5 (match (Box2-of (Inner-of 5)) ((Box2-of (Inner-of n)) n))))) (assert=
) 5
(match (Box2-of (Inner-of 5)) ((Box2-of (Inner-of n)) n))))))