;; lib/ocaml/infer.sx — Algorithm W type inference for OCaml-on-SX. ;; ;; Consumes lib/guest/hm.sx (algebra) and lib/guest/match.sx (unify) per ;; the Phase 5 sequencing. The kit ships fresh-tv, generalize, ;; instantiate, and substitution composition; this file assembles the ;; lambda / app / let / if rules of Algorithm W against the OCaml AST. ;; ;; Coverage in this slice (atoms + core forms): ;; :int :float :string :char :bool :unit :var :fun :app :let :if ;; :op (with builtin signatures for +, -, *, /, mod, comparisons, &&, ||) ;; ;; Out of scope: pattern matching, tuples, lists (need product/list types ;; first), records, modules, ADTs, let-rec. ;; ;; Inference state: ;; env — dict: name → scheme ;; counter — one-element list (mutable cell) used by hm-fresh-tv ;; ;; Returned value: {:subst S :type T}. (define ocaml-hm-counter (fn () (list 0))) (define ocaml-hm-empty-subst (fn () {})) ;; A registry of constructor types so :con / :pcon can be inferred. ;; OCaml's stdlib ctors are seeded here; user type-defs would extend ;; this in a future iteration. (define ocaml-hm-ctor-env (fn () (let ((a (hm-tv "a")) (b (hm-tv "b"))) (let ((opt-of-a (hm-con "option" (list a))) (res-of-ab (hm-con "result" (list a b)))) {"None" (hm-scheme (list "a") opt-of-a) "Some" (hm-scheme (list "a") (hm-arrow a opt-of-a)) "Ok" (hm-scheme (list "a" "b") (hm-arrow a res-of-ab)) "Error" (hm-scheme (list "a" "b") (hm-arrow b res-of-ab)) "true" (hm-monotype (hm-bool)) "false" (hm-monotype (hm-bool))})))) ;; Float type isn't in the kit; use a named ctor. (define ocaml-hm-float (fn () (hm-con "Float" (list)))) (define ocaml-hm-builtin-env (fn () (let ((int-int-int (hm-arrow (hm-int) (hm-arrow (hm-int) (hm-int)))) (float-float-float (hm-arrow (ocaml-hm-float) (hm-arrow (ocaml-hm-float) (ocaml-hm-float)))) (int-int-bool (hm-arrow (hm-int) (hm-arrow (hm-int) (hm-bool)))) (bool-bool-bool (hm-arrow (hm-bool) (hm-arrow (hm-bool) (hm-bool)))) (str-str-str (hm-arrow (hm-string) (hm-arrow (hm-string) (hm-string)))) (any-any-bool (let ((a (hm-tv "a"))) (hm-scheme (list "a") (hm-arrow a (hm-arrow a (hm-bool)))))) (a->a (let ((a (hm-tv "a"))) (hm-scheme (list "a") (hm-arrow a a)))) (cons-type (let ((a (hm-tv "a"))) (hm-scheme (list "a") (hm-arrow a (hm-arrow (hm-con "list" (list a)) (hm-con "list" (list a))))))) (concat-type (let ((a (hm-tv "a"))) (hm-scheme (list "a") (hm-arrow (hm-con "list" (list a)) (hm-arrow (hm-con "list" (list a)) (hm-con "list" (list a)))))))) {"+" (hm-monotype int-int-int) "-" (hm-monotype int-int-int) "*" (hm-monotype int-int-int) "/" (hm-monotype int-int-int) "+." (hm-monotype float-float-float) "-." (hm-monotype float-float-float) "*." (hm-monotype float-float-float) "/." (hm-monotype float-float-float) "mod" (hm-monotype int-int-int) "%" (hm-monotype int-int-int) "**" (hm-monotype int-int-int) "<" (hm-monotype int-int-bool) ">" (hm-monotype int-int-bool) "<=" (hm-monotype int-int-bool) ">=" (hm-monotype int-int-bool) "=" any-any-bool "<>" any-any-bool "&&" (hm-monotype bool-bool-bool) "||" (hm-monotype bool-bool-bool) "^" (hm-monotype str-str-str) "::" cons-type "@" concat-type "not" (hm-monotype (hm-arrow (hm-bool) (hm-bool))) "succ" (hm-monotype (hm-arrow (hm-int) (hm-int))) "pred" (hm-monotype (hm-arrow (hm-int) (hm-int))) "abs" (hm-monotype (hm-arrow (hm-int) (hm-int)))}))) (define ocaml-infer (fn (expr env counter) nil)) ;; Unify two types; raise on failure. The match.sx unify returns nil on ;; failure so we wrap it for clearer errors. (define ocaml-hm-unify (fn (t1 t2 subst) (let ((s2 (unify t1 t2 subst))) (cond ((= s2 nil) (error (str "ocaml-infer: cannot unify " t1 " with " t2))) (else s2))))) ;; Look up name; instantiate scheme to a fresh monotype. (define ocaml-infer-var (fn (name env counter) (cond ((has-key? env name) (let ((scheme (get env name))) (let ((t (hm-instantiate scheme counter))) {:subst {} :type t}))) (else (error (str "ocaml-infer: unbound variable " name)))))) (define ocaml-infer-app (fn (fn-expr arg-expr env counter) (let ((r1 (ocaml-infer fn-expr env counter))) (let ((s1 (get r1 :subst)) (t1 (get r1 :type))) (let ((env2 (hm-apply-env s1 env))) (let ((r2 (ocaml-infer arg-expr env2 counter))) (let ((s2 (get r2 :subst)) (t2 (get r2 :type))) (let ((tv (hm-fresh-tv counter))) (let ((s3 (ocaml-hm-unify (hm-apply s2 t1) (hm-arrow t2 tv) (hm-compose s2 s1)))) {:subst s3 :type (hm-apply s3 tv)}))))))))) (define ocaml-infer-fun (fn (params body env counter) (cond ((= (len params) 0) (error "ocaml-infer: fun without params")) ((= (len params) 1) (let ((tv (hm-fresh-tv counter))) (let ((env2 (assoc env (first params) (hm-monotype tv)))) (let ((r (ocaml-infer body env2 counter))) (let ((s (get r :subst)) (t-body (get r :type))) {:subst s :type (hm-arrow (hm-apply s tv) t-body)}))))) (else ;; Curry: fun x y -> e ≡ fun x -> fun y -> e (let ((tv (hm-fresh-tv counter))) (let ((env2 (assoc env (first params) (hm-monotype tv)))) (let ((r (ocaml-infer-fun (rest params) body env2 counter))) (let ((s (get r :subst)) (t-rest (get r :type))) {:subst s :type (hm-arrow (hm-apply s tv) t-rest)})))))))) (define ocaml-infer-let (fn (name params rhs body env counter) (let ((rhs-expr (cond ((= (len params) 0) rhs) (else (list :fun params rhs))))) (let ((r1 (ocaml-infer rhs-expr env counter))) (let ((s1 (get r1 :subst)) (t1 (get r1 :type))) (let ((env2 (hm-apply-env s1 env))) (let ((scheme (hm-generalize t1 env2))) (let ((env3 (assoc env2 name scheme))) (let ((r2 (ocaml-infer body env3 counter))) (let ((s2 (get r2 :subst)) (t2 (get r2 :type))) {:subst (hm-compose s2 s1) :type t2})))))))))) ;; let x = e1 and y = e2 in body — non-rec multi-binding; each rhs is ;; inferred against the parent env, then generalized and added to body env. (define ocaml-infer-let-mut (fn (bindings body env counter) (let ((subst {}) (env-cur env)) (begin (define one (fn (b) (let ((nm (nth b 0)) (ps (nth b 1)) (rh (nth b 2))) (let ((rhs-expr (cond ((= (len ps) 0) rh) (else (list :fun ps rh))))) (let ((r (ocaml-infer rhs-expr env-cur counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((env-after (hm-apply-env s env-cur))) (let ((scheme (hm-generalize t env-after))) (begin (set! subst (hm-compose s subst)) (set! env-cur (assoc env-after nm scheme))))))))))) (define loop (fn (xs) (when (not (= xs (list))) (begin (one (first xs)) (loop (rest xs)))))) (loop bindings) (let ((rb (ocaml-infer body env-cur counter))) (let ((sb (get rb :subst)) (tb (get rb :type))) {:subst (hm-compose sb subst) :type tb})))))) ;; let rec f = ... and g = ... in body — mutually recursive multi-binding. ;; Pre-bind all names with fresh tvs, infer rhs in joint env, unify with ;; tvs, generalize, infer body. (define ocaml-infer-let-rec-mut (fn (bindings body env counter) (let ((tvs (list)) (env-rec env)) (begin (define alloc (fn (xs) (when (not (= xs (list))) (let ((b (first xs))) (let ((nm (nth b 0)) (tv (hm-fresh-tv counter))) (begin (append! tvs tv) (set! env-rec (assoc env-rec nm (hm-monotype tv))) (alloc (rest xs)))))))) (alloc bindings) (let ((subst {}) (idx 0)) (begin (define infer-one (fn (b) (let ((ps (nth b 1)) (rh (nth b 2))) (let ((rhs-expr (cond ((= (len ps) 0) rh) (else (list :fun ps rh))))) (let ((r (ocaml-infer rhs-expr env-rec counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((s2 (ocaml-hm-unify (hm-apply s (nth tvs idx)) t (hm-compose s subst)))) (begin (set! subst s2) (set! idx (+ idx 1)))))))))) (define loop (fn (xs) (when (not (= xs (list))) (begin (infer-one (first xs)) (loop (rest xs)))))) (loop bindings) (let ((env-final (hm-apply-env subst env))) (begin (set! idx 0) (define gen-one (fn (b) (let ((nm (nth b 0))) (let ((scheme (hm-generalize (hm-apply subst (nth tvs idx)) env-final))) (begin (set! env-final (assoc env-final nm scheme)) (set! idx (+ idx 1))))))) (define loop2 (fn (xs) (when (not (= xs (list))) (begin (gen-one (first xs)) (loop2 (rest xs)))))) (loop2 bindings) (let ((rb (ocaml-infer body env-final counter))) (let ((sb (get rb :subst)) (tb (get rb :type))) {:subst (hm-compose sb subst) :type tb})))))))))) ;; let-rec name params = rhs in body — bind name to a fresh tv before ;; inferring rhs, then unify the inferred rhs type with the tv. This ;; lets rhs reference name (recursive call). Generalize after. (define ocaml-infer-let-rec (fn (name params rhs body env counter) (let ((rhs-expr (cond ((= (len params) 0) rhs) (else (list :fun params rhs)))) (rec-tv (hm-fresh-tv counter))) (let ((env-rec (assoc env name (hm-monotype rec-tv)))) (let ((r1 (ocaml-infer rhs-expr env-rec counter))) (let ((s1 (get r1 :subst)) (t1 (get r1 :type))) (let ((s2 (ocaml-hm-unify (hm-apply s1 rec-tv) t1 s1))) (let ((env2 (hm-apply-env s2 env))) (let ((scheme (hm-generalize (hm-apply s2 t1) env2))) (let ((env3 (assoc env2 name scheme))) (let ((r2 (ocaml-infer body env3 counter))) (let ((s3 (get r2 :subst)) (t2 (get r2 :type))) {:subst (hm-compose s3 s2) :type t2})))))))))))) (define ocaml-infer-if (fn (c-ast t-ast e-ast env counter) (let ((rc (ocaml-infer c-ast env counter))) (let ((sc (get rc :subst)) (tc (get rc :type))) (let ((sc2 (ocaml-hm-unify tc (hm-bool) sc))) (let ((env2 (hm-apply-env sc2 env))) (let ((rt (ocaml-infer t-ast env2 counter))) (let ((st (get rt :subst)) (tt (get rt :type))) (let ((env3 (hm-apply-env st env2))) (let ((re (ocaml-infer e-ast env3 counter))) (let ((se (get re :subst)) (te (get re :type))) (let ((sf (ocaml-hm-unify (hm-apply se tt) te (hm-compose se (hm-compose st sc2))))) {:subst sf :type (hm-apply sf te)})))))))))))) ;; Tuple type: (hm-con "*" (list T1 T2 ...)). (define ocaml-hm-tuple (fn (types) (hm-con "*" types))) ;; List type: (hm-con "list" (list ELEM)). (define ocaml-hm-list (fn (elem) (hm-con "list" (list elem)))) (define ocaml-infer-tuple (fn (items env counter) (let ((subst {}) (types (list))) (begin (define loop (fn (xs env-cur) (when (not (= xs (list))) (let ((r (ocaml-infer (first xs) env-cur counter))) (let ((s (get r :subst)) (t (get r :type))) (begin (set! subst (hm-compose s subst)) (append! types t) (loop (rest xs) (hm-apply-env s env-cur)))))))) (loop items env) {:subst subst :type (ocaml-hm-tuple (map (fn (t) (hm-apply subst t)) types))})))) ;; Pattern type inference. Returns {:type T :env ENV2 :subst S} where ;; ENV2 is the original env extended with any names the pattern binds. ;; Constructor patterns aren't supported here yet (need a type-def ;; registry) — :pcon falls through to a fresh tv so they don't break ;; inference of mixed clauses. (define ocaml-infer-pcon (fn (name arg-pats env counter) (cond ((ocaml-hm-ctor-has? name) (let ((ctor-type (hm-instantiate (ocaml-hm-ctor-lookup name) counter)) (env-cur env) (subst {})) (let ((cur-type (list nil))) (begin (set-nth! cur-type 0 ctor-type) (define loop (fn (xs) (when (not (= xs (list))) (let ((rp (ocaml-infer-pat (first xs) env-cur counter))) (let ((arg-tv (hm-fresh-tv counter)) (res-tv (hm-fresh-tv counter))) (let ((s1 (ocaml-hm-unify (nth cur-type 0) (hm-arrow arg-tv res-tv) (hm-compose (get rp :subst) subst)))) (let ((s2 (ocaml-hm-unify (hm-apply s1 arg-tv) (hm-apply s1 (get rp :type)) s1))) (begin (set! subst s2) (set-nth! cur-type 0 (hm-apply s2 res-tv)) (set! env-cur (get rp :env)) (loop (rest xs)))))))))) (loop arg-pats) {:type (hm-apply subst (nth cur-type 0)) :env env-cur :subst subst})))) (else (let ((tv (hm-fresh-tv counter))) {:type tv :env env :subst {}}))))) (define ocaml-infer-pat (fn (pat env counter) (let ((tag (nth pat 0))) (cond ((= tag "pwild") (let ((tv (hm-fresh-tv counter))) {:type tv :env env :subst {}})) ((= tag "pvar") (let ((nm (nth pat 1)) (tv (hm-fresh-tv counter))) {:type tv :env (assoc env nm (hm-monotype tv)) :subst {}})) ((= tag "plit") (let ((r (ocaml-infer (nth pat 1) env counter))) {:type (get r :type) :env env :subst (get r :subst)})) ((= tag "pcons") (let ((rh (ocaml-infer-pat (nth pat 1) env counter))) (let ((rt (ocaml-infer-pat (nth pat 2) (get rh :env) counter))) (let ((s (ocaml-hm-unify (ocaml-hm-list (get rh :type)) (get rt :type) (hm-compose (get rt :subst) (get rh :subst))))) {:type (hm-apply s (ocaml-hm-list (get rh :type))) :env (get rt :env) :subst s})))) ((= tag "plist") (let ((items (rest pat)) (tv (hm-fresh-tv counter)) (env-cur env) (subst {})) (begin (define loop (fn (xs) (when (not (= xs (list))) (let ((rp (ocaml-infer-pat (first xs) env-cur counter))) (let ((s (ocaml-hm-unify (hm-apply (get rp :subst) tv) (get rp :type) (hm-compose (get rp :subst) subst)))) (begin (set! subst s) (set! env-cur (get rp :env)) (loop (rest xs)))))))) (loop items) {:type (hm-apply subst (ocaml-hm-list tv)) :env env-cur :subst subst}))) ((= tag "ptuple") (let ((items (rest pat)) (env-cur env) (subst {}) (types (list))) (begin (define loop (fn (xs) (when (not (= xs (list))) (let ((rp (ocaml-infer-pat (first xs) env-cur counter))) (begin (set! subst (hm-compose (get rp :subst) subst)) (append! types (get rp :type)) (set! env-cur (get rp :env)) (loop (rest xs))))))) (loop items) {:type (ocaml-hm-tuple (map (fn (t) (hm-apply subst t)) types)) :env env-cur :subst subst}))) ((= tag "pas") (let ((rp (ocaml-infer-pat (nth pat 1) env counter))) (let ((alias (nth pat 2))) {:type (get rp :type) :env (assoc (get rp :env) alias (hm-monotype (get rp :type))) :subst (get rp :subst)}))) ((= tag "pcon") (ocaml-infer-pcon (nth pat 1) (rest (rest pat)) env counter)) (else (let ((tv (hm-fresh-tv counter))) {:type tv :env env :subst {}})))))) (define ocaml-infer-match (fn (scrut clauses env counter) (let ((rs (ocaml-infer scrut env counter))) (let ((s (get rs :subst)) (st (get rs :type)) (result-tv (hm-fresh-tv counter))) (let ((subst s)) (begin (define loop (fn (cs) (when (not (= cs (list))) (let ((clause (first cs))) (let ((ctag (nth clause 0))) (let ((p (nth clause 1)) (body (cond ((= ctag "case") (nth clause 2)) (else (nth clause 3))))) (let ((rp (ocaml-infer-pat p (hm-apply-env subst env) counter))) (let ((s1 (ocaml-hm-unify (hm-apply (get rp :subst) st) (get rp :type) (hm-compose (get rp :subst) subst)))) (let ((rb (ocaml-infer body (hm-apply-env s1 (get rp :env)) counter))) (let ((s2 (ocaml-hm-unify (hm-apply (get rb :subst) result-tv) (get rb :type) (hm-compose (get rb :subst) s1)))) (begin (set! subst s2) (loop (rest cs))))))))))))) (loop clauses) {:subst subst :type (hm-apply subst result-tv)})))))) (define ocaml-infer-list (fn (items env counter) (cond ((= (len items) 0) {:subst {} :type (ocaml-hm-list (hm-fresh-tv counter))}) (else (let ((subst {}) (elem-tv (hm-fresh-tv counter))) (begin (define loop (fn (xs env-cur) (when (not (= xs (list))) (let ((r (ocaml-infer (first xs) env-cur counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((s2 (ocaml-hm-unify (hm-apply s elem-tv) t (hm-compose s subst)))) (begin (set! subst s2) (loop (rest xs) (hm-apply-env s2 env-cur))))))))) (loop items env) {:subst subst :type (ocaml-hm-list (hm-apply subst elem-tv))})))))) ;; Mutable cell so user `type` declarations can extend the registry. (define ocaml-hm-ctors (list (ocaml-hm-ctor-env))) (define ocaml-hm-ctor-lookup (fn (name) (get (nth ocaml-hm-ctors 0) name))) (define ocaml-hm-ctor-has? (fn (name) (has-key? (nth ocaml-hm-ctors 0) name))) (define ocaml-hm-ctor-register! (fn (name scheme) (set-nth! ocaml-hm-ctors 0 (merge (nth ocaml-hm-ctors 0) (dict name scheme))))) ;; Parse a simple type source into an HM type. Handles primitive type ;; names, type variables `'a`, parametric `'a list`, `T1 * T2`, and ;; function `T1 -> T2`. Unknown tokens default to a fresh tv so the ;; result is at worst polymorphic, never wrong. (define ocaml-hm-parse-type-src (fn (src) (let ((s (trim src))) (cond ((= s "int") (hm-int)) ((= s "bool") (hm-bool)) ((= s "string") (hm-string)) ((= s "float") (ocaml-hm-float)) ((= s "unit") (hm-con "Unit" (list))) ((and (> (len s) 1) (= (nth s 0) "'")) (hm-tv (slice s 1 (len s)))) ;; "T list" / "T option" — split on space, treat last as ctor. (else (let ((parts (filter (fn (p) (not (= p ""))) (split s " ")))) (cond ((= (len parts) 2) (let ((arg (ocaml-hm-parse-type-src (first parts))) (head (nth parts 1))) (hm-con head (list arg)))) (else ;; Unknown: emit a fresh tv so unification stays sound. (hm-tv (str "_unknown")))))))))) ;; Process a :type-def AST. For each ctor, build its scheme. Multi-arg ;; ctors are a list of types — we model that as a tuple arg. (define ocaml-hm-register-type-def! (fn (type-def) (let ((name (nth type-def 1)) (params (nth type-def 2)) (ctors (nth type-def 3))) (let ((param-tvs (map hm-tv params))) (let ((self-type (hm-con name param-tvs))) (begin (define register-ctor (fn (ctor) (let ((cname (first ctor)) (arg-srcs (rest ctor))) (cond ((= (len arg-srcs) 0) (ocaml-hm-ctor-register! cname (hm-scheme params self-type))) (else ;; ARG-SRCS is a list of source strings, often a ;; single combined string `T1 * T2 * ...`. Parse. (let ((arg-type (ocaml-hm-parse-type-src (first arg-srcs)))) (ocaml-hm-ctor-register! cname (hm-scheme params (hm-arrow arg-type self-type))))))))) (for-each register-ctor ctors))))))) (set! ocaml-infer (fn (expr env counter) (let ((tag (nth expr 0))) (cond ((= tag "con") (let ((name (nth expr 1))) (cond ((ocaml-hm-ctor-has? name) {:subst {} :type (hm-instantiate (ocaml-hm-ctor-lookup name) counter)}) (else {:subst {} :type (hm-fresh-tv counter)})))) ((= tag "int") {:subst {} :type (hm-int)}) ((= tag "float") {:subst {} :type (ocaml-hm-float)}) ((= tag "string") {:subst {} :type (hm-string)}) ((= tag "char") {:subst {} :type (hm-string)}) ((= tag "bool") {:subst {} :type (hm-bool)}) ((= tag "unit") {:subst {} :type (hm-con "Unit" (list))}) ((= tag "var") (ocaml-infer-var (nth expr 1) env counter)) ((= tag "fun") (ocaml-infer-fun (nth expr 1) (nth expr 2) env counter)) ((= tag "app") (ocaml-infer-app (nth expr 1) (nth expr 2) env counter)) ((= tag "let") (ocaml-infer-let (nth expr 1) (nth expr 2) (nth expr 3) (nth expr 4) env counter)) ((= tag "let-rec") (ocaml-infer-let-rec (nth expr 1) (nth expr 2) (nth expr 3) (nth expr 4) env counter)) ((= tag "let-mut") (ocaml-infer-let-mut (nth expr 1) (nth expr 2) env counter)) ((= tag "let-rec-mut") (ocaml-infer-let-rec-mut (nth expr 1) (nth expr 2) env counter)) ((= tag "if") (ocaml-infer-if (nth expr 1) (nth expr 2) (nth expr 3) env counter)) ((= tag "tuple") (ocaml-infer-tuple (rest expr) env counter)) ((= tag "list") (ocaml-infer-list (rest expr) env counter)) ((= tag "match") (ocaml-infer-match (nth expr 1) (nth expr 2) env counter)) ((= tag "neg") (let ((r (ocaml-infer (nth expr 1) env counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((s2 (ocaml-hm-unify t (hm-int) s))) {:subst s2 :type (hm-int)})))) ((= tag "not") (let ((r (ocaml-infer (nth expr 1) env counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((s2 (ocaml-hm-unify t (hm-bool) s))) {:subst s2 :type (hm-bool)})))) ((= tag "op") ;; Treat (:op OP L R) as (:app (:app (:var OP) L) R) — same rule. (ocaml-infer (list :app (list :app (list :var (nth expr 1)) (nth expr 2)) (nth expr 3)) env counter)) (else (error (str "ocaml-infer: unsupported tag " tag))))))) ;; Top-level convenience: parse + infer + render the type. (define ocaml-type-of (fn (src) (let ((expr (ocaml-parse src)) (env (ocaml-hm-builtin-env)) (counter (ocaml-hm-counter))) (let ((r (ocaml-infer expr env counter))) (ocaml-hm-format-type (hm-apply (get r :subst) (get r :type))))))) ;; Program-level type inference: process decls in order, registering ;; type-defs with the ctor registry, threading let-bindings into the ;; env, and returning the type of the last expression-level form. (define ocaml-type-of-program (fn (src) (let ((prog (ocaml-parse-program src)) (env (ocaml-hm-builtin-env)) (counter (ocaml-hm-counter)) (last-type (hm-tv "?"))) (begin (define run-decl (fn (decl) (let ((tag (nth decl 0))) (cond ((= tag "type-def") (ocaml-hm-register-type-def! decl)) ((= tag "exception-def") nil) ((= tag "def") (let ((nm (nth decl 1)) (ps (nth decl 2)) (rh (nth decl 3))) (let ((rhs-expr (cond ((= (len ps) 0) rh) (else (list :fun ps rh))))) (let ((r (ocaml-infer rhs-expr env counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((env2 (hm-apply-env s env))) (let ((scheme (hm-generalize t env2))) (begin (set! env (assoc env2 nm scheme)) (set! last-type t))))))))) ((= tag "def-rec") (let ((nm (nth decl 1)) (ps (nth decl 2)) (rh (nth decl 3))) (let ((rec-tv (hm-fresh-tv counter))) (let ((env-rec (assoc env nm (hm-monotype rec-tv))) (rhs-expr (cond ((= (len ps) 0) rh) (else (list :fun ps rh))))) (let ((r (ocaml-infer rhs-expr env-rec counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((s2 (ocaml-hm-unify (hm-apply s rec-tv) t s))) (let ((env2 (hm-apply-env s2 env))) (let ((scheme (hm-generalize (hm-apply s2 t) env2))) (begin (set! env (assoc env2 nm scheme)) (set! last-type t))))))))))) ((= tag "expr") (let ((r (ocaml-infer (nth decl 1) env counter))) (set! last-type (hm-apply (get r :subst) (get r :type))))) ((= tag "def-mut") ;; let x = e and y = e' (top level, no rec) (let ((bindings (nth decl 1))) (begin (define one (fn (b) (let ((nm (nth b 0)) (ps (nth b 1)) (rh (nth b 2))) (let ((rhs-expr (cond ((= (len ps) 0) rh) (else (list :fun ps rh))))) (let ((r (ocaml-infer rhs-expr env counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((env2 (hm-apply-env s env))) (let ((scheme (hm-generalize t env2))) (begin (set! env (assoc env2 nm scheme)) (set! last-type t)))))))))) (define loop (fn (xs) (when (not (= xs (list))) (begin (one (first xs)) (loop (rest xs)))))) (loop bindings)))) ((= tag "def-rec-mut") ;; let rec f = ... and g = ... — mutual recursion at top level. (let ((bindings (nth decl 1)) (tvs (list)) (env-rec env)) (begin (define alloc (fn (xs) (when (not (= xs (list))) (let ((b (first xs))) (let ((nm (nth b 0)) (tv (hm-fresh-tv counter))) (begin (append! tvs tv) (set! env-rec (assoc env-rec nm (hm-monotype tv))) (alloc (rest xs)))))))) (alloc bindings) (let ((subst {}) (idx 0)) (begin (define infer-one (fn (b) (let ((ps (nth b 1)) (rh (nth b 2))) (let ((rhs-expr (cond ((= (len ps) 0) rh) (else (list :fun ps rh))))) (let ((r (ocaml-infer rhs-expr env-rec counter))) (let ((s (get r :subst)) (t (get r :type))) (let ((s2 (ocaml-hm-unify (hm-apply s (nth tvs idx)) t (hm-compose s subst)))) (begin (set! subst s2) (set! idx (+ idx 1)) (set! last-type (hm-apply s2 t)))))))))) (define loop2 (fn (xs) (when (not (= xs (list))) (begin (infer-one (first xs)) (loop2 (rest xs)))))) (loop2 bindings) (set! env (hm-apply-env subst env)) (set! idx 0) (define gen-one (fn (b) (let ((nm (nth b 0))) (let ((scheme (hm-generalize (hm-apply subst (nth tvs idx)) env))) (begin (set! env (assoc env nm scheme)) (set! idx (+ idx 1))))))) (define loop3 (fn (xs) (when (not (= xs (list))) (begin (gen-one (first xs)) (loop3 (rest xs)))))) (loop3 bindings)))))) (else nil))))) (define loop (fn (xs) (when (not (= xs (list))) (begin (run-decl (first xs)) (loop (rest xs)))))) (loop (rest prog)) (ocaml-hm-format-type last-type))))) ;; Pretty-print a type as an OCaml-style string for testing. Only handles ;; the constructors we use: Int / Bool / String / Unit / -> / type-vars. (define ocaml-hm-format-type (fn (t) (cond ((is-var? t) (str "'" (var-name t))) ((is-ctor? t) (let ((head (ctor-head t)) (args (ctor-args t))) (cond ((= head "->") (let ((a (nth args 0)) (b (nth args 1))) (str (cond ((and (is-ctor? a) (= (ctor-head a) "->")) (str "(" (ocaml-hm-format-type a) ")")) (else (ocaml-hm-format-type a))) " -> " (ocaml-hm-format-type b)))) ((= head "*") (let ((parts (map ocaml-hm-format-type args))) (join " * " parts))) ((= head "list") (let ((elem (ocaml-hm-format-type (nth args 0)))) (str elem " list"))) ((= head "option") (let ((elem (ocaml-hm-format-type (nth args 0)))) (str elem " option"))) ((= head "result") (let ((a (ocaml-hm-format-type (nth args 0))) (b (ocaml-hm-format-type (nth args 1)))) (str "(" a ", " b ") result"))) ((= head "Float") "Float") (else head)))) (else (str t)))))