Files
rose-ash/lib/ocaml/infer.sx
giles 81247eb6ea
Some checks failed
Test, Build, and Deploy / test-build-deploy (push) Failing after 26s
ocaml: phase 5 HM ctor inference for option/result (+7 tests, 351 total)
ocaml-hm-ctor-env registers None/Some : 'a -> 'a option, Ok/Error :
'a -> ('a, 'b) result. :con NAME instantiates the scheme; :pcon NAME
ARG-PATS walks arg patterns through the constructor's arrow type,
unifying each.

Pretty-printer renders 'Int option' and '(Int, 'b) result'.

Examples now infer:
  fun x -> Some x : 'a -> 'a option
  match Some 5 with | None -> 0 | Some n -> n : Int
  fun o -> match o with | None -> 0 | Some n -> n : Int option -> Int
  Ok 1 : (Int, 'b) result
  Error "oops" : ('a, String) result

User type-defs would extend the registry — pending.
2026-05-08 13:05:22 +00:00

447 lines
19 KiB
Plaintext

;; lib/ocaml/infer.sx — Algorithm W type inference for OCaml-on-SX.
;;
;; Consumes lib/guest/hm.sx (algebra) and lib/guest/match.sx (unify) per
;; the Phase 5 sequencing. The kit ships fresh-tv, generalize,
;; instantiate, and substitution composition; this file assembles the
;; lambda / app / let / if rules of Algorithm W against the OCaml AST.
;;
;; Coverage in this slice (atoms + core forms):
;; :int :float :string :char :bool :unit :var :fun :app :let :if
;; :op (with builtin signatures for +, -, *, /, mod, comparisons, &&, ||)
;;
;; Out of scope: pattern matching, tuples, lists (need product/list types
;; first), records, modules, ADTs, let-rec.
;;
;; Inference state:
;; env — dict: name → scheme
;; counter — one-element list (mutable cell) used by hm-fresh-tv
;;
;; Returned value: {:subst S :type T}.
(define ocaml-hm-counter (fn () (list 0)))
(define ocaml-hm-empty-subst (fn () {}))
;; A registry of constructor types so :con / :pcon can be inferred.
;; OCaml's stdlib ctors are seeded here; user type-defs would extend
;; this in a future iteration.
(define ocaml-hm-ctor-env
(fn ()
(let ((a (hm-tv "a")) (b (hm-tv "b")))
(let ((opt-of-a (hm-con "option" (list a)))
(res-of-ab (hm-con "result" (list a b))))
{"None" (hm-scheme (list "a") opt-of-a)
"Some" (hm-scheme (list "a") (hm-arrow a opt-of-a))
"Ok" (hm-scheme (list "a" "b") (hm-arrow a res-of-ab))
"Error" (hm-scheme (list "a" "b") (hm-arrow b res-of-ab))
"true" (hm-monotype (hm-bool))
"false" (hm-monotype (hm-bool))}))))
(define ocaml-hm-builtin-env
(fn ()
(let ((int-int-int (hm-arrow (hm-int) (hm-arrow (hm-int) (hm-int))))
(int-int-bool (hm-arrow (hm-int) (hm-arrow (hm-int) (hm-bool))))
(bool-bool-bool (hm-arrow (hm-bool) (hm-arrow (hm-bool) (hm-bool))))
(str-str-str (hm-arrow (hm-string) (hm-arrow (hm-string) (hm-string))))
(any-any-bool
(let ((a (hm-tv "a")))
(hm-scheme (list "a")
(hm-arrow a (hm-arrow a (hm-bool))))))
(a->a
(let ((a (hm-tv "a")))
(hm-scheme (list "a") (hm-arrow a a)))))
{"+" (hm-monotype int-int-int)
"-" (hm-monotype int-int-int)
"*" (hm-monotype int-int-int)
"/" (hm-monotype int-int-int)
"mod" (hm-monotype int-int-int)
"%" (hm-monotype int-int-int)
"**" (hm-monotype int-int-int)
"<" (hm-monotype int-int-bool)
">" (hm-monotype int-int-bool)
"<=" (hm-monotype int-int-bool)
">=" (hm-monotype int-int-bool)
"=" any-any-bool
"<>" any-any-bool
"&&" (hm-monotype bool-bool-bool)
"||" (hm-monotype bool-bool-bool)
"^" (hm-monotype str-str-str)
"not" (hm-monotype (hm-arrow (hm-bool) (hm-bool)))
"succ" (hm-monotype (hm-arrow (hm-int) (hm-int)))
"pred" (hm-monotype (hm-arrow (hm-int) (hm-int)))
"abs" (hm-monotype (hm-arrow (hm-int) (hm-int)))})))
(define ocaml-infer (fn (expr env counter) nil))
;; Unify two types; raise on failure. The match.sx unify returns nil on
;; failure so we wrap it for clearer errors.
(define ocaml-hm-unify
(fn (t1 t2 subst)
(let ((s2 (unify t1 t2 subst)))
(cond
((= s2 nil)
(error (str "ocaml-infer: cannot unify " t1 " with " t2)))
(else s2)))))
;; Look up name; instantiate scheme to a fresh monotype.
(define ocaml-infer-var
(fn (name env counter)
(cond
((has-key? env name)
(let ((scheme (get env name)))
(let ((t (hm-instantiate scheme counter)))
{:subst {} :type t})))
(else (error (str "ocaml-infer: unbound variable " name))))))
(define ocaml-infer-app
(fn (fn-expr arg-expr env counter)
(let ((r1 (ocaml-infer fn-expr env counter)))
(let ((s1 (get r1 :subst)) (t1 (get r1 :type)))
(let ((env2 (hm-apply-env s1 env)))
(let ((r2 (ocaml-infer arg-expr env2 counter)))
(let ((s2 (get r2 :subst)) (t2 (get r2 :type)))
(let ((tv (hm-fresh-tv counter)))
(let ((s3 (ocaml-hm-unify
(hm-apply s2 t1)
(hm-arrow t2 tv)
(hm-compose s2 s1))))
{:subst s3 :type (hm-apply s3 tv)})))))))))
(define ocaml-infer-fun
(fn (params body env counter)
(cond
((= (len params) 0)
(error "ocaml-infer: fun without params"))
((= (len params) 1)
(let ((tv (hm-fresh-tv counter)))
(let ((env2 (assoc env (first params) (hm-monotype tv))))
(let ((r (ocaml-infer body env2 counter)))
(let ((s (get r :subst)) (t-body (get r :type)))
{:subst s
:type (hm-arrow (hm-apply s tv) t-body)})))))
(else
;; Curry: fun x y -> e ≡ fun x -> fun y -> e
(let ((tv (hm-fresh-tv counter)))
(let ((env2 (assoc env (first params) (hm-monotype tv))))
(let ((r (ocaml-infer-fun (rest params) body env2 counter)))
(let ((s (get r :subst)) (t-rest (get r :type)))
{:subst s
:type (hm-arrow (hm-apply s tv) t-rest)}))))))))
(define ocaml-infer-let
(fn (name params rhs body env counter)
(let ((rhs-expr (cond
((= (len params) 0) rhs)
(else (list :fun params rhs)))))
(let ((r1 (ocaml-infer rhs-expr env counter)))
(let ((s1 (get r1 :subst)) (t1 (get r1 :type)))
(let ((env2 (hm-apply-env s1 env)))
(let ((scheme (hm-generalize t1 env2)))
(let ((env3 (assoc env2 name scheme)))
(let ((r2 (ocaml-infer body env3 counter)))
(let ((s2 (get r2 :subst)) (t2 (get r2 :type)))
{:subst (hm-compose s2 s1) :type t2}))))))))))
(define ocaml-infer-if
(fn (c-ast t-ast e-ast env counter)
(let ((rc (ocaml-infer c-ast env counter)))
(let ((sc (get rc :subst)) (tc (get rc :type)))
(let ((sc2 (ocaml-hm-unify tc (hm-bool) sc)))
(let ((env2 (hm-apply-env sc2 env)))
(let ((rt (ocaml-infer t-ast env2 counter)))
(let ((st (get rt :subst)) (tt (get rt :type)))
(let ((env3 (hm-apply-env st env2)))
(let ((re (ocaml-infer e-ast env3 counter)))
(let ((se (get re :subst)) (te (get re :type)))
(let ((sf (ocaml-hm-unify
(hm-apply se tt)
te
(hm-compose se (hm-compose st sc2)))))
{:subst sf
:type (hm-apply sf te)}))))))))))))
;; Tuple type: (hm-con "*" (list T1 T2 ...)).
(define ocaml-hm-tuple
(fn (types) (hm-con "*" types)))
;; List type: (hm-con "list" (list ELEM)).
(define ocaml-hm-list
(fn (elem) (hm-con "list" (list elem))))
(define ocaml-infer-tuple
(fn (items env counter)
(let ((subst {}) (types (list)))
(begin
(define loop
(fn (xs env-cur)
(when (not (= xs (list)))
(let ((r (ocaml-infer (first xs) env-cur counter)))
(let ((s (get r :subst)) (t (get r :type)))
(begin
(set! subst (hm-compose s subst))
(append! types t)
(loop (rest xs) (hm-apply-env s env-cur))))))))
(loop items env)
{:subst subst
:type (ocaml-hm-tuple
(map (fn (t) (hm-apply subst t)) types))}))))
;; Pattern type inference. Returns {:type T :env ENV2 :subst S} where
;; ENV2 is the original env extended with any names the pattern binds.
;; Constructor patterns aren't supported here yet (need a type-def
;; registry) — :pcon falls through to a fresh tv so they don't break
;; inference of mixed clauses.
(define ocaml-infer-pcon
(fn (name arg-pats env counter)
(cond
((has-key? ocaml-hm-ctors name)
(let ((ctor-type (hm-instantiate (get ocaml-hm-ctors name) counter))
(env-cur env) (subst {}))
(let ((cur-type (list nil)))
(begin
(set-nth! cur-type 0 ctor-type)
(define loop
(fn (xs)
(when (not (= xs (list)))
(let ((rp (ocaml-infer-pat (first xs) env-cur counter)))
(let ((arg-tv (hm-fresh-tv counter))
(res-tv (hm-fresh-tv counter)))
(let ((s1 (ocaml-hm-unify
(nth cur-type 0)
(hm-arrow arg-tv res-tv)
(hm-compose (get rp :subst) subst))))
(let ((s2 (ocaml-hm-unify
(hm-apply s1 arg-tv)
(hm-apply s1 (get rp :type))
s1)))
(begin
(set! subst s2)
(set-nth! cur-type 0 (hm-apply s2 res-tv))
(set! env-cur (get rp :env))
(loop (rest xs))))))))))
(loop arg-pats)
{:type (hm-apply subst (nth cur-type 0))
:env env-cur
:subst subst}))))
(else
(let ((tv (hm-fresh-tv counter)))
{:type tv :env env :subst {}})))))
(define ocaml-infer-pat
(fn (pat env counter)
(let ((tag (nth pat 0)))
(cond
((= tag "pwild")
(let ((tv (hm-fresh-tv counter)))
{:type tv :env env :subst {}}))
((= tag "pvar")
(let ((nm (nth pat 1)) (tv (hm-fresh-tv counter)))
{:type tv :env (assoc env nm (hm-monotype tv)) :subst {}}))
((= tag "plit")
(let ((r (ocaml-infer (nth pat 1) env counter)))
{:type (get r :type) :env env :subst (get r :subst)}))
((= tag "pcons")
(let ((rh (ocaml-infer-pat (nth pat 1) env counter)))
(let ((rt (ocaml-infer-pat (nth pat 2) (get rh :env) counter)))
(let ((s (ocaml-hm-unify
(ocaml-hm-list (get rh :type))
(get rt :type)
(hm-compose (get rt :subst) (get rh :subst)))))
{:type (hm-apply s (ocaml-hm-list (get rh :type)))
:env (get rt :env)
:subst s}))))
((= tag "plist")
(let ((items (rest pat)) (tv (hm-fresh-tv counter)) (env-cur env) (subst {}))
(begin
(define loop
(fn (xs)
(when (not (= xs (list)))
(let ((rp (ocaml-infer-pat (first xs) env-cur counter)))
(let ((s (ocaml-hm-unify
(hm-apply (get rp :subst) tv)
(get rp :type)
(hm-compose (get rp :subst) subst))))
(begin
(set! subst s)
(set! env-cur (get rp :env))
(loop (rest xs))))))))
(loop items)
{:type (hm-apply subst (ocaml-hm-list tv))
:env env-cur
:subst subst})))
((= tag "ptuple")
(let ((items (rest pat)) (env-cur env) (subst {}) (types (list)))
(begin
(define loop
(fn (xs)
(when (not (= xs (list)))
(let ((rp (ocaml-infer-pat (first xs) env-cur counter)))
(begin
(set! subst (hm-compose (get rp :subst) subst))
(append! types (get rp :type))
(set! env-cur (get rp :env))
(loop (rest xs)))))))
(loop items)
{:type (ocaml-hm-tuple
(map (fn (t) (hm-apply subst t)) types))
:env env-cur
:subst subst})))
((= tag "pas")
(let ((rp (ocaml-infer-pat (nth pat 1) env counter)))
(let ((alias (nth pat 2)))
{:type (get rp :type)
:env (assoc (get rp :env) alias (hm-monotype (get rp :type)))
:subst (get rp :subst)})))
((= tag "pcon")
(ocaml-infer-pcon (nth pat 1) (rest (rest pat)) env counter))
(else
(let ((tv (hm-fresh-tv counter)))
{:type tv :env env :subst {}}))))))
(define ocaml-infer-match
(fn (scrut clauses env counter)
(let ((rs (ocaml-infer scrut env counter)))
(let ((s (get rs :subst)) (st (get rs :type)) (result-tv (hm-fresh-tv counter)))
(let ((subst s))
(begin
(define loop
(fn (cs)
(when (not (= cs (list)))
(let ((clause (first cs)))
(let ((ctag (nth clause 0)))
(let ((p (nth clause 1))
(body (cond
((= ctag "case") (nth clause 2))
(else (nth clause 3)))))
(let ((rp (ocaml-infer-pat p (hm-apply-env subst env) counter)))
(let ((s1 (ocaml-hm-unify
(hm-apply (get rp :subst) st)
(get rp :type)
(hm-compose (get rp :subst) subst))))
(let ((rb (ocaml-infer body
(hm-apply-env s1 (get rp :env)) counter)))
(let ((s2 (ocaml-hm-unify
(hm-apply (get rb :subst) result-tv)
(get rb :type)
(hm-compose (get rb :subst) s1))))
(begin
(set! subst s2)
(loop (rest cs)))))))))))))
(loop clauses)
{:subst subst :type (hm-apply subst result-tv)}))))))
(define ocaml-infer-list
(fn (items env counter)
(cond
((= (len items) 0)
{:subst {} :type (ocaml-hm-list (hm-fresh-tv counter))})
(else
(let ((subst {}) (elem-tv (hm-fresh-tv counter)))
(begin
(define loop
(fn (xs env-cur)
(when (not (= xs (list)))
(let ((r (ocaml-infer (first xs) env-cur counter)))
(let ((s (get r :subst)) (t (get r :type)))
(let ((s2 (ocaml-hm-unify
(hm-apply s elem-tv)
t
(hm-compose s subst))))
(begin
(set! subst s2)
(loop (rest xs) (hm-apply-env s2 env-cur)))))))))
(loop items env)
{:subst subst
:type (ocaml-hm-list (hm-apply subst elem-tv))}))))))
(define ocaml-hm-ctors (ocaml-hm-ctor-env))
(set! ocaml-infer
(fn (expr env counter)
(let ((tag (nth expr 0)))
(cond
((= tag "con")
;; (:con NAME) — look up constructor type, instantiate fresh.
(let ((name (nth expr 1)))
(cond
((has-key? ocaml-hm-ctors name)
{:subst {}
:type (hm-instantiate (get ocaml-hm-ctors name) counter)})
(else
;; Unknown ctor — treat as a fresh polymorphic type.
{:subst {} :type (hm-fresh-tv counter)}))))
((= tag "int") {:subst {} :type (hm-int)})
((= tag "float") {:subst {} :type (hm-int)}) ;; treat float as int for now
((= tag "string") {:subst {} :type (hm-string)})
((= tag "char") {:subst {} :type (hm-string)})
((= tag "bool") {:subst {} :type (hm-bool)})
((= tag "unit") {:subst {} :type (hm-con "Unit" (list))})
((= tag "var") (ocaml-infer-var (nth expr 1) env counter))
((= tag "fun") (ocaml-infer-fun (nth expr 1) (nth expr 2) env counter))
((= tag "app") (ocaml-infer-app (nth expr 1) (nth expr 2) env counter))
((= tag "let") (ocaml-infer-let (nth expr 1) (nth expr 2)
(nth expr 3) (nth expr 4) env counter))
((= tag "if") (ocaml-infer-if (nth expr 1) (nth expr 2)
(nth expr 3) env counter))
((= tag "tuple") (ocaml-infer-tuple (rest expr) env counter))
((= tag "list") (ocaml-infer-list (rest expr) env counter))
((= tag "match") (ocaml-infer-match (nth expr 1) (nth expr 2) env counter))
((= tag "neg")
(let ((r (ocaml-infer (nth expr 1) env counter)))
(let ((s (get r :subst)) (t (get r :type)))
(let ((s2 (ocaml-hm-unify t (hm-int) s)))
{:subst s2 :type (hm-int)}))))
((= tag "not")
(let ((r (ocaml-infer (nth expr 1) env counter)))
(let ((s (get r :subst)) (t (get r :type)))
(let ((s2 (ocaml-hm-unify t (hm-bool) s)))
{:subst s2 :type (hm-bool)}))))
((= tag "op")
;; Treat (:op OP L R) as (:app (:app (:var OP) L) R) — same rule.
(ocaml-infer
(list :app (list :app (list :var (nth expr 1)) (nth expr 2)) (nth expr 3))
env counter))
(else (error (str "ocaml-infer: unsupported tag " tag)))))))
;; Top-level convenience: parse + infer + render the type.
(define ocaml-type-of
(fn (src)
(let ((expr (ocaml-parse src))
(env (ocaml-hm-builtin-env))
(counter (ocaml-hm-counter)))
(let ((r (ocaml-infer expr env counter)))
(ocaml-hm-format-type (hm-apply (get r :subst) (get r :type)))))))
;; Pretty-print a type as an OCaml-style string for testing. Only handles
;; the constructors we use: Int / Bool / String / Unit / -> / type-vars.
(define ocaml-hm-format-type
(fn (t)
(cond
((is-var? t) (str "'" (var-name t)))
((is-ctor? t)
(let ((head (ctor-head t)) (args (ctor-args t)))
(cond
((= head "->")
(let ((a (nth args 0)) (b (nth args 1)))
(str
(cond
((and (is-ctor? a) (= (ctor-head a) "->"))
(str "(" (ocaml-hm-format-type a) ")"))
(else (ocaml-hm-format-type a)))
" -> " (ocaml-hm-format-type b))))
((= head "*")
(let ((parts (map ocaml-hm-format-type args)))
(join " * " parts)))
((= head "list")
(let ((elem (ocaml-hm-format-type (nth args 0))))
(str elem " list")))
((= head "option")
(let ((elem (ocaml-hm-format-type (nth args 0))))
(str elem " option")))
((= head "result")
(let ((a (ocaml-hm-format-type (nth args 0)))
(b (ocaml-hm-format-type (nth args 1))))
(str "(" a ", " b ") result")))
(else head))))
(else (str t)))))