"""Tests for the SX bytecode compiler + VM execution. Compiles SX expressions with compiler.sx (Python-side), executes on the OCaml VM via the bridge, verifies results match CEK evaluation. Usage: pytest shared/sx/tests/test_vm_compile.py -v """ import asyncio import os import sys import time import unittest _project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) if _project_root not in sys.path: sys.path.insert(0, _project_root) from shared.sx.parser import parse_all, serialize from shared.sx.ref.sx_ref import eval_expr, trampoline, PRIMITIVES from shared.sx.types import Symbol, Keyword, NIL from shared.sx.ocaml_bridge import OcamlBridge, OcamlBridgeError, _DEFAULT_BIN def _skip_if_no_binary(): bin_path = os.path.abspath(_DEFAULT_BIN) if not os.path.isfile(bin_path): raise unittest.SkipTest(f"OCaml binary not found at {bin_path}") # Register primitives needed by compiler.sx PRIMITIVES['serialize'] = lambda x: serialize(x) PRIMITIVES['primitive?'] = lambda name: isinstance(name, str) and name in PRIMITIVES PRIMITIVES['has-key?'] = lambda *a: isinstance(a[0], dict) and str(a[1]) in a[0] PRIMITIVES['set-nth!'] = lambda *a: (a[0].__setitem__(int(a[1]), a[2]), NIL)[-1] PRIMITIVES['init'] = lambda *a: a[0][:-1] if isinstance(a[0], list) else a[0] # Register HO forms as primitives so compiler emits CALL_PRIM (direct dispatch) # instead of CALL (which routes through CEK HO special forms) for _ho_name in ['map', 'map-indexed', 'filter', 'reduce', 'for-each', 'some', 'every?']: PRIMITIVES[_ho_name] = lambda *a: NIL # placeholder — OCaml primitives handle actual work PRIMITIVES['make-symbol'] = lambda name: Symbol(name) PRIMITIVES['concat'] = lambda *a: (a[0] or []) + (a[1] or []) PRIMITIVES['slice'] = lambda *a: a[0][int(a[1]):int(a[2])] if len(a) == 3 else a[0][int(a[1]):] def _load_compiler(): """Load compiler.sx into a Python env, return the compile function.""" env = {} for f in ['spec/bytecode.sx', 'spec/compiler.sx']: path = os.path.join(_project_root, f) with open(path) as fh: for expr in parse_all(fh.read()): trampoline(eval_expr(expr, env)) return env def _compile(env, src): """Compile an SX source string to bytecode dict.""" ast = parse_all(src)[0] return trampoline(eval_expr( [Symbol('compile'), [Symbol('quote'), ast]], env)) # Load compiler once for all tests _compiler_env = _load_compiler() class TestCompilerOutput(unittest.TestCase): """Test that the compiler produces valid bytecode for various SX patterns.""" def _compile(self, src): return _compile(_compiler_env, src) def test_arithmetic(self): result = self._compile('(+ 1 2)') self.assertIn('bytecode', result) self.assertIn('constants', result) bc = list(result['bytecode']) self.assertTrue(len(bc) > 0) def test_if_produces_jumps(self): result = self._compile('(if true "a" "b")') bc = list(result['bytecode']) # Should contain OP_JUMP_IF_FALSE (33) self.assertIn(33, bc) def test_let_uses_local_slots(self): result = self._compile('(let ((x 1)) x)') bc = list(result['bytecode']) # Should contain OP_LOCAL_SET (17) and OP_LOCAL_GET (16) self.assertIn(17, bc) self.assertIn(16, bc) def test_lambda_produces_closure(self): result = self._compile('(fn (x) (+ x 1))') bc = list(result['bytecode']) # Should contain OP_CLOSURE (51) self.assertIn(51, bc) def test_closure_captures_upvalue(self): result = self._compile('(let ((x 10)) (fn (y) (+ x y)))') bc = list(result['bytecode']) # Should have OP_CLOSURE with upvalue descriptors self.assertIn(51, bc) # Find closure index and check upvalue-count in constants consts = list(result['constants']) code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c] self.assertTrue(len(code_objs) > 0) code = code_objs[0] self.assertEqual(code.get('upvalue-count', 0), 1) # Inner bytecode should use OP_UPVALUE_GET (18) inner_bc = list(code['bytecode']) self.assertIn(18, inner_bc) def test_cond_compiles(self): result = self._compile('(cond (= x 1) "a" :else "b")') self.assertTrue(len(list(result['bytecode'])) > 0) def test_case_compiles(self): result = self._compile('(case x 1 "one" :else "other")') self.assertTrue(len(list(result['bytecode'])) > 0) def test_thread_first_compiles(self): result = self._compile('(-> x (+ 1))') self.assertTrue(len(list(result['bytecode'])) > 0) def test_begin_compiles(self): result = self._compile('(do (+ 1 2) (+ 3 4))') bc = list(result['bytecode']) # Should contain OP_POP (5) between expressions self.assertIn(5, bc) def test_define_compiles(self): result = self._compile('(define x 42)') bc = list(result['bytecode']) # Should contain OP_DEFINE (128) self.assertIn(128, bc) def test_nested_let_shares_frame(self): """Nested lets should use incrementing slot numbers, not restart at 0.""" result = self._compile('(let ((a 1)) (let ((b 2)) (+ a b)))') bc = list(result['bytecode']) # First LOC_SET should be slot 0, second should be slot 1 set_indices = [] for i, op in enumerate(bc): if op == 17 and i + 1 < len(bc): # OP_LOCAL_SET set_indices.append(bc[i + 1]) self.assertEqual(set_indices, [0, 1]) def test_tail_call(self): """Calls in tail position should use OP_TAIL_CALL.""" result = self._compile('(fn (x) (if (> x 0) (foo (- x 1)) 0))') consts = list(result['constants']) code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c] inner_bc = list(code_objs[0]['bytecode']) # Should contain OP_TAIL_CALL (49) for the recursive call self.assertIn(49, inner_bc) class TestVMExecution(unittest.IsolatedAsyncioTestCase): """Test that compiled bytecode executes correctly on the OCaml VM.""" @classmethod def setUpClass(cls): _skip_if_no_binary() async def asyncSetUp(self): self.bridge = OcamlBridge() await self.bridge.start() async def asyncTearDown(self): await self.bridge.stop() async def _vm_eval(self, src): """Compile SX source and execute on VM, return result.""" compiled = _compile(_compiler_env, src) code_sx = serialize(compiled) async with self.bridge._lock: await self.bridge._send(f'(vm-exec {code_sx})') return await self.bridge._read_until_ok(ctx=None) async def _cek_eval(self, src): """Evaluate SX source on CEK machine, return result.""" async with self.bridge._lock: await self.bridge._send(f'(eval "{_escape_for_ocaml(src)}")') return await self.bridge._read_until_ok(ctx=None) async def test_arithmetic(self): result = await self._vm_eval('(+ 1 2)') self.assertEqual(result.strip(), '3') async def test_nested_arithmetic(self): result = await self._vm_eval('(+ (* 3 4) (- 10 5))') self.assertEqual(result.strip(), '17') async def test_if_true(self): result = await self._vm_eval('(if true "yes" "no")') self.assertIn('yes', result) async def test_if_false(self): result = await self._vm_eval('(if false "yes" "no")') self.assertIn('no', result) async def test_let_binding(self): result = await self._vm_eval('(let ((x 10) (y 20)) (+ x y))') self.assertEqual(result.strip(), '30') async def test_nested_let(self): result = await self._vm_eval('(let ((a 1)) (let ((b 2)) (+ a b)))') self.assertEqual(result.strip(), '3') async def test_closure_captures_variable(self): result = await self._vm_eval( '(let ((x 10)) (let ((f (fn (y) (+ x y)))) (f 5)))') self.assertEqual(result.strip(), '15') async def test_closure_nested_capture(self): result = await self._vm_eval( '(let ((a 1) (b 2)) (let ((f (fn (c) (+ a (+ b c))))) (f 3)))') self.assertEqual(result.strip(), '6') async def test_and_short_circuit(self): result = await self._vm_eval('(and false (error "should not reach"))') self.assertEqual(result.strip(), 'false') async def test_or_short_circuit(self): result = await self._vm_eval('(or 42 (error "should not reach"))') self.assertEqual(result.strip(), '42') async def test_when_true(self): result = await self._vm_eval('(when true "yes")') self.assertIn('yes', result) async def test_when_false(self): result = await self._vm_eval('(when false "yes")') self.assertIn('nil', result.lower()) async def test_cond(self): result = await self._vm_eval( '(let ((x 2)) (cond (= x 1) "one" (= x 2) "two" :else "other"))') self.assertIn('two', result) async def test_string_primitives(self): result = await self._vm_eval('(str "hello" " " "world")') self.assertIn('hello world', result) async def test_list_construction(self): result = await self._vm_eval('(list 1 2 3)') self.assertIn('1', result) self.assertIn('2', result) self.assertIn('3', result) async def test_define_and_call(self): result = await self._vm_eval( '(do (define double (fn (x) (* x 2))) (double 21))') self.assertEqual(result.strip(), '42') async def test_higher_order_call(self): """A function that takes another function as argument.""" result = await self._vm_eval( '(let ((apply-fn (fn (f x) (f x)))) (apply-fn (fn (n) (* n 3)) 7))') self.assertEqual(result.strip(), '21') async def test_vm_matches_cek(self): """VM result must match CEK result for numeric expressions.""" test_exprs = [ ('(+ 1 2)', '3'), ('(* 3 (+ 4 5))', '27'), ('(let ((x 10)) (+ x 1))', '11'), ('(- 100 42)', '58'), ] for src, expected in test_exprs: vm_result = await self._vm_eval(src) self.assertEqual(vm_result.strip(), expected, f"VM wrong for {src}: got {vm_result}, expected {expected}") class TestVMAutoCompile(unittest.IsolatedAsyncioTestCase): """Test patterns that auto-compile needs to handle. These represent the 111 functions that currently fail.""" @classmethod def setUpClass(cls): _skip_if_no_binary() async def asyncSetUp(self): self.bridge = OcamlBridge() await self.bridge.start() async def asyncTearDown(self): await self.bridge.stop() async def _vm_eval(self, src): compiled = _compile(_compiler_env, src) code_sx = serialize(compiled) async with self.bridge._lock: await self.bridge._send(f'(vm-exec {code_sx})') return await self.bridge._read_until_ok(ctx=None) async def test_for_each_via_primitive(self): """for-each should work as a primitive call.""" result = await self._vm_eval( '(let ((sum 0)) (for-each (fn (x) (set! sum (+ sum x))) (list 1 2 3)) sum)') self.assertEqual(result.strip(), '6') async def test_map_via_primitive(self): """map should work as a primitive call.""" result = await self._vm_eval( '(map (fn (x) (* x 2)) (list 1 2 3))') self.assertIn('2', result) self.assertIn('4', result) self.assertIn('6', result) async def test_filter_via_primitive(self): """filter should work as a primitive call.""" result = await self._vm_eval( '(filter (fn (x) (> x 2)) (list 1 2 3 4 5))') self.assertIn('3', result) self.assertIn('4', result) self.assertIn('5', result) async def test_closure_over_mutable(self): """Closure capturing a set! target must share the mutation.""" result = await self._vm_eval( '(let ((count 0)) (let ((inc (fn () (set! count (+ count 1))))) (inc) (inc) (inc) count))') self.assertEqual(result.strip(), '3') async def test_recursive_function(self): """Recursive function via define.""" result = await self._vm_eval( '(do (define fact (fn (n) (if (<= n 1) 1 (* n (fact (- n 1)))))) (fact 5))') self.assertEqual(result.strip(), '120') async def test_string_building(self): """String concatenation — hot path for aser.""" result = await self._vm_eval( '(str "(" "div" " " ":class" ")")') self.assertIn('div', result) self.assertIn(':class', result) async def test_type_dispatch(self): """type-of dispatch — used heavily by aser.""" result = await self._vm_eval( '(cond (= (type-of 42) "number") "num" (= (type-of "x") "string") "str" :else "other")') self.assertIn('num', result) async def test_type_of_number(self): """type-of dispatch — foundation for aser.""" result = await self._vm_eval('(type-of 42)') self.assertIn('number', result) async def test_empty_list_check(self): result = await self._vm_eval('(empty? (list))') self.assertEqual(result.strip(), 'true') async def test_multiple_closures_same_scope(self): """Multiple closures capturing from the same let.""" result = await self._vm_eval(''' (let ((base 100)) (let ((add (fn (x) (+ base x))) (sub (fn (x) (- base x)))) (+ (add 10) (sub 10))))''') self.assertEqual(result.strip(), '200') def _escape_for_ocaml(s): """Escape a string for embedding in an OCaml SX command.""" return s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n') if __name__ == "__main__": unittest.main()