VM upvalues + HO primitives + 40 tests (36 pass, 4 fail)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-19 20:41:23 +00:00
parent 1bb40415a8
commit e7da397f8e
5 changed files with 432 additions and 8 deletions

View File

@@ -214,6 +214,16 @@ let setup_io_env env =
Sx_ref.eval_expr (List (fn_val :: call_args)) (Env env)
| _ -> raise (Eval_error "call-lambda: expected (fn args env?)"));
(* Register HO forms as callable NativeFn — the CEK machine handles them
as special forms, but the VM needs them as callable values in globals. *)
let ho_via_cek name =
bind name (fun args ->
Sx_ref.eval_expr (List (Symbol name :: args)) (Env env))
in
List.iter ho_via_cek [
"map"; "map-indexed"; "filter"; "reduce"; "some"; "every?"; "for-each";
];
(* Generic helper call — dispatches to Python page helpers *)
bind "helper" (fun args ->
io_request "helper" args)
@@ -626,6 +636,14 @@ let make_server_env () =
(* IO primitives *)
setup_io_env env;
(* Initialize trampoline ref so HO primitives (map, filter, etc.)
can call SX lambdas. Must be done here (not sx_runtime.ml)
because Sx_ref is only available at the binary level. *)
Sx_primitives._sx_trampoline_fn := (fun v ->
match v with
| Thunk (body, closure_env) -> Sx_ref.eval_expr body (Env closure_env)
| other -> other);
env
@@ -733,7 +751,7 @@ let dispatch env cmd =
(try
ignore (env_bind env "expand-components?" (NativeFn ("expand-components?", fun _args -> Bool true)));
(* Enable batch IO mode *)
io_batch_mode := true; Sx_ref._cek_steps := 0;
io_batch_mode := true;
io_queue := [];
io_counter := 0;
let exprs = Sx_parser.parse_all src in
@@ -758,8 +776,8 @@ let dispatch env cmd =
(* Flush batched IO: send requests, receive responses, replace placeholders *)
let final = flush_batched_io result_str in
let t2 = Unix.gettimeofday () in
Printf.eprintf "[aser-slot] eval=%.1fs io_flush=%.1fs batched=%d result=%d chars cek_steps=%d\n%!"
(t1 -. t0) (t2 -. t1) n_batched (String.length final) !Sx_ref._cek_steps;
Printf.eprintf "[aser-slot] eval=%.1fs io_flush=%.1fs batched=%d result=%d chars\n%!"
(t1 -. t0) (t2 -. t1) n_batched (String.length final);
send (Printf.sprintf "(ok-raw %s)" final)
with
| Eval_error msg ->

View File

@@ -7,6 +7,12 @@ open Sx_types
let primitives : (string, value list -> value) Hashtbl.t = Hashtbl.create 128
(** Forward refs for calling SX functions from primitives (breaks cycle). *)
let _sx_call_fn : (value -> value list -> value) ref =
ref (fun _ _ -> raise (Eval_error "sx_call not initialized"))
let _sx_trampoline_fn : (value -> value) ref =
ref (fun v -> v)
let register name fn = Hashtbl.replace primitives name fn
let is_primitive name = Hashtbl.mem primitives name
@@ -590,4 +596,34 @@ let () =
List.iter (fun (k, v) -> dict_set d k v) pairs;
Dict d
| _ -> raise (Eval_error "spread-attrs: 1 spread"));
(* Higher-order forms as callable primitives — used by the VM.
The CEK machine handles these as special forms with dedicated frames;
the VM needs them as plain callable values. *)
(* Call any SX callable — handles NativeFn, Lambda (via trampoline), VM closures *)
let call_any f args =
match f with
| NativeFn (_, fn) -> fn args
| _ -> !_sx_trampoline_fn (!_sx_call_fn f args)
in
register "map" (fun args ->
match args with
| [f; (List items | ListRef { contents = items })] ->
List (List.map (fun x -> call_any f [x]) items)
| _ -> raise (Eval_error "map: expected (fn list)"));
register "filter" (fun args ->
match args with
| [f; (List items | ListRef { contents = items })] ->
List (List.filter (fun x -> sx_truthy (call_any f [x])) items)
| _ -> raise (Eval_error "filter: expected (fn list)"));
register "for-each" (fun args ->
match args with
| [f; (List items | ListRef { contents = items })] ->
List.iter (fun x -> ignore (call_any f [x])) items; Nil
| _ -> raise (Eval_error "for-each: expected (fn list)"));
register "reduce" (fun args ->
match args with
| [f; init; (List items | ListRef { contents = items })] ->
List.fold_left (fun acc x -> call_any f [acc; x]) init items
| _ -> raise (Eval_error "reduce: expected (fn init list)"));
()

View File

@@ -23,14 +23,12 @@ let _prim_param_types_ref = ref Nil
(* === Transpiled from evaluator (frames + eval + CEK) === *)
(* make-cek-state *)
let _cek_steps = ref 0
let rec make_cek_state control env kont =
(incr _cek_steps; CekState { cs_control = control; cs_env = env; cs_kont = kont; cs_phase = "eval"; cs_value = Nil })
(CekState { cs_control = control; cs_env = env; cs_kont = kont; cs_phase = "eval"; cs_value = Nil })
(* make-cek-value *)
and make_cek_value value env kont =
(incr _cek_steps; CekState { cs_control = Nil; cs_env = env; cs_kont = kont; cs_phase = "continue"; cs_value = value })
(CekState { cs_control = Nil; cs_env = env; cs_kont = kont; cs_phase = "continue"; cs_value = value })
(* cek-terminal? *)
and cek_terminal_p state =

View File

@@ -46,12 +46,15 @@ let sx_call f args =
| Lambda l ->
let local = Sx_types.env_extend l.l_closure in
List.iter2 (fun p a -> ignore (Sx_types.env_bind local p a)) l.l_params args;
(* Return the body + env for the trampoline to evaluate *)
Thunk (l.l_body, local)
| Continuation (k, _) ->
k (match args with x :: _ -> x | [] -> Nil)
| _ -> raise (Eval_error ("Not callable: " ^ inspect f))
(* Initialize forward ref so primitives can call SX functions *)
let () = Sx_primitives._sx_call_fn := sx_call
(* Trampoline ref is set by sx_ref.ml after it's loaded *)
(** Apply a function to a list of args. *)
let sx_apply f args_list =
sx_call f (sx_to_list args_list)

View File

@@ -0,0 +1,369 @@
"""Tests for the SX bytecode compiler + VM execution.
Compiles SX expressions with compiler.sx (Python-side), executes
on the OCaml VM via the bridge, verifies results match CEK evaluation.
Usage:
pytest shared/sx/tests/test_vm_compile.py -v
"""
import asyncio
import os
import sys
import time
import unittest
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
if _project_root not in sys.path:
sys.path.insert(0, _project_root)
from shared.sx.parser import parse_all, serialize
from shared.sx.ref.sx_ref import eval_expr, trampoline, PRIMITIVES
from shared.sx.types import Symbol, Keyword, NIL
from shared.sx.ocaml_bridge import OcamlBridge, OcamlBridgeError, _DEFAULT_BIN
def _skip_if_no_binary():
bin_path = os.path.abspath(_DEFAULT_BIN)
if not os.path.isfile(bin_path):
raise unittest.SkipTest(f"OCaml binary not found at {bin_path}")
# Register primitives needed by compiler.sx
PRIMITIVES['serialize'] = lambda x: serialize(x)
PRIMITIVES['primitive?'] = lambda name: isinstance(name, str) and name in PRIMITIVES
PRIMITIVES['has-key?'] = lambda *a: isinstance(a[0], dict) and str(a[1]) in a[0]
PRIMITIVES['set-nth!'] = lambda *a: (a[0].__setitem__(int(a[1]), a[2]), NIL)[-1]
PRIMITIVES['init'] = lambda *a: a[0][:-1] if isinstance(a[0], list) else a[0]
PRIMITIVES['make-symbol'] = lambda name: Symbol(name)
PRIMITIVES['concat'] = lambda *a: (a[0] or []) + (a[1] or [])
PRIMITIVES['slice'] = lambda *a: a[0][int(a[1]):int(a[2])] if len(a) == 3 else a[0][int(a[1]):]
def _load_compiler():
"""Load compiler.sx into a Python env, return the compile function."""
env = {}
for f in ['spec/bytecode.sx', 'spec/compiler.sx']:
path = os.path.join(_project_root, f)
with open(path) as fh:
for expr in parse_all(fh.read()):
trampoline(eval_expr(expr, env))
return env
def _compile(env, src):
"""Compile an SX source string to bytecode dict."""
ast = parse_all(src)[0]
return trampoline(eval_expr(
[Symbol('compile'), [Symbol('quote'), ast]], env))
# Load compiler once for all tests
_compiler_env = _load_compiler()
class TestCompilerOutput(unittest.TestCase):
"""Test that the compiler produces valid bytecode for various SX patterns."""
def _compile(self, src):
return _compile(_compiler_env, src)
def test_arithmetic(self):
result = self._compile('(+ 1 2)')
self.assertIn('bytecode', result)
self.assertIn('constants', result)
bc = list(result['bytecode'])
self.assertTrue(len(bc) > 0)
def test_if_produces_jumps(self):
result = self._compile('(if true "a" "b")')
bc = list(result['bytecode'])
# Should contain OP_JUMP_IF_FALSE (33)
self.assertIn(33, bc)
def test_let_uses_local_slots(self):
result = self._compile('(let ((x 1)) x)')
bc = list(result['bytecode'])
# Should contain OP_LOCAL_SET (17) and OP_LOCAL_GET (16)
self.assertIn(17, bc)
self.assertIn(16, bc)
def test_lambda_produces_closure(self):
result = self._compile('(fn (x) (+ x 1))')
bc = list(result['bytecode'])
# Should contain OP_CLOSURE (51)
self.assertIn(51, bc)
def test_closure_captures_upvalue(self):
result = self._compile('(let ((x 10)) (fn (y) (+ x y)))')
bc = list(result['bytecode'])
# Should have OP_CLOSURE with upvalue descriptors
self.assertIn(51, bc)
# Find closure index and check upvalue-count in constants
consts = list(result['constants'])
code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c]
self.assertTrue(len(code_objs) > 0)
code = code_objs[0]
self.assertEqual(code.get('upvalue-count', 0), 1)
# Inner bytecode should use OP_UPVALUE_GET (18)
inner_bc = list(code['bytecode'])
self.assertIn(18, inner_bc)
def test_cond_compiles(self):
result = self._compile('(cond (= x 1) "a" :else "b")')
self.assertTrue(len(list(result['bytecode'])) > 0)
def test_case_compiles(self):
result = self._compile('(case x 1 "one" :else "other")')
self.assertTrue(len(list(result['bytecode'])) > 0)
def test_thread_first_compiles(self):
result = self._compile('(-> x (+ 1))')
self.assertTrue(len(list(result['bytecode'])) > 0)
def test_begin_compiles(self):
result = self._compile('(do (+ 1 2) (+ 3 4))')
bc = list(result['bytecode'])
# Should contain OP_POP (5) between expressions
self.assertIn(5, bc)
def test_define_compiles(self):
result = self._compile('(define x 42)')
bc = list(result['bytecode'])
# Should contain OP_DEFINE (128)
self.assertIn(128, bc)
def test_nested_let_shares_frame(self):
"""Nested lets should use incrementing slot numbers, not restart at 0."""
result = self._compile('(let ((a 1)) (let ((b 2)) (+ a b)))')
bc = list(result['bytecode'])
# First LOC_SET should be slot 0, second should be slot 1
set_indices = []
for i, op in enumerate(bc):
if op == 17 and i + 1 < len(bc): # OP_LOCAL_SET
set_indices.append(bc[i + 1])
self.assertEqual(set_indices, [0, 1])
def test_tail_call(self):
"""Calls in tail position should use OP_TAIL_CALL."""
result = self._compile('(fn (x) (if (> x 0) (foo (- x 1)) 0))')
consts = list(result['constants'])
code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c]
inner_bc = list(code_objs[0]['bytecode'])
# Should contain OP_TAIL_CALL (49) for the recursive call
self.assertIn(49, inner_bc)
class TestVMExecution(unittest.IsolatedAsyncioTestCase):
"""Test that compiled bytecode executes correctly on the OCaml VM."""
@classmethod
def setUpClass(cls):
_skip_if_no_binary()
async def asyncSetUp(self):
self.bridge = OcamlBridge()
await self.bridge.start()
async def asyncTearDown(self):
await self.bridge.stop()
async def _vm_eval(self, src):
"""Compile SX source and execute on VM, return result."""
compiled = _compile(_compiler_env, src)
code_sx = serialize(compiled)
async with self.bridge._lock:
await self.bridge._send(f'(vm-exec {code_sx})')
return await self.bridge._read_until_ok(ctx=None)
async def _cek_eval(self, src):
"""Evaluate SX source on CEK machine, return result."""
async with self.bridge._lock:
await self.bridge._send(f'(eval "{_escape_for_ocaml(src)}")')
return await self.bridge._read_until_ok(ctx=None)
async def test_arithmetic(self):
result = await self._vm_eval('(+ 1 2)')
self.assertEqual(result.strip(), '3')
async def test_nested_arithmetic(self):
result = await self._vm_eval('(+ (* 3 4) (- 10 5))')
self.assertEqual(result.strip(), '17')
async def test_if_true(self):
result = await self._vm_eval('(if true "yes" "no")')
self.assertIn('yes', result)
async def test_if_false(self):
result = await self._vm_eval('(if false "yes" "no")')
self.assertIn('no', result)
async def test_let_binding(self):
result = await self._vm_eval('(let ((x 10) (y 20)) (+ x y))')
self.assertEqual(result.strip(), '30')
async def test_nested_let(self):
result = await self._vm_eval('(let ((a 1)) (let ((b 2)) (+ a b)))')
self.assertEqual(result.strip(), '3')
async def test_closure_captures_variable(self):
result = await self._vm_eval(
'(let ((x 10)) (let ((f (fn (y) (+ x y)))) (f 5)))')
self.assertEqual(result.strip(), '15')
async def test_closure_nested_capture(self):
result = await self._vm_eval(
'(let ((a 1) (b 2)) (let ((f (fn (c) (+ a (+ b c))))) (f 3)))')
self.assertEqual(result.strip(), '6')
async def test_and_short_circuit(self):
result = await self._vm_eval('(and false (error "should not reach"))')
self.assertEqual(result.strip(), 'false')
async def test_or_short_circuit(self):
result = await self._vm_eval('(or 42 (error "should not reach"))')
self.assertEqual(result.strip(), '42')
async def test_when_true(self):
result = await self._vm_eval('(when true "yes")')
self.assertIn('yes', result)
async def test_when_false(self):
result = await self._vm_eval('(when false "yes")')
self.assertIn('nil', result.lower())
async def test_cond(self):
result = await self._vm_eval(
'(let ((x 2)) (cond (= x 1) "one" (= x 2) "two" :else "other"))')
self.assertIn('two', result)
async def test_string_primitives(self):
result = await self._vm_eval('(str "hello" " " "world")')
self.assertIn('hello world', result)
async def test_list_construction(self):
result = await self._vm_eval('(list 1 2 3)')
self.assertIn('1', result)
self.assertIn('2', result)
self.assertIn('3', result)
async def test_define_and_call(self):
result = await self._vm_eval(
'(do (define double (fn (x) (* x 2))) (double 21))')
self.assertEqual(result.strip(), '42')
async def test_higher_order_call(self):
"""A function that takes another function as argument."""
result = await self._vm_eval(
'(let ((apply-fn (fn (f x) (f x)))) (apply-fn (fn (n) (* n 3)) 7))')
self.assertEqual(result.strip(), '21')
async def test_vm_matches_cek(self):
"""VM result must match CEK result for numeric expressions."""
test_exprs = [
('(+ 1 2)', '3'),
('(* 3 (+ 4 5))', '27'),
('(let ((x 10)) (+ x 1))', '11'),
('(- 100 42)', '58'),
]
for src, expected in test_exprs:
vm_result = await self._vm_eval(src)
self.assertEqual(vm_result.strip(), expected,
f"VM wrong for {src}: got {vm_result}, expected {expected}")
class TestVMAutoCompile(unittest.IsolatedAsyncioTestCase):
"""Test patterns that auto-compile needs to handle.
These represent the 111 functions that currently fail."""
@classmethod
def setUpClass(cls):
_skip_if_no_binary()
async def asyncSetUp(self):
self.bridge = OcamlBridge()
await self.bridge.start()
async def asyncTearDown(self):
await self.bridge.stop()
async def _vm_eval(self, src):
compiled = _compile(_compiler_env, src)
code_sx = serialize(compiled)
async with self.bridge._lock:
await self.bridge._send(f'(vm-exec {code_sx})')
return await self.bridge._read_until_ok(ctx=None)
async def test_for_each_via_primitive(self):
"""for-each should work as a primitive call."""
result = await self._vm_eval(
'(let ((sum 0)) (for-each (fn (x) (set! sum (+ sum x))) (list 1 2 3)) sum)')
self.assertEqual(result.strip(), '6')
async def test_map_via_primitive(self):
"""map should work as a primitive call."""
result = await self._vm_eval(
'(map (fn (x) (* x 2)) (list 1 2 3))')
self.assertIn('2', result)
self.assertIn('4', result)
self.assertIn('6', result)
async def test_filter_via_primitive(self):
"""filter should work as a primitive call."""
result = await self._vm_eval(
'(filter (fn (x) (> x 2)) (list 1 2 3 4 5))')
self.assertIn('3', result)
self.assertIn('4', result)
self.assertIn('5', result)
async def test_closure_over_mutable(self):
"""Closure capturing a set! target must share the mutation."""
result = await self._vm_eval(
'(let ((count 0)) (let ((inc (fn () (set! count (+ count 1))))) (inc) (inc) (inc) count))')
self.assertEqual(result.strip(), '3')
async def test_recursive_function(self):
"""Recursive function via define."""
result = await self._vm_eval(
'(do (define fact (fn (n) (if (<= n 1) 1 (* n (fact (- n 1)))))) (fact 5))')
self.assertEqual(result.strip(), '120')
async def test_string_building(self):
"""String concatenation — hot path for aser."""
result = await self._vm_eval(
'(str "(" "div" " " ":class" ")")')
self.assertIn('div', result)
self.assertIn(':class', result)
async def test_type_dispatch(self):
"""type-of dispatch — used heavily by aser."""
result = await self._vm_eval(
'(cond (= (type-of 42) "number") "num" (= (type-of "x") "string") "str" :else "other")')
self.assertIn('num', result)
async def test_type_of_number(self):
"""type-of dispatch — foundation for aser."""
result = await self._vm_eval('(type-of 42)')
self.assertIn('number', result)
async def test_empty_list_check(self):
result = await self._vm_eval('(empty? (list))')
self.assertEqual(result.strip(), 'true')
async def test_multiple_closures_same_scope(self):
"""Multiple closures capturing from the same let."""
result = await self._vm_eval('''
(let ((base 100))
(let ((add (fn (x) (+ base x)))
(sub (fn (x) (- base x))))
(+ (add 10) (sub 10))))''')
self.assertEqual(result.strip(), '200')
def _escape_for_ocaml(s):
"""Escape a string for embedding in an OCaml SX command."""
return s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
if __name__ == "__main__":
unittest.main()