"""Tests for delimited continuations (shift/reset). Tests run against both the hand-written evaluator and the transpiled sx_ref evaluator to verify both implementations match. """ import pytest from shared.sx import parse, evaluate, EvalError, NIL from shared.sx.types import Continuation from shared.sx.ref import sx_ref # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def ev(text, env=None): """Parse and evaluate via hand-written evaluator.""" return evaluate(parse(text), env) def ev_ref(text, env=None): """Parse and evaluate via transpiled sx_ref.""" return sx_ref.evaluate(parse(text), env) EVALUATORS = [ pytest.param(ev, id="hand-written"), pytest.param(ev_ref, id="sx_ref"), ] # --------------------------------------------------------------------------- # Basic shift/reset # --------------------------------------------------------------------------- class TestBasicReset: """Reset without shift is a no-op wrapper.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_reset_passthrough(self, evaluate): assert evaluate("(reset 42)") == 42 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_reset_expression(self, evaluate): assert evaluate("(reset (+ 1 2))") == 3 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_reset_with_let(self, evaluate): assert evaluate("(reset (let (x 10) (+ x 5)))") == 15 class TestShiftAbort: """Shift without invoking k aborts to the reset boundary.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_abort_returns_shift_body(self, evaluate): # (reset (+ 1 (shift k 42))) → shift body 42 is returned, + 1 is abandoned assert evaluate("(reset (+ 1 (shift k 42)))") == 42 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_abort_string(self, evaluate): assert evaluate('(reset (+ 1 (shift k "aborted")))') == "aborted" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_abort_with_computation(self, evaluate): assert evaluate("(reset (+ 1 (shift k (* 6 7))))") == 42 class TestContinuationInvoke: """Invoking the captured continuation re-enters the reset body.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_invoke_once(self, evaluate): # (reset (+ 1 (shift k (k 10)))) → k resumes with 10, so + 1 10 = 11 assert evaluate("(reset (+ 1 (shift k (k 10))))") == 11 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_invoke_with_zero(self, evaluate): assert evaluate("(reset (+ 1 (shift k (k 0))))") == 1 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_invoke_twice(self, evaluate): # k invoked twice: (+ (k 1) (k 10)) → (+ 1 1) + ... → (+ 2 11) = 13 # Actually: (k 1) re-evaluates (+ 1 ) where shift returns 1 → 2 # Then (k 10) re-evaluates (+ 1 ) where shift returns 10 → 11 # Then (+ 2 11) = 13 assert evaluate("(reset (+ 1 (shift k (+ (k 1) (k 10)))))") == 13 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_invoke_transforms_value(self, evaluate): # k wraps: (reset (* 2 (shift k (k (k 3))))) # k(3) → (* 2 3) = 6, k(6) → (* 2 6) = 12 assert evaluate("(reset (* 2 (shift k (k (k 3)))))") == 12 class TestContinuationPredicate: """The continuation? predicate identifies captured continuations.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_continuation_is_true(self, evaluate): result = evaluate("(reset (shift k (continuation? k)))") assert result is True @pytest.mark.parametrize("evaluate", EVALUATORS) def test_non_continuation(self, evaluate): assert evaluate("(continuation? 42)") is False @pytest.mark.parametrize("evaluate", EVALUATORS) def test_nil_not_continuation(self, evaluate): assert evaluate("(continuation? nil)") is False @pytest.mark.parametrize("evaluate", EVALUATORS) def test_lambda_not_continuation(self, evaluate): assert evaluate("(continuation? (fn (x) x))") is False class TestStoredContinuation: """Continuations can be stored and invoked later.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_stored_in_variable(self, evaluate): code = """ (let (saved nil) (reset (+ 1 (shift k (do (set! saved k) "captured")))) ) """ # The reset returns "captured" (abort path) assert evaluate(code) == "captured" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_continuation_type(self, evaluate): """Verify that a captured continuation is identified by continuation?.""" code = '(reset (shift k (continuation? k)))' result = evaluate(code) assert result is True class TestNestedReset: """Nested reset blocks delimit independently.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_inner_reset(self, evaluate): code = "(reset (+ 1 (reset (+ 2 (shift k (k 10))))))" # Inner reset: (+ 2 (shift k (k 10))) → k(10) → (+ 2 10) = 12 # Outer reset: (+ 1 12) = 13 assert evaluate(code) == 13 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_inner_abort_outer_continues(self, evaluate): code = "(reset (+ 1 (reset (shift k 99))))" # Inner reset aborts with 99 # Outer reset: (+ 1 99) = 100 assert evaluate(code) == 100 class TestPracticalPatterns: """Practical uses of delimited continuations.""" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_early_return(self, evaluate): """Shift without invoking k acts as early return.""" code = """ (reset (let (x 5) (if (> x 3) (shift k "too big") (* x x)))) """ assert evaluate(code) == "too big" @pytest.mark.parametrize("evaluate", EVALUATORS) def test_normal_path(self, evaluate): """When condition doesn't trigger shift, normal result.""" code = """ (reset (let (x 2) (if (> x 3) (shift k "too big") (* x x)))) """ assert evaluate(code) == 4 @pytest.mark.parametrize("evaluate", EVALUATORS) def test_continuation_as_function(self, evaluate): """Map over a continuation to apply it to multiple values.""" code = """ (reset (+ 10 (shift k (map k (list 1 2 3))))) """ result = evaluate(code) assert result == [11, 12, 13] @pytest.mark.parametrize("evaluate", EVALUATORS) def test_default_value(self, evaluate): """Calling k with no args passes NIL.""" code = '(reset (shift k (nil? (k))))' # k() passes NIL, reset body re-evals: (shift k ...) returns NIL # Then the outer shift body checks: (nil? NIL) = true assert evaluate(code) is True