From e7da397f8ec78c477539657c7aefab679ac91144 Mon Sep 17 00:00:00 2001 From: giles Date: Thu, 19 Mar 2026 20:41:23 +0000 Subject: [PATCH] VM upvalues + HO primitives + 40 tests (36 pass, 4 fail) Co-Authored-By: Claude Opus 4.6 (1M context) --- hosts/ocaml/bin/sx_server.ml | 24 +- hosts/ocaml/lib/sx_primitives.ml | 36 +++ hosts/ocaml/lib/sx_ref.ml | 6 +- hosts/ocaml/lib/sx_runtime.ml | 5 +- shared/sx/tests/test_vm_compile.py | 369 +++++++++++++++++++++++++++++ 5 files changed, 432 insertions(+), 8 deletions(-) create mode 100644 shared/sx/tests/test_vm_compile.py diff --git a/hosts/ocaml/bin/sx_server.ml b/hosts/ocaml/bin/sx_server.ml index 6e9ad32..c5c55c3 100644 --- a/hosts/ocaml/bin/sx_server.ml +++ b/hosts/ocaml/bin/sx_server.ml @@ -214,6 +214,16 @@ let setup_io_env env = Sx_ref.eval_expr (List (fn_val :: call_args)) (Env env) | _ -> raise (Eval_error "call-lambda: expected (fn args env?)")); + (* Register HO forms as callable NativeFn — the CEK machine handles them + as special forms, but the VM needs them as callable values in globals. *) + let ho_via_cek name = + bind name (fun args -> + Sx_ref.eval_expr (List (Symbol name :: args)) (Env env)) + in + List.iter ho_via_cek [ + "map"; "map-indexed"; "filter"; "reduce"; "some"; "every?"; "for-each"; + ]; + (* Generic helper call — dispatches to Python page helpers *) bind "helper" (fun args -> io_request "helper" args) @@ -626,6 +636,14 @@ let make_server_env () = (* IO primitives *) setup_io_env env; + (* Initialize trampoline ref so HO primitives (map, filter, etc.) + can call SX lambdas. Must be done here (not sx_runtime.ml) + because Sx_ref is only available at the binary level. *) + Sx_primitives._sx_trampoline_fn := (fun v -> + match v with + | Thunk (body, closure_env) -> Sx_ref.eval_expr body (Env closure_env) + | other -> other); + env @@ -733,7 +751,7 @@ let dispatch env cmd = (try ignore (env_bind env "expand-components?" (NativeFn ("expand-components?", fun _args -> Bool true))); (* Enable batch IO mode *) - io_batch_mode := true; Sx_ref._cek_steps := 0; + io_batch_mode := true; io_queue := []; io_counter := 0; let exprs = Sx_parser.parse_all src in @@ -758,8 +776,8 @@ let dispatch env cmd = (* Flush batched IO: send requests, receive responses, replace placeholders *) let final = flush_batched_io result_str in let t2 = Unix.gettimeofday () in - Printf.eprintf "[aser-slot] eval=%.1fs io_flush=%.1fs batched=%d result=%d chars cek_steps=%d\n%!" - (t1 -. t0) (t2 -. t1) n_batched (String.length final) !Sx_ref._cek_steps; + Printf.eprintf "[aser-slot] eval=%.1fs io_flush=%.1fs batched=%d result=%d chars\n%!" + (t1 -. t0) (t2 -. t1) n_batched (String.length final); send (Printf.sprintf "(ok-raw %s)" final) with | Eval_error msg -> diff --git a/hosts/ocaml/lib/sx_primitives.ml b/hosts/ocaml/lib/sx_primitives.ml index 0eedfd4..e4b9cd2 100644 --- a/hosts/ocaml/lib/sx_primitives.ml +++ b/hosts/ocaml/lib/sx_primitives.ml @@ -7,6 +7,12 @@ open Sx_types let primitives : (string, value list -> value) Hashtbl.t = Hashtbl.create 128 +(** Forward refs for calling SX functions from primitives (breaks cycle). *) +let _sx_call_fn : (value -> value list -> value) ref = + ref (fun _ _ -> raise (Eval_error "sx_call not initialized")) +let _sx_trampoline_fn : (value -> value) ref = + ref (fun v -> v) + let register name fn = Hashtbl.replace primitives name fn let is_primitive name = Hashtbl.mem primitives name @@ -590,4 +596,34 @@ let () = List.iter (fun (k, v) -> dict_set d k v) pairs; Dict d | _ -> raise (Eval_error "spread-attrs: 1 spread")); + + (* Higher-order forms as callable primitives — used by the VM. + The CEK machine handles these as special forms with dedicated frames; + the VM needs them as plain callable values. *) + (* Call any SX callable — handles NativeFn, Lambda (via trampoline), VM closures *) + let call_any f args = + match f with + | NativeFn (_, fn) -> fn args + | _ -> !_sx_trampoline_fn (!_sx_call_fn f args) + in + register "map" (fun args -> + match args with + | [f; (List items | ListRef { contents = items })] -> + List (List.map (fun x -> call_any f [x]) items) + | _ -> raise (Eval_error "map: expected (fn list)")); + register "filter" (fun args -> + match args with + | [f; (List items | ListRef { contents = items })] -> + List (List.filter (fun x -> sx_truthy (call_any f [x])) items) + | _ -> raise (Eval_error "filter: expected (fn list)")); + register "for-each" (fun args -> + match args with + | [f; (List items | ListRef { contents = items })] -> + List.iter (fun x -> ignore (call_any f [x])) items; Nil + | _ -> raise (Eval_error "for-each: expected (fn list)")); + register "reduce" (fun args -> + match args with + | [f; init; (List items | ListRef { contents = items })] -> + List.fold_left (fun acc x -> call_any f [acc; x]) init items + | _ -> raise (Eval_error "reduce: expected (fn init list)")); () diff --git a/hosts/ocaml/lib/sx_ref.ml b/hosts/ocaml/lib/sx_ref.ml index ac6568d..859a574 100644 --- a/hosts/ocaml/lib/sx_ref.ml +++ b/hosts/ocaml/lib/sx_ref.ml @@ -23,14 +23,12 @@ let _prim_param_types_ref = ref Nil (* === Transpiled from evaluator (frames + eval + CEK) === *) (* make-cek-state *) -let _cek_steps = ref 0 - let rec make_cek_state control env kont = - (incr _cek_steps; CekState { cs_control = control; cs_env = env; cs_kont = kont; cs_phase = "eval"; cs_value = Nil }) + (CekState { cs_control = control; cs_env = env; cs_kont = kont; cs_phase = "eval"; cs_value = Nil }) (* make-cek-value *) and make_cek_value value env kont = - (incr _cek_steps; CekState { cs_control = Nil; cs_env = env; cs_kont = kont; cs_phase = "continue"; cs_value = value }) + (CekState { cs_control = Nil; cs_env = env; cs_kont = kont; cs_phase = "continue"; cs_value = value }) (* cek-terminal? *) and cek_terminal_p state = diff --git a/hosts/ocaml/lib/sx_runtime.ml b/hosts/ocaml/lib/sx_runtime.ml index 1003840..2e8f09a 100644 --- a/hosts/ocaml/lib/sx_runtime.ml +++ b/hosts/ocaml/lib/sx_runtime.ml @@ -46,12 +46,15 @@ let sx_call f args = | Lambda l -> let local = Sx_types.env_extend l.l_closure in List.iter2 (fun p a -> ignore (Sx_types.env_bind local p a)) l.l_params args; - (* Return the body + env for the trampoline to evaluate *) Thunk (l.l_body, local) | Continuation (k, _) -> k (match args with x :: _ -> x | [] -> Nil) | _ -> raise (Eval_error ("Not callable: " ^ inspect f)) +(* Initialize forward ref so primitives can call SX functions *) +let () = Sx_primitives._sx_call_fn := sx_call +(* Trampoline ref is set by sx_ref.ml after it's loaded *) + (** Apply a function to a list of args. *) let sx_apply f args_list = sx_call f (sx_to_list args_list) diff --git a/shared/sx/tests/test_vm_compile.py b/shared/sx/tests/test_vm_compile.py new file mode 100644 index 0000000..5ca4e2e --- /dev/null +++ b/shared/sx/tests/test_vm_compile.py @@ -0,0 +1,369 @@ +"""Tests for the SX bytecode compiler + VM execution. + +Compiles SX expressions with compiler.sx (Python-side), executes +on the OCaml VM via the bridge, verifies results match CEK evaluation. + +Usage: + pytest shared/sx/tests/test_vm_compile.py -v +""" + +import asyncio +import os +import sys +import time +import unittest + +_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) +if _project_root not in sys.path: + sys.path.insert(0, _project_root) + +from shared.sx.parser import parse_all, serialize +from shared.sx.ref.sx_ref import eval_expr, trampoline, PRIMITIVES +from shared.sx.types import Symbol, Keyword, NIL +from shared.sx.ocaml_bridge import OcamlBridge, OcamlBridgeError, _DEFAULT_BIN + + +def _skip_if_no_binary(): + bin_path = os.path.abspath(_DEFAULT_BIN) + if not os.path.isfile(bin_path): + raise unittest.SkipTest(f"OCaml binary not found at {bin_path}") + + +# Register primitives needed by compiler.sx +PRIMITIVES['serialize'] = lambda x: serialize(x) +PRIMITIVES['primitive?'] = lambda name: isinstance(name, str) and name in PRIMITIVES +PRIMITIVES['has-key?'] = lambda *a: isinstance(a[0], dict) and str(a[1]) in a[0] +PRIMITIVES['set-nth!'] = lambda *a: (a[0].__setitem__(int(a[1]), a[2]), NIL)[-1] +PRIMITIVES['init'] = lambda *a: a[0][:-1] if isinstance(a[0], list) else a[0] +PRIMITIVES['make-symbol'] = lambda name: Symbol(name) +PRIMITIVES['concat'] = lambda *a: (a[0] or []) + (a[1] or []) +PRIMITIVES['slice'] = lambda *a: a[0][int(a[1]):int(a[2])] if len(a) == 3 else a[0][int(a[1]):] + + +def _load_compiler(): + """Load compiler.sx into a Python env, return the compile function.""" + env = {} + for f in ['spec/bytecode.sx', 'spec/compiler.sx']: + path = os.path.join(_project_root, f) + with open(path) as fh: + for expr in parse_all(fh.read()): + trampoline(eval_expr(expr, env)) + return env + + +def _compile(env, src): + """Compile an SX source string to bytecode dict.""" + ast = parse_all(src)[0] + return trampoline(eval_expr( + [Symbol('compile'), [Symbol('quote'), ast]], env)) + + +# Load compiler once for all tests +_compiler_env = _load_compiler() + + +class TestCompilerOutput(unittest.TestCase): + """Test that the compiler produces valid bytecode for various SX patterns.""" + + def _compile(self, src): + return _compile(_compiler_env, src) + + def test_arithmetic(self): + result = self._compile('(+ 1 2)') + self.assertIn('bytecode', result) + self.assertIn('constants', result) + bc = list(result['bytecode']) + self.assertTrue(len(bc) > 0) + + def test_if_produces_jumps(self): + result = self._compile('(if true "a" "b")') + bc = list(result['bytecode']) + # Should contain OP_JUMP_IF_FALSE (33) + self.assertIn(33, bc) + + def test_let_uses_local_slots(self): + result = self._compile('(let ((x 1)) x)') + bc = list(result['bytecode']) + # Should contain OP_LOCAL_SET (17) and OP_LOCAL_GET (16) + self.assertIn(17, bc) + self.assertIn(16, bc) + + def test_lambda_produces_closure(self): + result = self._compile('(fn (x) (+ x 1))') + bc = list(result['bytecode']) + # Should contain OP_CLOSURE (51) + self.assertIn(51, bc) + + def test_closure_captures_upvalue(self): + result = self._compile('(let ((x 10)) (fn (y) (+ x y)))') + bc = list(result['bytecode']) + # Should have OP_CLOSURE with upvalue descriptors + self.assertIn(51, bc) + # Find closure index and check upvalue-count in constants + consts = list(result['constants']) + code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c] + self.assertTrue(len(code_objs) > 0) + code = code_objs[0] + self.assertEqual(code.get('upvalue-count', 0), 1) + # Inner bytecode should use OP_UPVALUE_GET (18) + inner_bc = list(code['bytecode']) + self.assertIn(18, inner_bc) + + def test_cond_compiles(self): + result = self._compile('(cond (= x 1) "a" :else "b")') + self.assertTrue(len(list(result['bytecode'])) > 0) + + def test_case_compiles(self): + result = self._compile('(case x 1 "one" :else "other")') + self.assertTrue(len(list(result['bytecode'])) > 0) + + def test_thread_first_compiles(self): + result = self._compile('(-> x (+ 1))') + self.assertTrue(len(list(result['bytecode'])) > 0) + + def test_begin_compiles(self): + result = self._compile('(do (+ 1 2) (+ 3 4))') + bc = list(result['bytecode']) + # Should contain OP_POP (5) between expressions + self.assertIn(5, bc) + + def test_define_compiles(self): + result = self._compile('(define x 42)') + bc = list(result['bytecode']) + # Should contain OP_DEFINE (128) + self.assertIn(128, bc) + + def test_nested_let_shares_frame(self): + """Nested lets should use incrementing slot numbers, not restart at 0.""" + result = self._compile('(let ((a 1)) (let ((b 2)) (+ a b)))') + bc = list(result['bytecode']) + # First LOC_SET should be slot 0, second should be slot 1 + set_indices = [] + for i, op in enumerate(bc): + if op == 17 and i + 1 < len(bc): # OP_LOCAL_SET + set_indices.append(bc[i + 1]) + self.assertEqual(set_indices, [0, 1]) + + def test_tail_call(self): + """Calls in tail position should use OP_TAIL_CALL.""" + result = self._compile('(fn (x) (if (> x 0) (foo (- x 1)) 0))') + consts = list(result['constants']) + code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c] + inner_bc = list(code_objs[0]['bytecode']) + # Should contain OP_TAIL_CALL (49) for the recursive call + self.assertIn(49, inner_bc) + + +class TestVMExecution(unittest.IsolatedAsyncioTestCase): + """Test that compiled bytecode executes correctly on the OCaml VM.""" + + @classmethod + def setUpClass(cls): + _skip_if_no_binary() + + async def asyncSetUp(self): + self.bridge = OcamlBridge() + await self.bridge.start() + + async def asyncTearDown(self): + await self.bridge.stop() + + async def _vm_eval(self, src): + """Compile SX source and execute on VM, return result.""" + compiled = _compile(_compiler_env, src) + code_sx = serialize(compiled) + async with self.bridge._lock: + await self.bridge._send(f'(vm-exec {code_sx})') + return await self.bridge._read_until_ok(ctx=None) + + async def _cek_eval(self, src): + """Evaluate SX source on CEK machine, return result.""" + async with self.bridge._lock: + await self.bridge._send(f'(eval "{_escape_for_ocaml(src)}")') + return await self.bridge._read_until_ok(ctx=None) + + async def test_arithmetic(self): + result = await self._vm_eval('(+ 1 2)') + self.assertEqual(result.strip(), '3') + + async def test_nested_arithmetic(self): + result = await self._vm_eval('(+ (* 3 4) (- 10 5))') + self.assertEqual(result.strip(), '17') + + async def test_if_true(self): + result = await self._vm_eval('(if true "yes" "no")') + self.assertIn('yes', result) + + async def test_if_false(self): + result = await self._vm_eval('(if false "yes" "no")') + self.assertIn('no', result) + + async def test_let_binding(self): + result = await self._vm_eval('(let ((x 10) (y 20)) (+ x y))') + self.assertEqual(result.strip(), '30') + + async def test_nested_let(self): + result = await self._vm_eval('(let ((a 1)) (let ((b 2)) (+ a b)))') + self.assertEqual(result.strip(), '3') + + async def test_closure_captures_variable(self): + result = await self._vm_eval( + '(let ((x 10)) (let ((f (fn (y) (+ x y)))) (f 5)))') + self.assertEqual(result.strip(), '15') + + async def test_closure_nested_capture(self): + result = await self._vm_eval( + '(let ((a 1) (b 2)) (let ((f (fn (c) (+ a (+ b c))))) (f 3)))') + self.assertEqual(result.strip(), '6') + + async def test_and_short_circuit(self): + result = await self._vm_eval('(and false (error "should not reach"))') + self.assertEqual(result.strip(), 'false') + + async def test_or_short_circuit(self): + result = await self._vm_eval('(or 42 (error "should not reach"))') + self.assertEqual(result.strip(), '42') + + async def test_when_true(self): + result = await self._vm_eval('(when true "yes")') + self.assertIn('yes', result) + + async def test_when_false(self): + result = await self._vm_eval('(when false "yes")') + self.assertIn('nil', result.lower()) + + async def test_cond(self): + result = await self._vm_eval( + '(let ((x 2)) (cond (= x 1) "one" (= x 2) "two" :else "other"))') + self.assertIn('two', result) + + async def test_string_primitives(self): + result = await self._vm_eval('(str "hello" " " "world")') + self.assertIn('hello world', result) + + async def test_list_construction(self): + result = await self._vm_eval('(list 1 2 3)') + self.assertIn('1', result) + self.assertIn('2', result) + self.assertIn('3', result) + + async def test_define_and_call(self): + result = await self._vm_eval( + '(do (define double (fn (x) (* x 2))) (double 21))') + self.assertEqual(result.strip(), '42') + + async def test_higher_order_call(self): + """A function that takes another function as argument.""" + result = await self._vm_eval( + '(let ((apply-fn (fn (f x) (f x)))) (apply-fn (fn (n) (* n 3)) 7))') + self.assertEqual(result.strip(), '21') + + async def test_vm_matches_cek(self): + """VM result must match CEK result for numeric expressions.""" + test_exprs = [ + ('(+ 1 2)', '3'), + ('(* 3 (+ 4 5))', '27'), + ('(let ((x 10)) (+ x 1))', '11'), + ('(- 100 42)', '58'), + ] + for src, expected in test_exprs: + vm_result = await self._vm_eval(src) + self.assertEqual(vm_result.strip(), expected, + f"VM wrong for {src}: got {vm_result}, expected {expected}") + + +class TestVMAutoCompile(unittest.IsolatedAsyncioTestCase): + """Test patterns that auto-compile needs to handle. + These represent the 111 functions that currently fail.""" + + @classmethod + def setUpClass(cls): + _skip_if_no_binary() + + async def asyncSetUp(self): + self.bridge = OcamlBridge() + await self.bridge.start() + + async def asyncTearDown(self): + await self.bridge.stop() + + async def _vm_eval(self, src): + compiled = _compile(_compiler_env, src) + code_sx = serialize(compiled) + async with self.bridge._lock: + await self.bridge._send(f'(vm-exec {code_sx})') + return await self.bridge._read_until_ok(ctx=None) + + async def test_for_each_via_primitive(self): + """for-each should work as a primitive call.""" + result = await self._vm_eval( + '(let ((sum 0)) (for-each (fn (x) (set! sum (+ sum x))) (list 1 2 3)) sum)') + self.assertEqual(result.strip(), '6') + + async def test_map_via_primitive(self): + """map should work as a primitive call.""" + result = await self._vm_eval( + '(map (fn (x) (* x 2)) (list 1 2 3))') + self.assertIn('2', result) + self.assertIn('4', result) + self.assertIn('6', result) + + async def test_filter_via_primitive(self): + """filter should work as a primitive call.""" + result = await self._vm_eval( + '(filter (fn (x) (> x 2)) (list 1 2 3 4 5))') + self.assertIn('3', result) + self.assertIn('4', result) + self.assertIn('5', result) + + async def test_closure_over_mutable(self): + """Closure capturing a set! target must share the mutation.""" + result = await self._vm_eval( + '(let ((count 0)) (let ((inc (fn () (set! count (+ count 1))))) (inc) (inc) (inc) count))') + self.assertEqual(result.strip(), '3') + + async def test_recursive_function(self): + """Recursive function via define.""" + result = await self._vm_eval( + '(do (define fact (fn (n) (if (<= n 1) 1 (* n (fact (- n 1)))))) (fact 5))') + self.assertEqual(result.strip(), '120') + + async def test_string_building(self): + """String concatenation — hot path for aser.""" + result = await self._vm_eval( + '(str "(" "div" " " ":class" ")")') + self.assertIn('div', result) + self.assertIn(':class', result) + + async def test_type_dispatch(self): + """type-of dispatch — used heavily by aser.""" + result = await self._vm_eval( + '(cond (= (type-of 42) "number") "num" (= (type-of "x") "string") "str" :else "other")') + self.assertIn('num', result) + + async def test_type_of_number(self): + """type-of dispatch — foundation for aser.""" + result = await self._vm_eval('(type-of 42)') + self.assertIn('number', result) + + async def test_empty_list_check(self): + result = await self._vm_eval('(empty? (list))') + self.assertEqual(result.strip(), 'true') + + async def test_multiple_closures_same_scope(self): + """Multiple closures capturing from the same let.""" + result = await self._vm_eval(''' + (let ((base 100)) + (let ((add (fn (x) (+ base x))) + (sub (fn (x) (- base x)))) + (+ (add 10) (sub 10))))''') + self.assertEqual(result.strip(), '200') + + +def _escape_for_ocaml(s): + """Escape a string for embedding in an OCaml SX command.""" + return s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') + + +if __name__ == "__main__": + unittest.main()