Files
rose-ash/shared/sx/tests/test_continuations.py
giles 102a27e845 Implement delimited continuations (shift/reset) across all evaluators
Bootstrap shift/reset to both Python and JS targets. The implementation
uses exception-based capture with re-evaluation: reset wraps in try/catch
for ShiftSignal, shift raises to the nearest reset, and continuation
invocation pushes a resume value and re-evaluates the body.

- Add Continuation type and _ShiftSignal to shared/sx/types.py
- Add sf_reset/sf_shift to hand-written evaluator.py
- Add async versions to async_eval.py
- Add shift/reset dispatch to eval.sx spec
- Bootstrap to Python: FIXUPS_PY with sf_reset/sf_shift, regenerate sx_ref.py
- Bootstrap to JS: Continuation/ShiftSignal types, sfReset/sfShift in fixups
- Add continuation? primitive to both bootstrappers and primitives.sx
- Allow callables (including Continuation) in hand-written HO map
- 44 unit tests (22 per evaluator) covering: passthrough, abort, invoke,
  double invoke, predicate, stored continuation, nested reset, practical patterns
- Update continuations essay to reflect implemented status with examples

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-06 00:58:50 +00:00

202 lines
7.0 KiB
Python

"""Tests for delimited continuations (shift/reset).
Tests run against both the hand-written evaluator and the transpiled
sx_ref evaluator to verify both implementations match.
"""
import pytest
from shared.sx import parse, evaluate, EvalError, NIL
from shared.sx.types import Continuation
from shared.sx.ref import sx_ref
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def ev(text, env=None):
"""Parse and evaluate via hand-written evaluator."""
return evaluate(parse(text), env)
def ev_ref(text, env=None):
"""Parse and evaluate via transpiled sx_ref."""
return sx_ref.evaluate(parse(text), env)
EVALUATORS = [
pytest.param(ev, id="hand-written"),
pytest.param(ev_ref, id="sx_ref"),
]
# ---------------------------------------------------------------------------
# Basic shift/reset
# ---------------------------------------------------------------------------
class TestBasicReset:
"""Reset without shift is a no-op wrapper."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_reset_passthrough(self, evaluate):
assert evaluate("(reset 42)") == 42
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_reset_expression(self, evaluate):
assert evaluate("(reset (+ 1 2))") == 3
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_reset_with_let(self, evaluate):
assert evaluate("(reset (let (x 10) (+ x 5)))") == 15
class TestShiftAbort:
"""Shift without invoking k aborts to the reset boundary."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_abort_returns_shift_body(self, evaluate):
# (reset (+ 1 (shift k 42))) → shift body 42 is returned, + 1 is abandoned
assert evaluate("(reset (+ 1 (shift k 42)))") == 42
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_abort_string(self, evaluate):
assert evaluate('(reset (+ 1 (shift k "aborted")))') == "aborted"
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_abort_with_computation(self, evaluate):
assert evaluate("(reset (+ 1 (shift k (* 6 7))))") == 42
class TestContinuationInvoke:
"""Invoking the captured continuation re-enters the reset body."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_invoke_once(self, evaluate):
# (reset (+ 1 (shift k (k 10)))) → k resumes with 10, so + 1 10 = 11
assert evaluate("(reset (+ 1 (shift k (k 10))))") == 11
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_invoke_with_zero(self, evaluate):
assert evaluate("(reset (+ 1 (shift k (k 0))))") == 1
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_invoke_twice(self, evaluate):
# k invoked twice: (+ (k 1) (k 10)) → (+ 1 1) + ... → (+ 2 11) = 13
# Actually: (k 1) re-evaluates (+ 1 <shift>) where shift returns 1 → 2
# Then (k 10) re-evaluates (+ 1 <shift>) where shift returns 10 → 11
# Then (+ 2 11) = 13
assert evaluate("(reset (+ 1 (shift k (+ (k 1) (k 10)))))") == 13
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_invoke_transforms_value(self, evaluate):
# k wraps: (reset (* 2 (shift k (k (k 3)))))
# k(3) → (* 2 3) = 6, k(6) → (* 2 6) = 12
assert evaluate("(reset (* 2 (shift k (k (k 3)))))") == 12
class TestContinuationPredicate:
"""The continuation? predicate identifies captured continuations."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_continuation_is_true(self, evaluate):
result = evaluate("(reset (shift k (continuation? k)))")
assert result is True
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_non_continuation(self, evaluate):
assert evaluate("(continuation? 42)") is False
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_nil_not_continuation(self, evaluate):
assert evaluate("(continuation? nil)") is False
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_lambda_not_continuation(self, evaluate):
assert evaluate("(continuation? (fn (x) x))") is False
class TestStoredContinuation:
"""Continuations can be stored and invoked later."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_stored_in_variable(self, evaluate):
code = """
(let (saved nil)
(reset (+ 1 (shift k (do (set! saved k) "captured"))))
)
"""
# The reset returns "captured" (abort path)
assert evaluate(code) == "captured"
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_continuation_type(self, evaluate):
"""Verify that a captured continuation is identified by continuation?."""
code = '(reset (shift k (continuation? k)))'
result = evaluate(code)
assert result is True
class TestNestedReset:
"""Nested reset blocks delimit independently."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_inner_reset(self, evaluate):
code = "(reset (+ 1 (reset (+ 2 (shift k (k 10))))))"
# Inner reset: (+ 2 (shift k (k 10))) → k(10) → (+ 2 10) = 12
# Outer reset: (+ 1 12) = 13
assert evaluate(code) == 13
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_inner_abort_outer_continues(self, evaluate):
code = "(reset (+ 1 (reset (shift k 99))))"
# Inner reset aborts with 99
# Outer reset: (+ 1 99) = 100
assert evaluate(code) == 100
class TestPracticalPatterns:
"""Practical uses of delimited continuations."""
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_early_return(self, evaluate):
"""Shift without invoking k acts as early return."""
code = """
(reset
(let (x 5)
(if (> x 3)
(shift k "too big")
(* x x))))
"""
assert evaluate(code) == "too big"
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_normal_path(self, evaluate):
"""When condition doesn't trigger shift, normal result."""
code = """
(reset
(let (x 2)
(if (> x 3)
(shift k "too big")
(* x x))))
"""
assert evaluate(code) == 4
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_continuation_as_function(self, evaluate):
"""Map over a continuation to apply it to multiple values."""
code = """
(reset
(+ 10 (shift k
(map k (list 1 2 3)))))
"""
result = evaluate(code)
assert result == [11, 12, 13]
@pytest.mark.parametrize("evaluate", EVALUATORS)
def test_default_value(self, evaluate):
"""Calling k with no args passes NIL."""
code = '(reset (shift k (nil? (k))))'
# k() passes NIL, reset body re-evals: (shift k ...) returns NIL
# Then the outer shift body checks: (nil? NIL) = true
assert evaluate(code) is True