diff --git a/lib/ocaml/infer.sx b/lib/ocaml/infer.sx index a5b83c18..212e80ee 100644 --- a/lib/ocaml/infer.sx +++ b/lib/ocaml/infer.sx @@ -171,6 +171,113 @@ :type (ocaml-hm-tuple (map (fn (t) (hm-apply subst t)) types))})))) +;; Pattern type inference. Returns {:type T :env ENV2 :subst S} where +;; ENV2 is the original env extended with any names the pattern binds. +;; Constructor patterns aren't supported here yet (need a type-def +;; registry) — :pcon falls through to a fresh tv so they don't break +;; inference of mixed clauses. +(define ocaml-infer-pat + (fn (pat env counter) + (let ((tag (nth pat 0))) + (cond + ((= tag "pwild") + (let ((tv (hm-fresh-tv counter))) + {:type tv :env env :subst {}})) + ((= tag "pvar") + (let ((nm (nth pat 1)) (tv (hm-fresh-tv counter))) + {:type tv :env (assoc env nm (hm-monotype tv)) :subst {}})) + ((= tag "plit") + (let ((r (ocaml-infer (nth pat 1) env counter))) + {:type (get r :type) :env env :subst (get r :subst)})) + ((= tag "pcons") + (let ((rh (ocaml-infer-pat (nth pat 1) env counter))) + (let ((rt (ocaml-infer-pat (nth pat 2) (get rh :env) counter))) + (let ((s (ocaml-hm-unify + (ocaml-hm-list (get rh :type)) + (get rt :type) + (hm-compose (get rt :subst) (get rh :subst))))) + {:type (hm-apply s (ocaml-hm-list (get rh :type))) + :env (get rt :env) + :subst s})))) + ((= tag "plist") + (let ((items (rest pat)) (tv (hm-fresh-tv counter)) (env-cur env) (subst {})) + (begin + (define loop + (fn (xs) + (when (not (= xs (list))) + (let ((rp (ocaml-infer-pat (first xs) env-cur counter))) + (let ((s (ocaml-hm-unify + (hm-apply (get rp :subst) tv) + (get rp :type) + (hm-compose (get rp :subst) subst)))) + (begin + (set! subst s) + (set! env-cur (get rp :env)) + (loop (rest xs)))))))) + (loop items) + {:type (hm-apply subst (ocaml-hm-list tv)) + :env env-cur + :subst subst}))) + ((= tag "ptuple") + (let ((items (rest pat)) (env-cur env) (subst {}) (types (list))) + (begin + (define loop + (fn (xs) + (when (not (= xs (list))) + (let ((rp (ocaml-infer-pat (first xs) env-cur counter))) + (begin + (set! subst (hm-compose (get rp :subst) subst)) + (append! types (get rp :type)) + (set! env-cur (get rp :env)) + (loop (rest xs))))))) + (loop items) + {:type (ocaml-hm-tuple + (map (fn (t) (hm-apply subst t)) types)) + :env env-cur + :subst subst}))) + ((= tag "pas") + (let ((rp (ocaml-infer-pat (nth pat 1) env counter))) + (let ((alias (nth pat 2))) + {:type (get rp :type) + :env (assoc (get rp :env) alias (hm-monotype (get rp :type))) + :subst (get rp :subst)}))) + (else + ;; :pcon and others — fall through to a fresh tv (sound but loose). + (let ((tv (hm-fresh-tv counter))) + {:type tv :env env :subst {}})))))) + +(define ocaml-infer-match + (fn (scrut clauses env counter) + (let ((rs (ocaml-infer scrut env counter))) + (let ((s (get rs :subst)) (st (get rs :type)) (result-tv (hm-fresh-tv counter))) + (let ((subst s)) + (begin + (define loop + (fn (cs) + (when (not (= cs (list))) + (let ((clause (first cs))) + (let ((ctag (nth clause 0))) + (let ((p (nth clause 1)) + (body (cond + ((= ctag "case") (nth clause 2)) + (else (nth clause 3))))) + (let ((rp (ocaml-infer-pat p (hm-apply-env subst env) counter))) + (let ((s1 (ocaml-hm-unify + (hm-apply (get rp :subst) st) + (get rp :type) + (hm-compose (get rp :subst) subst)))) + (let ((rb (ocaml-infer body + (hm-apply-env s1 (get rp :env)) counter))) + (let ((s2 (ocaml-hm-unify + (hm-apply (get rb :subst) result-tv) + (get rb :type) + (hm-compose (get rb :subst) s1)))) + (begin + (set! subst s2) + (loop (rest cs))))))))))))) + (loop clauses) + {:subst subst :type (hm-apply subst result-tv)})))))) + (define ocaml-infer-list (fn (items env counter) (cond @@ -214,6 +321,7 @@ (nth expr 3) env counter)) ((= tag "tuple") (ocaml-infer-tuple (rest expr) env counter)) ((= tag "list") (ocaml-infer-list (rest expr) env counter)) + ((= tag "match") (ocaml-infer-match (nth expr 1) (nth expr 2) env counter)) ((= tag "neg") (let ((r (ocaml-infer (nth expr 1) env counter))) (let ((s (get r :subst)) (t (get r :type))) diff --git a/lib/ocaml/test.sh b/lib/ocaml/test.sh index 03076b0b..552b33c9 100755 --- a/lib/ocaml/test.sh +++ b/lib/ocaml/test.sh @@ -844,6 +844,18 @@ cat > "$TMPFILE" << 'EPOCHS' (epoch 1706) (eval "(ocaml-run \"List.sort compare [\\\"b\\\"; \\\"a\\\"; \\\"c\\\"]\")") +;; ── HM pattern-match inference ───────────────────────────────── +(epoch 1800) +(eval "(ocaml-type-of \"match 1 with | n -> n + 1\")") +(epoch 1801) +(eval "(ocaml-type-of \"match [1;2] with | [] -> 0 | h :: t -> h\")") +(epoch 1802) +(eval "(ocaml-type-of \"match (1, 2) with | (a, b) -> a + b\")") +(epoch 1803) +(eval "(ocaml-type-of \"fun x -> match x with | 0 -> 0 | n -> n + 1\")") +(epoch 1804) +(eval "(ocaml-type-of \"fun lst -> match lst with | [] -> 0 | h :: _ -> h\")") + EPOCHS OUTPUT=$(timeout 180 "$SX_SERVER" < "$TMPFILE" 2>/dev/null) @@ -1335,6 +1347,13 @@ check 1704 "List.sort descending" '(4 3 1)' check 1705 "List.sort empty" '()' check 1706 "List.sort strings" '("a" "b" "c")' +# ── HM match inference ────────────────────────────────────────── +check 1800 "match int" '"Int"' +check 1801 "match list" '"Int"' +check 1802 "match tuple" '"Int"' +check 1803 "fn match int -> int" '"Int -> Int"' +check 1804 "fn list -> elem" '"Int list -> Int"' + TOTAL=$((PASS + FAIL)) if [ $FAIL -eq 0 ]; then echo "ok $PASS/$TOTAL OCaml-on-SX tests passed" diff --git a/plans/ocaml-on-sx.md b/plans/ocaml-on-sx.md index 5ef6223c..170c85e8 100644 --- a/plans/ocaml-on-sx.md +++ b/plans/ocaml-on-sx.md @@ -365,6 +365,14 @@ the "mother tongue" closure: OCaml → SX → OCaml. This means: _Newest first._ +- 2026-05-08 Phase 5 — HM pattern-matching inference (+5 tests, 344 + total). `ocaml-infer-pat` covers wild, var, lit, cons, list, tuple, + as. `ocaml-infer-match` unifies each clause's pattern type with the + scrutinee, runs the body in the env extended with pattern-bound vars, + and unifies all body types via a fresh result tv. Examples: + `fun lst -> match lst with | [] -> 0 | h :: _ -> h : Int list -> Int`. + Constructor patterns fall through to a fresh tv for now (need a ctor + type registry from `type` decls — pending). - 2026-05-08 Phase 6 — `List.sort` + polymorphic `compare` (+7 tests, 339 total). `compare` is a host primitive that returns -1/0/1 like Stdlib.compare, defers to host SX `<`/`>`. `List.sort` is implemented