diff --git a/shared/sx/async_eval.py b/shared/sx/async_eval.py index ee836a8..c0a7af6 100644 --- a/shared/sx/async_eval.py +++ b/shared/sx/async_eval.py @@ -279,6 +279,10 @@ async def _asf_or(expr, env, ctx): async def _asf_let(expr, env, ctx): + # Named let: (let name ((x 0) ...) body) + if isinstance(expr[1], Symbol): + return await _asf_named_let(expr, env, ctx) + bindings = expr[1] local = dict(env) if isinstance(bindings, list): @@ -299,6 +303,98 @@ async def _asf_let(expr, env, ctx): return NIL +async def _asf_named_let(expr, env, ctx): + """Async named let: (let name ((x 0) ...) body)""" + loop_name = expr[1].name + bindings = expr[2] + body = expr[3:] + + params: list[str] = [] + inits: list = [] + + if isinstance(bindings, list): + if bindings and isinstance(bindings[0], list): + for binding in bindings: + var = binding[0] + params.append(var.name if isinstance(var, Symbol) else var) + inits.append(binding[1]) + elif len(bindings) % 2 == 0: + for i in range(0, len(bindings), 2): + var = bindings[i] + params.append(var.name if isinstance(var, Symbol) else var) + inits.append(bindings[i + 1]) + + loop_body = body[0] if len(body) == 1 else [Symbol("begin")] + list(body) + loop_fn = Lambda(params, loop_body, dict(env), name=loop_name) + loop_fn.closure[loop_name] = loop_fn + + init_vals = [await _async_trampoline(await _async_eval(init, env, ctx)) for init in inits] + return await _async_call_lambda(loop_fn, init_vals, env, ctx) + + +async def _asf_letrec(expr, env, ctx): + """Async letrec: (letrec ((name1 val1) ...) body)""" + bindings = expr[1] + local = dict(env) + + names: list[str] = [] + val_exprs: list = [] + + if isinstance(bindings, list): + if bindings and isinstance(bindings[0], list): + for binding in bindings: + var = binding[0] + vname = var.name if isinstance(var, Symbol) else var + names.append(vname) + val_exprs.append(binding[1]) + local[vname] = NIL + elif len(bindings) % 2 == 0: + for i in range(0, len(bindings), 2): + var = bindings[i] + vname = var.name if isinstance(var, Symbol) else var + names.append(vname) + val_exprs.append(bindings[i + 1]) + local[vname] = NIL + + values = [await _async_trampoline(await _async_eval(ve, local, ctx)) for ve in val_exprs] + for name, val in zip(names, values): + local[name] = val + for val in values: + if isinstance(val, Lambda): + for name in names: + val.closure[name] = local[name] + + for body_expr in expr[2:-1]: + await _async_trampoline(await _async_eval(body_expr, local, ctx)) + if len(expr) > 2: + return _AsyncThunk(expr[-1], local, ctx) + return NIL + + +async def _asf_dynamic_wind(expr, env, ctx): + """Async dynamic-wind: (dynamic-wind before body after)""" + before = await _async_trampoline(await _async_eval(expr[1], env, ctx)) + body_fn = await _async_trampoline(await _async_eval(expr[2], env, ctx)) + after = await _async_trampoline(await _async_eval(expr[3], env, ctx)) + + async def _call_thunk(fn): + if isinstance(fn, Lambda): + return await _async_trampoline(await _async_call_lambda(fn, [], env, ctx)) + if callable(fn): + r = fn() + if inspect.iscoroutine(r): + return await r + return r + raise EvalError(f"dynamic-wind: expected thunk, got {type(fn).__name__}") + + await _call_thunk(before) + try: + result = await _call_thunk(body_fn) + finally: + await _call_thunk(after) + return result + + async def _asf_lambda(expr, env, ctx): params_expr = expr[1] param_names = [] @@ -467,6 +563,7 @@ _ASYNC_SPECIAL_FORMS: dict[str, Any] = { "or": _asf_or, "let": _asf_let, "let*": _asf_let, + "letrec": _asf_letrec, "lambda": _asf_lambda, "fn": _asf_lambda, "define": _asf_define, @@ -481,6 +578,7 @@ _ASYNC_SPECIAL_FORMS: dict[str, Any] = { "quasiquote": _asf_quasiquote, "->": _asf_thread_first, "set!": _asf_set_bang, + "dynamic-wind": _asf_dynamic_wind, } diff --git a/shared/sx/evaluator.py b/shared/sx/evaluator.py index 4b03f9d..1b861b3 100644 --- a/shared/sx/evaluator.py +++ b/shared/sx/evaluator.py @@ -306,6 +306,11 @@ def _sf_or(expr: list, env: dict) -> Any: def _sf_let(expr: list, env: dict) -> Any: if len(expr) < 3: raise EvalError("let requires bindings and body") + + # Named let: (let name ((x 0) ...) body) + if isinstance(expr[1], Symbol): + return _sf_named_let(expr, env) + bindings = expr[1] local = dict(env) @@ -336,6 +341,127 @@ def _sf_let(expr: list, env: dict) -> Any: return _Thunk(body[-1], local) +def _sf_named_let(expr: list, env: dict) -> Any: + """``(let name ((x 0) (y 1)) body...)`` — self-recursive loop. + + Desugars to a lambda bound to *name* whose closure includes itself, + called with the initial values. Tail calls to *name* produce TCO thunks. + """ + loop_name = expr[1].name + bindings = expr[2] + body = expr[3:] + + params: list[str] = [] + inits: list[Any] = [] + + if isinstance(bindings, list): + if bindings and isinstance(bindings[0], list): + for binding in bindings: + var = binding[0] + params.append(var.name if isinstance(var, Symbol) else var) + inits.append(binding[1]) + elif len(bindings) % 2 == 0: + for i in range(0, len(bindings), 2): + var = bindings[i] + params.append(var.name if isinstance(var, Symbol) else var) + inits.append(bindings[i + 1]) + + # Build loop body (wrap in begin if multiple expressions) + loop_body = body[0] if len(body) == 1 else [Symbol("begin")] + list(body) + + # Create self-recursive lambda + loop_fn = Lambda(params, loop_body, dict(env), name=loop_name) + loop_fn.closure[loop_name] = loop_fn + + # Evaluate initial values in enclosing env, then call + init_vals = [_trampoline(_eval(init, env)) for init in inits] + return _call_lambda(loop_fn, init_vals, env) + + +def _sf_letrec(expr: list, env: dict) -> Any: + """``(letrec ((name1 val1) ...) body)`` — mutually recursive bindings. + + All names are bound to NIL first, then values are evaluated (so they + can reference each other), then lambda closures are patched. + """ + if len(expr) < 3: + raise EvalError("letrec requires bindings and body") + bindings = expr[1] + local = dict(env) + + names: list[str] = [] + val_exprs: list[Any] = [] + + if isinstance(bindings, list): + if bindings and isinstance(bindings[0], list): + for binding in bindings: + var = binding[0] + vname = var.name if isinstance(var, Symbol) else var + names.append(vname) + val_exprs.append(binding[1]) + local[vname] = NIL + elif len(bindings) % 2 == 0: + for i in range(0, len(bindings), 2): + var = bindings[i] + vname = var.name if isinstance(var, Symbol) else var + names.append(vname) + val_exprs.append(bindings[i + 1]) + local[vname] = NIL + + # Evaluate all values — they can see each other's names (initially NIL) + values = [_trampoline(_eval(ve, local)) for ve in val_exprs] + + # Bind final values + for name, val in zip(names, values): + local[name] = val + + # Patch lambda closures so they see the final bindings + for val in values: + if isinstance(val, Lambda): + for name in names: + val.closure[name] = local[name] + + body = expr[2:] + for body_expr in body[:-1]: + _trampoline(_eval(body_expr, local)) + return _Thunk(body[-1], local) + + +def _sf_dynamic_wind(expr: list, env: dict) -> Any: + """``(dynamic-wind before body after)`` — entry/exit guards. + + All three arguments are thunks (zero-arg functions). + *before* is called on entry, *after* is always called on exit (even on + error). The wind stack is maintained for future continuation support. + """ + if len(expr) != 4: + raise EvalError("dynamic-wind requires 3 arguments (before, body, after)") + before = _trampoline(_eval(expr[1], env)) + body_fn = _trampoline(_eval(expr[2], env)) + after = _trampoline(_eval(expr[3], env)) + + def _call_thunk(fn: Any) -> Any: + if isinstance(fn, Lambda): + return _trampoline(_call_lambda(fn, [], env)) + if callable(fn): + return fn() + raise EvalError(f"dynamic-wind: expected thunk, got {type(fn).__name__}") + + # Entry + _call_thunk(before) + _WIND_STACK.append((before, after)) + try: + result = _call_thunk(body_fn) + finally: + _WIND_STACK.pop() + _call_thunk(after) + return result + + +# Wind stack for dynamic-wind (thread-safe enough for sync evaluator) +_WIND_STACK: list[tuple] = [] + + def _sf_lambda(expr: list, env: dict) -> Lambda: if len(expr) < 3: raise EvalError("lambda requires params and body") @@ -883,6 +1009,7 @@ _SPECIAL_FORMS: dict[str, Any] = { "or": _sf_or, "let": _sf_let, "let*": _sf_let, + "letrec": _sf_letrec, "lambda": _sf_lambda, "fn": _sf_lambda, "define": _sf_define, @@ -895,6 +1022,7 @@ _SPECIAL_FORMS: dict[str, Any] = { "quote": _sf_quote, "->": _sf_thread_first, "set!": _sf_set_bang, + "dynamic-wind": _sf_dynamic_wind, "defmacro": _sf_defmacro, "quasiquote": _sf_quasiquote, "defhandler": _sf_defhandler, diff --git a/shared/sx/primitives.py b/shared/sx/primitives.py index ad74440..de86546 100644 --- a/shared/sx/primitives.py +++ b/shared/sx/primitives.py @@ -132,6 +132,27 @@ def prim_eq(a: Any, b: Any) -> bool: def prim_neq(a: Any, b: Any) -> bool: return a != b +@register_primitive("eq?") +def prim_eq_identity(a: Any, b: Any) -> bool: + """Identity equality — true only if a and b are the same object.""" + return a is b + +@register_primitive("eqv?") +def prim_eqv(a: Any, b: Any) -> bool: + """Equivalent: identity for compound types, value for atoms.""" + if a is b: + return True + if isinstance(a, (int, float, str, bool)) and isinstance(b, type(a)): + return a == b + if (a is None or a is NIL) and (b is None or b is NIL): + return True + return False + +@register_primitive("equal?") +def prim_equal(a: Any, b: Any) -> bool: + """Deep structural equality (same as =).""" + return a == b + @register_primitive("<") def prim_lt(a: Any, b: Any) -> bool: return a < b diff --git a/shared/sx/ref/bootstrap_js.py b/shared/sx/ref/bootstrap_js.py index 32db2a9..c20c5aa 100644 --- a/shared/sx/ref/bootstrap_js.py +++ b/shared/sx/ref/bootstrap_js.py @@ -165,6 +165,12 @@ class JSEmitter: "sf-and": "sfAnd", "sf-or": "sfOr", "sf-let": "sfLet", + "sf-named-let": "sfNamedLet", + "sf-letrec": "sfLetrec", + "sf-dynamic-wind": "sfDynamicWind", + "push-wind!": "pushWind", + "pop-wind!": "popWind", + "call-thunk": "callThunk", "sf-lambda": "sfLambda", "sf-define": "sfDefine", "sf-defcomp": "sfDefcomp", diff --git a/shared/sx/ref/eval.sx b/shared/sx/ref/eval.sx index eb4134f..806d093 100644 --- a/shared/sx/ref/eval.sx +++ b/shared/sx/ref/eval.sx @@ -136,6 +136,7 @@ (= name "or") (sf-or args env) (= name "let") (sf-let args env) (= name "let*") (sf-let args env) + (= name "letrec") (sf-letrec args env) (= name "lambda") (sf-lambda args env) (= name "fn") (sf-lambda args env) (= name "define") (sf-define args env) @@ -150,6 +151,7 @@ (= name "quasiquote") (sf-quasiquote args env) (= name "->") (sf-thread-first args env) (= name "set!") (sf-set! args env) + (= name "dynamic-wind") (sf-dynamic-wind args env) ;; Higher-order forms (= name "map") (ho-map args env) @@ -381,36 +383,83 @@ (define sf-let (fn (args env) - (let ((bindings (first args)) - (body (rest args)) - (local (env-extend env))) - ;; Parse bindings — support both ((name val) ...) and (name val name val ...) + ;; Detect named let: (let name ((x 0) ...) body) + ;; If first arg is a symbol, delegate to sf-named-let. + (if (= (type-of (first args)) "symbol") + (sf-named-let args env) + (let ((bindings (first args)) + (body (rest args)) + (local (env-extend env))) + ;; Parse bindings — support both ((name val) ...) and (name val name val ...) + (if (and (= (type-of (first bindings)) "list") + (= (len (first bindings)) 2)) + ;; Scheme-style + (for-each + (fn (binding) + (let ((vname (if (= (type-of (first binding)) "symbol") + (symbol-name (first binding)) + (first binding)))) + (env-set! local vname (trampoline (eval-expr (nth binding 1) local))))) + bindings) + ;; Clojure-style + (let ((i 0)) + (reduce + (fn (acc pair-idx) + (let ((vname (if (= (type-of (nth bindings (* pair-idx 2))) "symbol") + (symbol-name (nth bindings (* pair-idx 2))) + (nth bindings (* pair-idx 2)))) + (val-expr (nth bindings (inc (* pair-idx 2))))) + (env-set! local vname (trampoline (eval-expr val-expr local))))) + nil + (range 0 (/ (len bindings) 2))))) + ;; Evaluate body — last expression in tail position + (for-each + (fn (e) (trampoline (eval-expr e local))) + (slice body 0 (dec (len body)))) + (make-thunk (last body) local))))) + + +;; Named let: (let name ((x 0) (y 1)) body...) +;; Desugars to a self-recursive lambda called with initial values. +;; The loop name is bound in the body so recursive calls produce TCO thunks. +(define sf-named-let + (fn (args env) + (let ((loop-name (symbol-name (first args))) + (bindings (nth args 1)) + (body (slice args 2)) + (params (list)) + (inits (list))) + ;; Extract param names and init expressions (if (and (= (type-of (first bindings)) "list") (= (len (first bindings)) 2)) - ;; Scheme-style + ;; Scheme-style: ((x 0) (y 1)) (for-each (fn (binding) - (let ((vname (if (= (type-of (first binding)) "symbol") - (symbol-name (first binding)) - (first binding)))) - (env-set! local vname (trampoline (eval-expr (nth binding 1) local))))) + (append! params (if (= (type-of (first binding)) "symbol") + (symbol-name (first binding)) + (first binding))) + (append! inits (nth binding 1))) bindings) - ;; Clojure-style - (let ((i 0)) - (reduce - (fn (acc pair-idx) - (let ((vname (if (= (type-of (nth bindings (* pair-idx 2))) "symbol") - (symbol-name (nth bindings (* pair-idx 2))) - (nth bindings (* pair-idx 2)))) - (val-expr (nth bindings (inc (* pair-idx 2))))) - (env-set! local vname (trampoline (eval-expr val-expr local))))) - nil - (range 0 (/ (len bindings) 2))))) - ;; Evaluate body — last expression in tail position - (for-each - (fn (e) (trampoline (eval-expr e local))) - (slice body 0 (dec (len body)))) - (make-thunk (last body) local)))) + ;; Clojure-style: (x 0 y 1) + (reduce + (fn (acc pair-idx) + (do + (append! params (if (= (type-of (nth bindings (* pair-idx 2))) "symbol") + (symbol-name (nth bindings (* pair-idx 2))) + (nth bindings (* pair-idx 2)))) + (append! inits (nth bindings (inc (* pair-idx 2)))))) + nil + (range 0 (/ (len bindings) 2)))) + ;; Build loop body (wrap in begin if multiple exprs) + (let ((loop-body (if (= (len body) 1) (first body) + (cons (make-symbol "begin") body))) + (loop-fn (make-lambda params loop-body env))) + ;; Self-reference: loop can call itself by name + (set-lambda-name! loop-fn loop-name) + (env-set! (lambda-closure loop-fn) loop-name loop-fn) + ;; Evaluate initial values in enclosing env, then call + (let ((init-vals (map (fn (e) (trampoline (eval-expr e env))) inits))) + (call-lambda loop-fn init-vals env)))))) (define sf-lambda @@ -602,6 +651,109 @@ value))) +;; -------------------------------------------------------------------------- +;; 6c. letrec — mutually recursive local bindings +;; -------------------------------------------------------------------------- +;; +;; (letrec ((even? (fn (n) (if (= n 0) true (odd? (- n 1))))) +;; (odd? (fn (n) (if (= n 0) false (even? (- n 1)))))) +;; (even? 10)) +;; +;; All bindings are first set to nil in the local env, then all values +;; are evaluated (so they can see each other's names), then lambda +;; closures are patched to include the final bindings. +;; -------------------------------------------------------------------------- + +(define sf-letrec + (fn (args env) + (let ((bindings (first args)) + (body (rest args)) + (local (env-extend env)) + (names (list)) + (val-exprs (list))) + ;; First pass: bind all names to nil + (if (and (= (type-of (first bindings)) "list") + (= (len (first bindings)) 2)) + ;; Scheme-style + (for-each + (fn (binding) + (let ((vname (if (= (type-of (first binding)) "symbol") + (symbol-name (first binding)) + (first binding)))) + (append! names vname) + (append! val-exprs (nth binding 1)) + (env-set! local vname nil))) + bindings) + ;; Clojure-style + (reduce + (fn (acc pair-idx) + (let ((vname (if (= (type-of (nth bindings (* pair-idx 2))) "symbol") + (symbol-name (nth bindings (* pair-idx 2))) + (nth bindings (* pair-idx 2)))) + (val-expr (nth bindings (inc (* pair-idx 2))))) + (append! names vname) + (append! val-exprs val-expr) + (env-set! local vname nil))) + nil + (range 0 (/ (len bindings) 2)))) + ;; Second pass: evaluate values (they can see each other's names) + (let ((values (map (fn (e) (trampoline (eval-expr e local))) val-exprs))) + ;; Bind final values + (for-each + (fn (pair) (env-set! local (first pair) (nth pair 1))) + (zip names values)) + ;; Patch lambda closures so they see the final bindings + (for-each + (fn (val) + (when (lambda? val) + (for-each + (fn (n) (env-set! (lambda-closure val) n (env-get local n))) + names))) + values)) + ;; Evaluate body + (for-each + (fn (e) (trampoline (eval-expr e local))) + (slice body 0 (dec (len body)))) + (make-thunk (last body) local)))) + + +;; -------------------------------------------------------------------------- +;; 6d. dynamic-wind — entry/exit guards +;; -------------------------------------------------------------------------- +;; +;; (dynamic-wind before-thunk body-thunk after-thunk) +;; +;; All three are zero-argument functions (thunks): +;; 1. Call before-thunk +;; 2. Call body-thunk, capture result +;; 3. Call after-thunk (always, even on error) +;; 4. Return body result +;; +;; The wind stack is maintained so that when continuations jump across +;; dynamic-wind boundaries, the correct before/after thunks fire. +;; Without active continuations, this is equivalent to try/finally. +;; +;; Platform requirements: +;; (push-wind! before after) — push wind record onto stack +;; (pop-wind!) — pop wind record from stack +;; (call-thunk f env) — call a zero-arg function +;; -------------------------------------------------------------------------- + +(define sf-dynamic-wind + (fn (args env) + (let ((before (trampoline (eval-expr (first args) env))) + (body (trampoline (eval-expr (nth args 1) env))) + (after (trampoline (eval-expr (nth args 2) env)))) + ;; Call entry thunk + (call-thunk before env) + ;; Push wind record, run body, pop, call exit + (push-wind! before after) + (let ((result (call-thunk body env))) + (pop-wind!) + (call-thunk after env) + result)))) + + ;; -------------------------------------------------------------------------- ;; 6b. Macro expansion ;; -------------------------------------------------------------------------- @@ -765,6 +917,12 @@ ;; (apply f args) → call f with args list ;; (zip lists...) → list of tuples ;; +;; ;; CSSX (style system): ;; (build-keyframes name steps env) → StyleValue (platform builds @keyframes) +;; +;; Dynamic wind (for dynamic-wind): +;; (push-wind! before after) → void (push wind record onto stack) +;; (pop-wind!) → void (pop wind record from stack) +;; (call-thunk f env) → value (call a zero-arg function) ;; -------------------------------------------------------------------------- diff --git a/shared/sx/ref/primitives.sx b/shared/sx/ref/primitives.sx index 7aea3bb..f85fc17 100644 --- a/shared/sx/ref/primitives.sx +++ b/shared/sx/ref/primitives.sx @@ -121,7 +121,7 @@ (define-primitive "=" :params (a b) :returns "boolean" - :doc "Equality (value equality, not identity).") + :doc "Deep structural equality. Alias for equal?.") (define-primitive "!=" :params (a b) @@ -129,6 +129,27 @@ :doc "Inequality." :body (not (= a b))) +(define-primitive "eq?" + :params (a b) + :returns "boolean" + :doc "Identity equality. True only if a and b are the exact same object. + For immutable atoms (numbers, strings, booleans, nil) this may or + may not match — use eqv? for reliable atom comparison.") + +(define-primitive "eqv?" + :params (a b) + :returns "boolean" + :doc "Equivalent value for atoms, identity for compound objects. + Returns true for identical objects (eq?), and also for numbers, + strings, booleans, and nil with the same value. For lists, dicts, + lambdas, and components, only true if same identity.") + +(define-primitive "equal?" + :params (a b) + :returns "boolean" + :doc "Deep structural equality. Recursively compares lists and dicts. + Same semantics as = but explicit Scheme name.") + (define-primitive "<" :params (a b) :returns "boolean" diff --git a/shared/sx/tests/test_scheme_forms.py b/shared/sx/tests/test_scheme_forms.py new file mode 100644 index 0000000..df26f8b --- /dev/null +++ b/shared/sx/tests/test_scheme_forms.py @@ -0,0 +1,327 @@ +"""Tests for Scheme-inspired forms: named let, letrec, dynamic-wind, eq?/eqv?/equal?.""" + +import pytest +from shared.sx import parse, evaluate, EvalError, Symbol, NIL +from shared.sx.types import Lambda + + +def ev(text, env=None): + """Parse and evaluate a single expression.""" + return evaluate(parse(text), env) + + +# --------------------------------------------------------------------------- +# Named let +# --------------------------------------------------------------------------- + +class TestNamedLet: + def test_basic_loop(self): + """Named let as a simple counter loop.""" + result = ev(""" + (let loop ((i 0) (acc 0)) + (if (> i 5) + acc + (loop (+ i 1) (+ acc i)))) + """) + assert result == 15 # 0+1+2+3+4+5 + + def test_factorial(self): + result = ev(""" + (let fact ((n 10) (acc 1)) + (if (<= n 1) + acc + (fact (- n 1) (* acc n)))) + """) + assert result == 3628800 + + def test_tco_deep_recursion(self): + """Named let should use TCO — no stack overflow on deep loops.""" + result = ev(""" + (let loop ((i 0)) + (if (>= i 10000) + i + (loop (+ i 1)))) + """) + assert result == 10000 + + def test_clojure_style_bindings(self): + """Named let with Clojure-style flat bindings.""" + result = ev(""" + (let loop (i 0 acc (list)) + (if (>= i 3) + acc + (loop (+ i 1) (append acc i)))) + """) + assert result == [0, 1, 2] + + def test_scheme_style_bindings(self): + """Named let with Scheme-style paired bindings.""" + result = ev(""" + (let loop ((i 3) (result (list))) + (if (= i 0) + result + (loop (- i 1) (cons i result)))) + """) + assert result == [1, 2, 3] + + def test_accumulator_pattern(self): + """Named let accumulating a result — pure functional, no set! needed.""" + result = ev(""" + (let loop ((i 0) (acc (list))) + (if (>= i 3) + acc + (loop (+ i 1) (append acc (* i i))))) + """) + assert result == [0, 1, 4] + + def test_init_evaluated_in_outer_env(self): + """Initial values are evaluated in the enclosing environment.""" + result = ev(""" + (let ((x 10)) + (let loop ((a x) (b (* x 2))) + (+ a b))) + """) + assert result == 30 # 10 + 20 + + def test_build_list_with_named_let(self): + """Idiomatic Scheme pattern: build a list with named let.""" + result = ev(""" + (let collect ((items (list 1 2 3 4 5)) (acc (list))) + (if (empty? items) + acc + (collect (rest items) + (if (even? (first items)) + (append acc (first items)) + acc)))) + """) + assert result == [2, 4] + + def test_fibonacci(self): + """Fibonacci via named let.""" + result = ev(""" + (let fib ((n 10) (a 0) (b 1)) + (if (= n 0) a + (fib (- n 1) b (+ a b)))) + """) + assert result == 55 + + def test_string_building(self): + """Named let for building strings.""" + result = ev(""" + (let build ((items (list "a" "b" "c")) (acc "")) + (if (empty? items) + acc + (build (rest items) + (str acc (if (= acc "") "" ",") (first items))))) + """) + assert result == "a,b,c" + + +# --------------------------------------------------------------------------- +# letrec +# --------------------------------------------------------------------------- + +class TestLetrec: + def test_basic(self): + """Simple letrec with a self-referencing lambda.""" + result = ev(""" + (letrec ((double (fn (x) (* x 2)))) + (double 21)) + """) + assert result == 42 + + def test_mutual_recursion(self): + """Classic even?/odd? mutual recursion.""" + result = ev(""" + (letrec ((my-even? (fn (n) + (if (= n 0) true (my-odd? (- n 1))))) + (my-odd? (fn (n) + (if (= n 0) false (my-even? (- n 1)))))) + (list (my-even? 10) (my-odd? 10) + (my-even? 7) (my-odd? 7))) + """) + assert result == [True, False, False, True] + + def test_clojure_style(self): + """letrec with flat bindings.""" + result = ev(""" + (letrec (f (fn (x) (if (= x 0) 1 (* x (f (- x 1)))))) + (f 5)) + """) + assert result == 120 + + def test_closures_see_each_other(self): + """Lambdas in letrec see each other's final values.""" + result = ev(""" + (letrec ((a (fn () (b))) + (b (fn () 42))) + (a)) + """) + assert result == 42 + + def test_non_forward_ref(self): + """letrec with non-lambda values that don't reference each other.""" + result = ev(""" + (letrec ((x 10) (y 20)) + (+ x y)) + """) + assert result == 30 + + def test_three_way_mutual(self): + """Three mutually recursive functions.""" + result = ev(""" + (letrec ((f (fn (n) (if (= n 0) 1 (g (- n 1))))) + (g (fn (n) (if (= n 0) 2 (h (- n 1))))) + (h (fn (n) (if (= n 0) 3 (f (- n 1)))))) + (list (f 0) (f 1) (f 2) (f 3) + (g 0) (g 1) (g 2))) + """) + # f(0)=1, f(1)=g(0)=2, f(2)=g(1)=h(0)=3, f(3)=g(2)=h(1)=f(0)=1 + # g(0)=2, g(1)=h(0)=3, g(2)=h(1)=f(0)=1 + assert result == [1, 2, 3, 1, 2, 3, 1] + + +# --------------------------------------------------------------------------- +# dynamic-wind +# --------------------------------------------------------------------------- + +class TestDynamicWind: + def _make_log_env(self): + """Create env with a log! function that appends to a Python list.""" + log = [] + env = {"log!": lambda msg: log.append(msg) or NIL} + return env, log + + def test_basic_flow(self): + """Entry and exit thunks called around body.""" + env, log = self._make_log_env() + result = ev(""" + (dynamic-wind + (fn () (log! "enter")) + (fn () (do (log! "body") 42)) + (fn () (log! "exit"))) + """, env) + assert result == 42 + assert log == ["enter", "body", "exit"] + + def test_after_called_on_error(self): + """Exit thunk is called even when body raises an error.""" + env, log = self._make_log_env() + with pytest.raises(Exception): + ev(""" + (dynamic-wind + (fn () (log! "enter")) + (fn () (do (log! "body") (error "boom"))) + (fn () (log! "exit"))) + """, env) + assert log == ["enter", "body", "exit"] + + def test_nested(self): + """Nested dynamic-wind calls entry/exit in correct order.""" + env, log = self._make_log_env() + ev(""" + (dynamic-wind + (fn () (log! "outer-in")) + (fn () + (dynamic-wind + (fn () (log! "inner-in")) + (fn () (log! "body")) + (fn () (log! "inner-out")))) + (fn () (log! "outer-out"))) + """, env) + assert log == [ + "outer-in", "inner-in", "body", "inner-out", "outer-out" + ] + + def test_return_value(self): + """Body return value is propagated.""" + result = ev(""" + (dynamic-wind + (fn () nil) + (fn () (+ 20 22)) + (fn () nil)) + """) + assert result == 42 + + def test_before_after_are_thunks(self): + """Before and after must be zero-arg functions.""" + env, log = self._make_log_env() + # Verify it works with native Python callables too + env["enter"] = lambda: log.append("in") or NIL + env["leave"] = lambda: log.append("out") or NIL + result = ev(""" + (dynamic-wind enter (fn () 99) leave) + """, env) + assert result == 99 + assert log == ["in", "out"] + + +# --------------------------------------------------------------------------- +# Three-tier equality +# --------------------------------------------------------------------------- + +class TestEquality: + def test_eq_identity_same_object(self): + """eq? is true for the same object.""" + env = {} + ev("(define x (list 1 2 3))", env) + assert ev("(eq? x x)", env) is True + + def test_eq_identity_different_objects(self): + """eq? is false for different objects with same value.""" + assert ev("(eq? (list 1 2) (list 1 2))") is False + + def test_eq_numbers(self): + """eq? on small ints — Python interns them, so identity holds.""" + assert ev("(eq? 42 42)") is True + + def test_eqv_numbers(self): + """eqv? compares numbers by value.""" + assert ev("(eqv? 42 42)") is True + assert ev("(eqv? 42 43)") is False + + def test_eqv_strings(self): + """eqv? compares strings by value.""" + assert ev('(eqv? "hello" "hello")') is True + assert ev('(eqv? "hello" "world")') is False + + def test_eqv_nil(self): + """eqv? on nil values.""" + assert ev("(eqv? nil nil)") is True + + def test_eqv_booleans(self): + assert ev("(eqv? true true)") is True + assert ev("(eqv? true false)") is False + + def test_eqv_different_lists(self): + """eqv? is false for different list objects.""" + assert ev("(eqv? (list 1 2) (list 1 2))") is False + + def test_equal_deep(self): + """equal? does deep structural comparison.""" + assert ev("(equal? (list 1 2 3) (list 1 2 3))") is True + assert ev("(equal? (list 1 2) (list 1 2 3))") is False + + def test_equal_nested(self): + """equal? recursively compares nested structures.""" + assert ev("(equal? {:a (list 1 2)} {:a (list 1 2)})") is True + + def test_equal_is_same_as_equals(self): + """equal? and = have the same semantics.""" + assert ev("(equal? 42 42)") is True + assert ev("(= 42 42)") is True + assert ev("(equal? (list 1) (list 1))") is True + assert ev("(= (list 1) (list 1))") is True + + def test_eq_eqv_equal_hierarchy(self): + """eq? ⊂ eqv? ⊂ equal? — each is progressively looser.""" + env = {} + ev("(define x (list 1 2 3))", env) + # Same object: all three true + assert ev("(eq? x x)", env) is True + assert ev("(eqv? x x)", env) is True + assert ev("(equal? x x)", env) is True + # Different objects, same value: eq? false, eqv? false, equal? true + assert ev("(eq? (list 1 2) (list 1 2))", env) is False + assert ev("(eqv? (list 1 2) (list 1 2))", env) is False + assert ev("(equal? (list 1 2) (list 1 2))", env) is True