Files
rose-ash/shared/sx/tests/test_scheme_forms.py
giles f34e55aa9b Add Scheme forms: named let, letrec, dynamic-wind, three-tier equality
Spec (eval.sx, primitives.sx):
- Named let: (let loop ((i 0)) body) — self-recursive lambda with TCO
- letrec: mutually recursive local bindings with closure patching
- dynamic-wind: entry/exit guards with wind stack for future continuations
- eq?/eqv?/equal?: identity, atom-value, and deep structural equality

Implementation (evaluator.py, async_eval.py, primitives.py):
- Both sync and async evaluators implement all four forms
- 33 new tests covering all forms including TCO at 10k depth

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 01:11:31 +00:00

328 lines
11 KiB
Python

"""Tests for Scheme-inspired forms: named let, letrec, dynamic-wind, eq?/eqv?/equal?."""
import pytest
from shared.sx import parse, evaluate, EvalError, Symbol, NIL
from shared.sx.types import Lambda
def ev(text, env=None):
"""Parse and evaluate a single expression."""
return evaluate(parse(text), env)
# ---------------------------------------------------------------------------
# Named let
# ---------------------------------------------------------------------------
class TestNamedLet:
def test_basic_loop(self):
"""Named let as a simple counter loop."""
result = ev("""
(let loop ((i 0) (acc 0))
(if (> i 5)
acc
(loop (+ i 1) (+ acc i))))
""")
assert result == 15 # 0+1+2+3+4+5
def test_factorial(self):
result = ev("""
(let fact ((n 10) (acc 1))
(if (<= n 1)
acc
(fact (- n 1) (* acc n))))
""")
assert result == 3628800
def test_tco_deep_recursion(self):
"""Named let should use TCO — no stack overflow on deep loops."""
result = ev("""
(let loop ((i 0))
(if (>= i 10000)
i
(loop (+ i 1))))
""")
assert result == 10000
def test_clojure_style_bindings(self):
"""Named let with Clojure-style flat bindings."""
result = ev("""
(let loop (i 0 acc (list))
(if (>= i 3)
acc
(loop (+ i 1) (append acc i))))
""")
assert result == [0, 1, 2]
def test_scheme_style_bindings(self):
"""Named let with Scheme-style paired bindings."""
result = ev("""
(let loop ((i 3) (result (list)))
(if (= i 0)
result
(loop (- i 1) (cons i result))))
""")
assert result == [1, 2, 3]
def test_accumulator_pattern(self):
"""Named let accumulating a result — pure functional, no set! needed."""
result = ev("""
(let loop ((i 0) (acc (list)))
(if (>= i 3)
acc
(loop (+ i 1) (append acc (* i i)))))
""")
assert result == [0, 1, 4]
def test_init_evaluated_in_outer_env(self):
"""Initial values are evaluated in the enclosing environment."""
result = ev("""
(let ((x 10))
(let loop ((a x) (b (* x 2)))
(+ a b)))
""")
assert result == 30 # 10 + 20
def test_build_list_with_named_let(self):
"""Idiomatic Scheme pattern: build a list with named let."""
result = ev("""
(let collect ((items (list 1 2 3 4 5)) (acc (list)))
(if (empty? items)
acc
(collect (rest items)
(if (even? (first items))
(append acc (first items))
acc))))
""")
assert result == [2, 4]
def test_fibonacci(self):
"""Fibonacci via named let."""
result = ev("""
(let fib ((n 10) (a 0) (b 1))
(if (= n 0) a
(fib (- n 1) b (+ a b))))
""")
assert result == 55
def test_string_building(self):
"""Named let for building strings."""
result = ev("""
(let build ((items (list "a" "b" "c")) (acc ""))
(if (empty? items)
acc
(build (rest items)
(str acc (if (= acc "") "" ",") (first items)))))
""")
assert result == "a,b,c"
# ---------------------------------------------------------------------------
# letrec
# ---------------------------------------------------------------------------
class TestLetrec:
def test_basic(self):
"""Simple letrec with a self-referencing lambda."""
result = ev("""
(letrec ((double (fn (x) (* x 2))))
(double 21))
""")
assert result == 42
def test_mutual_recursion(self):
"""Classic even?/odd? mutual recursion."""
result = ev("""
(letrec ((my-even? (fn (n)
(if (= n 0) true (my-odd? (- n 1)))))
(my-odd? (fn (n)
(if (= n 0) false (my-even? (- n 1))))))
(list (my-even? 10) (my-odd? 10)
(my-even? 7) (my-odd? 7)))
""")
assert result == [True, False, False, True]
def test_clojure_style(self):
"""letrec with flat bindings."""
result = ev("""
(letrec (f (fn (x) (if (= x 0) 1 (* x (f (- x 1))))))
(f 5))
""")
assert result == 120
def test_closures_see_each_other(self):
"""Lambdas in letrec see each other's final values."""
result = ev("""
(letrec ((a (fn () (b)))
(b (fn () 42)))
(a))
""")
assert result == 42
def test_non_forward_ref(self):
"""letrec with non-lambda values that don't reference each other."""
result = ev("""
(letrec ((x 10) (y 20))
(+ x y))
""")
assert result == 30
def test_three_way_mutual(self):
"""Three mutually recursive functions."""
result = ev("""
(letrec ((f (fn (n) (if (= n 0) 1 (g (- n 1)))))
(g (fn (n) (if (= n 0) 2 (h (- n 1)))))
(h (fn (n) (if (= n 0) 3 (f (- n 1))))))
(list (f 0) (f 1) (f 2) (f 3)
(g 0) (g 1) (g 2)))
""")
# f(0)=1, f(1)=g(0)=2, f(2)=g(1)=h(0)=3, f(3)=g(2)=h(1)=f(0)=1
# g(0)=2, g(1)=h(0)=3, g(2)=h(1)=f(0)=1
assert result == [1, 2, 3, 1, 2, 3, 1]
# ---------------------------------------------------------------------------
# dynamic-wind
# ---------------------------------------------------------------------------
class TestDynamicWind:
def _make_log_env(self):
"""Create env with a log! function that appends to a Python list."""
log = []
env = {"log!": lambda msg: log.append(msg) or NIL}
return env, log
def test_basic_flow(self):
"""Entry and exit thunks called around body."""
env, log = self._make_log_env()
result = ev("""
(dynamic-wind
(fn () (log! "enter"))
(fn () (do (log! "body") 42))
(fn () (log! "exit")))
""", env)
assert result == 42
assert log == ["enter", "body", "exit"]
def test_after_called_on_error(self):
"""Exit thunk is called even when body raises an error."""
env, log = self._make_log_env()
with pytest.raises(Exception):
ev("""
(dynamic-wind
(fn () (log! "enter"))
(fn () (do (log! "body") (error "boom")))
(fn () (log! "exit")))
""", env)
assert log == ["enter", "body", "exit"]
def test_nested(self):
"""Nested dynamic-wind calls entry/exit in correct order."""
env, log = self._make_log_env()
ev("""
(dynamic-wind
(fn () (log! "outer-in"))
(fn ()
(dynamic-wind
(fn () (log! "inner-in"))
(fn () (log! "body"))
(fn () (log! "inner-out"))))
(fn () (log! "outer-out")))
""", env)
assert log == [
"outer-in", "inner-in", "body", "inner-out", "outer-out"
]
def test_return_value(self):
"""Body return value is propagated."""
result = ev("""
(dynamic-wind
(fn () nil)
(fn () (+ 20 22))
(fn () nil))
""")
assert result == 42
def test_before_after_are_thunks(self):
"""Before and after must be zero-arg functions."""
env, log = self._make_log_env()
# Verify it works with native Python callables too
env["enter"] = lambda: log.append("in") or NIL
env["leave"] = lambda: log.append("out") or NIL
result = ev("""
(dynamic-wind enter (fn () 99) leave)
""", env)
assert result == 99
assert log == ["in", "out"]
# ---------------------------------------------------------------------------
# Three-tier equality
# ---------------------------------------------------------------------------
class TestEquality:
def test_eq_identity_same_object(self):
"""eq? is true for the same object."""
env = {}
ev("(define x (list 1 2 3))", env)
assert ev("(eq? x x)", env) is True
def test_eq_identity_different_objects(self):
"""eq? is false for different objects with same value."""
assert ev("(eq? (list 1 2) (list 1 2))") is False
def test_eq_numbers(self):
"""eq? on small ints — Python interns them, so identity holds."""
assert ev("(eq? 42 42)") is True
def test_eqv_numbers(self):
"""eqv? compares numbers by value."""
assert ev("(eqv? 42 42)") is True
assert ev("(eqv? 42 43)") is False
def test_eqv_strings(self):
"""eqv? compares strings by value."""
assert ev('(eqv? "hello" "hello")') is True
assert ev('(eqv? "hello" "world")') is False
def test_eqv_nil(self):
"""eqv? on nil values."""
assert ev("(eqv? nil nil)") is True
def test_eqv_booleans(self):
assert ev("(eqv? true true)") is True
assert ev("(eqv? true false)") is False
def test_eqv_different_lists(self):
"""eqv? is false for different list objects."""
assert ev("(eqv? (list 1 2) (list 1 2))") is False
def test_equal_deep(self):
"""equal? does deep structural comparison."""
assert ev("(equal? (list 1 2 3) (list 1 2 3))") is True
assert ev("(equal? (list 1 2) (list 1 2 3))") is False
def test_equal_nested(self):
"""equal? recursively compares nested structures."""
assert ev("(equal? {:a (list 1 2)} {:a (list 1 2)})") is True
def test_equal_is_same_as_equals(self):
"""equal? and = have the same semantics."""
assert ev("(equal? 42 42)") is True
assert ev("(= 42 42)") is True
assert ev("(equal? (list 1) (list 1))") is True
assert ev("(= (list 1) (list 1))") is True
def test_eq_eqv_equal_hierarchy(self):
"""eq? ⊂ eqv? ⊂ equal? — each is progressively looser."""
env = {}
ev("(define x (list 1 2 3))", env)
# Same object: all three true
assert ev("(eq? x x)", env) is True
assert ev("(eqv? x x)", env) is True
assert ev("(equal? x x)", env) is True
# Different objects, same value: eq? false, eqv? false, equal? true
assert ev("(eq? (list 1 2) (list 1 2))", env) is False
assert ev("(eqv? (list 1 2) (list 1 2))", env) is False
assert ev("(equal? (list 1 2) (list 1 2))", env) is True