Add TCO trampolining to async evaluator and sx.js client

Both evaluators now use thunk-based trampolining to eliminate stack
overflow on deep tail recursion (verified at 50K+ depth). Mirrors
the sync evaluator TCO added in 5069072.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-04 10:53:16 +00:00
parent da8d2e342f
commit e72f7485f4
2 changed files with 126 additions and 84 deletions

View File

@@ -33,12 +33,40 @@ from .html import (
)
# ---------------------------------------------------------------------------
# Async TCO — thunk + trampoline
# ---------------------------------------------------------------------------
class _AsyncThunk:
"""Deferred (expr, env, ctx) for tail-call optimization."""
__slots__ = ("expr", "env", "ctx")
def __init__(self, expr: Any, env: dict[str, Any], ctx: RequestContext) -> None:
self.expr = expr
self.env = env
self.ctx = ctx
async def _async_trampoline(val: Any) -> Any:
"""Iteratively resolve thunks from tail positions."""
while isinstance(val, _AsyncThunk):
val = await _async_eval(val.expr, val.env, val.ctx)
return val
# ---------------------------------------------------------------------------
# Async evaluate
# ---------------------------------------------------------------------------
async def async_eval(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any:
"""Evaluate *expr* in *env*, awaiting I/O primitives inline."""
"""Public entry — evaluates and trampolines thunks."""
result = await _async_eval(expr, env, ctx)
while isinstance(result, _AsyncThunk):
result = await _async_eval(result.expr, result.env, result.ctx)
return result
async def _async_eval(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any:
"""Internal evaluator — may return _AsyncThunk for tail positions."""
# --- literals ---
if isinstance(expr, (int, float, str, bool)):
return expr
@@ -66,7 +94,7 @@ async def async_eval(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any
# --- dict literal ---
if isinstance(expr, dict):
return {k: await async_eval(v, env, ctx) for k, v in expr.items()}
return {k: await _async_trampoline(await _async_eval(v, env, ctx)) for k, v in expr.items()}
# --- list ---
if not isinstance(expr, list):
@@ -77,7 +105,7 @@ async def async_eval(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any
head = expr[0]
if not isinstance(head, (Symbol, Lambda, list)):
return [await async_eval(x, env, ctx) for x in expr]
return [await _async_trampoline(await _async_eval(x, env, ctx)) for x in expr]
if isinstance(head, Symbol):
name = head.name
@@ -96,12 +124,12 @@ async def async_eval(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any
if ho is not None:
return await ho(expr, env, ctx)
# Macro expansion
# Macro expansion — tail position
if name in env:
val = env[name]
if isinstance(val, Macro):
expanded = _expand_macro(val, expr[1:], env)
return await async_eval(expanded, env, ctx)
return _AsyncThunk(expanded, env, ctx)
# Render forms in eval position — delegate to renderer and return
# as _RawHTML so it won't be double-escaped when used in render
@@ -111,8 +139,8 @@ async def async_eval(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any
return _RawHTML(html)
# --- function / lambda call ---
fn = await async_eval(head, env, ctx)
args = [await async_eval(a, env, ctx) for a in expr[1:]]
fn = await _async_trampoline(await _async_eval(head, env, ctx))
args = [await _async_trampoline(await _async_eval(a, env, ctx)) for a in expr[1:]]
if callable(fn) and not isinstance(fn, (Lambda, Component)):
result = fn(*args)
@@ -153,7 +181,7 @@ async def _async_call_lambda(
local.update(caller_env)
for p, v in zip(fn.params, args):
local[p] = v
return await async_eval(fn.body, local, ctx)
return _AsyncThunk(fn.body, local, ctx)
async def _async_call_component(
@@ -176,7 +204,7 @@ async def _async_call_component(
local[p] = kwargs.get(p, NIL)
if comp.has_children:
local["children"] = children
return await async_eval(comp.body, local, ctx)
return _AsyncThunk(comp.body, local, ctx)
# ---------------------------------------------------------------------------
@@ -184,28 +212,28 @@ async def _async_call_component(
# ---------------------------------------------------------------------------
async def _asf_if(expr, env, ctx):
cond = await async_eval(expr[1], env, ctx)
cond = await _async_trampoline(await _async_eval(expr[1], env, ctx))
if cond and cond is not NIL:
return await async_eval(expr[2], env, ctx)
return _AsyncThunk(expr[2], env, ctx)
if len(expr) > 3:
return await async_eval(expr[3], env, ctx)
return _AsyncThunk(expr[3], env, ctx)
return NIL
async def _asf_when(expr, env, ctx):
cond = await async_eval(expr[1], env, ctx)
cond = await _async_trampoline(await _async_eval(expr[1], env, ctx))
if cond and cond is not NIL:
result = NIL
for body_expr in expr[2:]:
result = await async_eval(body_expr, env, ctx)
return result
for body_expr in expr[2:-1]:
await _async_trampoline(await _async_eval(body_expr, env, ctx))
if len(expr) > 2:
return _AsyncThunk(expr[-1], env, ctx)
return NIL
async def _asf_and(expr, env, ctx):
result: Any = True
for arg in expr[1:]:
result = await async_eval(arg, env, ctx)
result = await _async_trampoline(await _async_eval(arg, env, ctx))
if not result:
return result
return result
@@ -214,7 +242,7 @@ async def _asf_and(expr, env, ctx):
async def _asf_or(expr, env, ctx):
result: Any = False
for arg in expr[1:]:
result = await async_eval(arg, env, ctx)
result = await _async_trampoline(await _async_eval(arg, env, ctx))
if result:
return result
return result
@@ -228,16 +256,17 @@ async def _asf_let(expr, env, ctx):
for binding in bindings:
var = binding[0]
vname = var.name if isinstance(var, Symbol) else var
local[vname] = await async_eval(binding[1], local, ctx)
local[vname] = await _async_trampoline(await _async_eval(binding[1], local, ctx))
elif len(bindings) % 2 == 0:
for i in range(0, len(bindings), 2):
var = bindings[i]
vname = var.name if isinstance(var, Symbol) else var
local[vname] = await async_eval(bindings[i + 1], local, ctx)
result: Any = NIL
for body_expr in expr[2:]:
result = await async_eval(body_expr, local, ctx)
return result
local[vname] = await _async_trampoline(await _async_eval(bindings[i + 1], local, ctx))
for body_expr in expr[2:-1]:
await _async_trampoline(await _async_eval(body_expr, local, ctx))
if len(expr) > 2:
return _AsyncThunk(expr[-1], local, ctx)
return NIL
async def _asf_lambda(expr, env, ctx):
@@ -253,7 +282,7 @@ async def _asf_lambda(expr, env, ctx):
async def _asf_define(expr, env, ctx):
name_sym = expr[1]
value = await async_eval(expr[2], env, ctx)
value = await _async_trampoline(await _async_eval(expr[2], env, ctx))
if isinstance(value, Lambda) and value.name is None:
value.name = name_sym.name
env[name_sym.name] = value
@@ -276,10 +305,11 @@ async def _asf_defhandler(expr, env, ctx):
async def _asf_begin(expr, env, ctx):
result: Any = NIL
for sub in expr[1:]:
result = await async_eval(sub, env, ctx)
return result
for sub in expr[1:-1]:
await _async_trampoline(await _async_eval(sub, env, ctx))
if len(expr) > 1:
return _AsyncThunk(expr[-1], env, ctx)
return NIL
async def _asf_quote(expr, env, ctx):
@@ -325,65 +355,65 @@ async def _asf_cond(expr, env, ctx):
for clause in clauses:
test = clause[0]
if isinstance(test, Symbol) and test.name in ("else", ":else"):
return await async_eval(clause[1], env, ctx)
return _AsyncThunk(clause[1], env, ctx)
if isinstance(test, Keyword) and test.name == "else":
return await async_eval(clause[1], env, ctx)
if await async_eval(test, env, ctx):
return await async_eval(clause[1], env, ctx)
return _AsyncThunk(clause[1], env, ctx)
if await _async_trampoline(await _async_eval(test, env, ctx)):
return _AsyncThunk(clause[1], env, ctx)
else:
i = 0
while i < len(clauses) - 1:
test = clauses[i]
result = clauses[i + 1]
if isinstance(test, Keyword) and test.name == "else":
return await async_eval(result, env, ctx)
return _AsyncThunk(result, env, ctx)
if isinstance(test, Symbol) and test.name in (":else", "else"):
return await async_eval(result, env, ctx)
if await async_eval(test, env, ctx):
return await async_eval(result, env, ctx)
return _AsyncThunk(result, env, ctx)
if await _async_trampoline(await _async_eval(test, env, ctx)):
return _AsyncThunk(result, env, ctx)
i += 2
return NIL
async def _asf_case(expr, env, ctx):
match_val = await async_eval(expr[1], env, ctx)
match_val = await _async_trampoline(await _async_eval(expr[1], env, ctx))
clauses = expr[2:]
i = 0
while i < len(clauses) - 1:
test = clauses[i]
result = clauses[i + 1]
if isinstance(test, Keyword) and test.name == "else":
return await async_eval(result, env, ctx)
return _AsyncThunk(result, env, ctx)
if isinstance(test, Symbol) and test.name in (":else", "else"):
return await async_eval(result, env, ctx)
if match_val == await async_eval(test, env, ctx):
return await async_eval(result, env, ctx)
return _AsyncThunk(result, env, ctx)
if match_val == await _async_trampoline(await _async_eval(test, env, ctx)):
return _AsyncThunk(result, env, ctx)
i += 2
return NIL
async def _asf_thread_first(expr, env, ctx):
result = await async_eval(expr[1], env, ctx)
result = await _async_trampoline(await _async_eval(expr[1], env, ctx))
for form in expr[2:]:
if isinstance(form, list):
fn = await async_eval(form[0], env, ctx)
args = [result] + [await async_eval(a, env, ctx) for a in form[1:]]
fn = await _async_trampoline(await _async_eval(form[0], env, ctx))
args = [result] + [await _async_trampoline(await _async_eval(a, env, ctx)) for a in form[1:]]
else:
fn = await async_eval(form, env, ctx)
fn = await _async_trampoline(await _async_eval(form, env, ctx))
args = [result]
if callable(fn) and not isinstance(fn, (Lambda, Component)):
result = fn(*args)
if inspect.iscoroutine(result):
result = await result
elif isinstance(fn, Lambda):
result = await _async_call_lambda(fn, args, env, ctx)
result = await _async_trampoline(await _async_call_lambda(fn, args, env, ctx))
else:
raise EvalError(f"-> form not callable: {fn!r}")
return result
async def _asf_set_bang(expr, env, ctx):
value = await async_eval(expr[2], env, ctx)
value = await _async_trampoline(await _async_eval(expr[2], env, ctx))
env[expr[1].name] = value
return value
@@ -422,7 +452,7 @@ async def _aho_map(expr, env, ctx):
results = []
for item in coll:
if isinstance(fn, Lambda):
results.append(await _async_call_lambda(fn, [item], env, ctx))
results.append(await _async_trampoline(await _async_call_lambda(fn, [item], env, ctx)))
elif callable(fn):
r = fn(item)
results.append(await r if inspect.iscoroutine(r) else r)
@@ -437,7 +467,7 @@ async def _aho_map_indexed(expr, env, ctx):
results = []
for i, item in enumerate(coll):
if isinstance(fn, Lambda):
results.append(await _async_call_lambda(fn, [i, item], env, ctx))
results.append(await _async_trampoline(await _async_call_lambda(fn, [i, item], env, ctx)))
elif callable(fn):
r = fn(i, item)
results.append(await r if inspect.iscoroutine(r) else r)
@@ -452,7 +482,7 @@ async def _aho_filter(expr, env, ctx):
results = []
for item in coll:
if isinstance(fn, Lambda):
val = await _async_call_lambda(fn, [item], env, ctx)
val = await _async_trampoline(await _async_call_lambda(fn, [item], env, ctx))
elif callable(fn):
val = fn(item)
if inspect.iscoroutine(val):
@@ -470,7 +500,7 @@ async def _aho_reduce(expr, env, ctx):
coll = await async_eval(expr[3], env, ctx)
for item in coll:
if isinstance(fn, Lambda):
acc = await _async_call_lambda(fn, [acc, item], env, ctx)
acc = await _async_trampoline(await _async_call_lambda(fn, [acc, item], env, ctx))
else:
acc = fn(acc, item)
if inspect.iscoroutine(acc):
@@ -483,7 +513,7 @@ async def _aho_some(expr, env, ctx):
coll = await async_eval(expr[2], env, ctx)
for item in coll:
if isinstance(fn, Lambda):
result = await _async_call_lambda(fn, [item], env, ctx)
result = await _async_trampoline(await _async_call_lambda(fn, [item], env, ctx))
else:
result = fn(item)
if inspect.iscoroutine(result):
@@ -498,7 +528,7 @@ async def _aho_every(expr, env, ctx):
coll = await async_eval(expr[2], env, ctx)
for item in coll:
if isinstance(fn, Lambda):
val = await _async_call_lambda(fn, [item], env, ctx)
val = await _async_trampoline(await _async_call_lambda(fn, [item], env, ctx))
else:
val = fn(item)
if inspect.iscoroutine(val):
@@ -513,7 +543,7 @@ async def _aho_for_each(expr, env, ctx):
coll = await async_eval(expr[2], env, ctx)
for item in coll:
if isinstance(fn, Lambda):
await _async_call_lambda(fn, [item], env, ctx)
await _async_trampoline(await _async_call_lambda(fn, [item], env, ctx))
elif callable(fn):
r = fn(item)
if inspect.iscoroutine(r):
@@ -1038,7 +1068,7 @@ async def _aser(expr: Any, env: dict[str, Any], ctx: RequestContext) -> Any:
return await result
return result
if isinstance(fn, Lambda):
return await _async_call_lambda(fn, args, env, ctx)
return await _async_trampoline(await _async_call_lambda(fn, args, env, ctx))
if isinstance(fn, Component):
# Component invoked as function — serialize the call
return await _aser_call(f"~{fn.name}", expr[1:], env, ctx)