"""Tests for Scheme-inspired forms: named let, letrec, dynamic-wind, eq?/eqv?/equal?.""" import pytest from shared.sx import parse, evaluate, EvalError, Symbol, NIL from shared.sx.types import Lambda def ev(text, env=None): """Parse and evaluate a single expression.""" return evaluate(parse(text), env) # --------------------------------------------------------------------------- # Named let # --------------------------------------------------------------------------- class TestNamedLet: def test_basic_loop(self): """Named let as a simple counter loop.""" result = ev(""" (let loop ((i 0) (acc 0)) (if (> i 5) acc (loop (+ i 1) (+ acc i)))) """) assert result == 15 # 0+1+2+3+4+5 def test_factorial(self): result = ev(""" (let fact ((n 10) (acc 1)) (if (<= n 1) acc (fact (- n 1) (* acc n)))) """) assert result == 3628800 def test_tco_deep_recursion(self): """Named let should use TCO — no stack overflow on deep loops.""" result = ev(""" (let loop ((i 0)) (if (>= i 10000) i (loop (+ i 1)))) """) assert result == 10000 def test_clojure_style_bindings(self): """Named let with Clojure-style flat bindings.""" result = ev(""" (let loop (i 0 acc (list)) (if (>= i 3) acc (loop (+ i 1) (append acc i)))) """) assert result == [0, 1, 2] def test_scheme_style_bindings(self): """Named let with Scheme-style paired bindings.""" result = ev(""" (let loop ((i 3) (result (list))) (if (= i 0) result (loop (- i 1) (cons i result)))) """) assert result == [1, 2, 3] def test_accumulator_pattern(self): """Named let accumulating a result — pure functional, no set! needed.""" result = ev(""" (let loop ((i 0) (acc (list))) (if (>= i 3) acc (loop (+ i 1) (append acc (* i i))))) """) assert result == [0, 1, 4] def test_init_evaluated_in_outer_env(self): """Initial values are evaluated in the enclosing environment.""" result = ev(""" (let ((x 10)) (let loop ((a x) (b (* x 2))) (+ a b))) """) assert result == 30 # 10 + 20 def test_build_list_with_named_let(self): """Idiomatic Scheme pattern: build a list with named let.""" result = ev(""" (let collect ((items (list 1 2 3 4 5)) (acc (list))) (if (empty? items) acc (collect (rest items) (if (even? (first items)) (append acc (first items)) acc)))) """) assert result == [2, 4] def test_fibonacci(self): """Fibonacci via named let.""" result = ev(""" (let fib ((n 10) (a 0) (b 1)) (if (= n 0) a (fib (- n 1) b (+ a b)))) """) assert result == 55 def test_string_building(self): """Named let for building strings.""" result = ev(""" (let build ((items (list "a" "b" "c")) (acc "")) (if (empty? items) acc (build (rest items) (str acc (if (= acc "") "" ",") (first items))))) """) assert result == "a,b,c" # --------------------------------------------------------------------------- # letrec # --------------------------------------------------------------------------- class TestLetrec: def test_basic(self): """Simple letrec with a self-referencing lambda.""" result = ev(""" (letrec ((double (fn (x) (* x 2)))) (double 21)) """) assert result == 42 def test_mutual_recursion(self): """Classic even?/odd? mutual recursion.""" result = ev(""" (letrec ((my-even? (fn (n) (if (= n 0) true (my-odd? (- n 1))))) (my-odd? (fn (n) (if (= n 0) false (my-even? (- n 1)))))) (list (my-even? 10) (my-odd? 10) (my-even? 7) (my-odd? 7))) """) assert result == [True, False, False, True] def test_clojure_style(self): """letrec with flat bindings.""" result = ev(""" (letrec (f (fn (x) (if (= x 0) 1 (* x (f (- x 1)))))) (f 5)) """) assert result == 120 def test_closures_see_each_other(self): """Lambdas in letrec see each other's final values.""" result = ev(""" (letrec ((a (fn () (b))) (b (fn () 42))) (a)) """) assert result == 42 def test_non_forward_ref(self): """letrec with non-lambda values that don't reference each other.""" result = ev(""" (letrec ((x 10) (y 20)) (+ x y)) """) assert result == 30 def test_three_way_mutual(self): """Three mutually recursive functions.""" result = ev(""" (letrec ((f (fn (n) (if (= n 0) 1 (g (- n 1))))) (g (fn (n) (if (= n 0) 2 (h (- n 1))))) (h (fn (n) (if (= n 0) 3 (f (- n 1)))))) (list (f 0) (f 1) (f 2) (f 3) (g 0) (g 1) (g 2))) """) # f(0)=1, f(1)=g(0)=2, f(2)=g(1)=h(0)=3, f(3)=g(2)=h(1)=f(0)=1 # g(0)=2, g(1)=h(0)=3, g(2)=h(1)=f(0)=1 assert result == [1, 2, 3, 1, 2, 3, 1] # --------------------------------------------------------------------------- # dynamic-wind # --------------------------------------------------------------------------- class TestDynamicWind: def _make_log_env(self): """Create env with a log! function that appends to a Python list.""" log = [] env = {"log!": lambda msg: log.append(msg) or NIL} return env, log def test_basic_flow(self): """Entry and exit thunks called around body.""" env, log = self._make_log_env() result = ev(""" (dynamic-wind (fn () (log! "enter")) (fn () (do (log! "body") 42)) (fn () (log! "exit"))) """, env) assert result == 42 assert log == ["enter", "body", "exit"] def test_after_called_on_error(self): """Exit thunk is called even when body raises an error.""" env, log = self._make_log_env() with pytest.raises(Exception): ev(""" (dynamic-wind (fn () (log! "enter")) (fn () (do (log! "body") (error "boom"))) (fn () (log! "exit"))) """, env) assert log == ["enter", "body", "exit"] def test_nested(self): """Nested dynamic-wind calls entry/exit in correct order.""" env, log = self._make_log_env() ev(""" (dynamic-wind (fn () (log! "outer-in")) (fn () (dynamic-wind (fn () (log! "inner-in")) (fn () (log! "body")) (fn () (log! "inner-out")))) (fn () (log! "outer-out"))) """, env) assert log == [ "outer-in", "inner-in", "body", "inner-out", "outer-out" ] def test_return_value(self): """Body return value is propagated.""" result = ev(""" (dynamic-wind (fn () nil) (fn () (+ 20 22)) (fn () nil)) """) assert result == 42 def test_before_after_are_thunks(self): """Before and after must be zero-arg functions.""" env, log = self._make_log_env() # Verify it works with native Python callables too env["enter"] = lambda: log.append("in") or NIL env["leave"] = lambda: log.append("out") or NIL result = ev(""" (dynamic-wind enter (fn () 99) leave) """, env) assert result == 99 assert log == ["in", "out"] # --------------------------------------------------------------------------- # Three-tier equality # --------------------------------------------------------------------------- class TestEquality: def test_eq_identity_same_object(self): """eq? is true for the same object.""" env = {} ev("(define x (list 1 2 3))", env) assert ev("(eq? x x)", env) is True def test_eq_identity_different_objects(self): """eq? is false for different objects with same value.""" assert ev("(eq? (list 1 2) (list 1 2))") is False def test_eq_numbers(self): """eq? on small ints — Python interns them, so identity holds.""" assert ev("(eq? 42 42)") is True def test_eqv_numbers(self): """eqv? compares numbers by value.""" assert ev("(eqv? 42 42)") is True assert ev("(eqv? 42 43)") is False def test_eqv_strings(self): """eqv? compares strings by value.""" assert ev('(eqv? "hello" "hello")') is True assert ev('(eqv? "hello" "world")') is False def test_eqv_nil(self): """eqv? on nil values.""" assert ev("(eqv? nil nil)") is True def test_eqv_booleans(self): assert ev("(eqv? true true)") is True assert ev("(eqv? true false)") is False def test_eqv_different_lists(self): """eqv? is false for different list objects.""" assert ev("(eqv? (list 1 2) (list 1 2))") is False def test_equal_deep(self): """equal? does deep structural comparison.""" assert ev("(equal? (list 1 2 3) (list 1 2 3))") is True assert ev("(equal? (list 1 2) (list 1 2 3))") is False def test_equal_nested(self): """equal? recursively compares nested structures.""" assert ev("(equal? {:a (list 1 2)} {:a (list 1 2)})") is True def test_equal_is_same_as_equals(self): """equal? and = have the same semantics.""" assert ev("(equal? 42 42)") is True assert ev("(= 42 42)") is True assert ev("(equal? (list 1) (list 1))") is True assert ev("(= (list 1) (list 1))") is True def test_eq_eqv_equal_hierarchy(self): """eq? ⊂ eqv? ⊂ equal? — each is progressively looser.""" env = {} ev("(define x (list 1 2 3))", env) # Same object: all three true assert ev("(eq? x x)", env) is True assert ev("(eqv? x x)", env) is True assert ev("(equal? x x)", env) is True # Different objects, same value: eq? false, eqv? false, equal? true assert ev("(eq? (list 1 2) (list 1 2))", env) is False assert ev("(eqv? (list 1 2) (list 1 2))", env) is False assert ev("(equal? (list 1 2) (list 1 2))", env) is True