Files
rose-ash/shared/sx/tests/test_vm_compile.py
2026-03-19 20:41:23 +00:00

370 lines
14 KiB
Python

"""Tests for the SX bytecode compiler + VM execution.
Compiles SX expressions with compiler.sx (Python-side), executes
on the OCaml VM via the bridge, verifies results match CEK evaluation.
Usage:
pytest shared/sx/tests/test_vm_compile.py -v
"""
import asyncio
import os
import sys
import time
import unittest
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
if _project_root not in sys.path:
sys.path.insert(0, _project_root)
from shared.sx.parser import parse_all, serialize
from shared.sx.ref.sx_ref import eval_expr, trampoline, PRIMITIVES
from shared.sx.types import Symbol, Keyword, NIL
from shared.sx.ocaml_bridge import OcamlBridge, OcamlBridgeError, _DEFAULT_BIN
def _skip_if_no_binary():
bin_path = os.path.abspath(_DEFAULT_BIN)
if not os.path.isfile(bin_path):
raise unittest.SkipTest(f"OCaml binary not found at {bin_path}")
# Register primitives needed by compiler.sx
PRIMITIVES['serialize'] = lambda x: serialize(x)
PRIMITIVES['primitive?'] = lambda name: isinstance(name, str) and name in PRIMITIVES
PRIMITIVES['has-key?'] = lambda *a: isinstance(a[0], dict) and str(a[1]) in a[0]
PRIMITIVES['set-nth!'] = lambda *a: (a[0].__setitem__(int(a[1]), a[2]), NIL)[-1]
PRIMITIVES['init'] = lambda *a: a[0][:-1] if isinstance(a[0], list) else a[0]
PRIMITIVES['make-symbol'] = lambda name: Symbol(name)
PRIMITIVES['concat'] = lambda *a: (a[0] or []) + (a[1] or [])
PRIMITIVES['slice'] = lambda *a: a[0][int(a[1]):int(a[2])] if len(a) == 3 else a[0][int(a[1]):]
def _load_compiler():
"""Load compiler.sx into a Python env, return the compile function."""
env = {}
for f in ['spec/bytecode.sx', 'spec/compiler.sx']:
path = os.path.join(_project_root, f)
with open(path) as fh:
for expr in parse_all(fh.read()):
trampoline(eval_expr(expr, env))
return env
def _compile(env, src):
"""Compile an SX source string to bytecode dict."""
ast = parse_all(src)[0]
return trampoline(eval_expr(
[Symbol('compile'), [Symbol('quote'), ast]], env))
# Load compiler once for all tests
_compiler_env = _load_compiler()
class TestCompilerOutput(unittest.TestCase):
"""Test that the compiler produces valid bytecode for various SX patterns."""
def _compile(self, src):
return _compile(_compiler_env, src)
def test_arithmetic(self):
result = self._compile('(+ 1 2)')
self.assertIn('bytecode', result)
self.assertIn('constants', result)
bc = list(result['bytecode'])
self.assertTrue(len(bc) > 0)
def test_if_produces_jumps(self):
result = self._compile('(if true "a" "b")')
bc = list(result['bytecode'])
# Should contain OP_JUMP_IF_FALSE (33)
self.assertIn(33, bc)
def test_let_uses_local_slots(self):
result = self._compile('(let ((x 1)) x)')
bc = list(result['bytecode'])
# Should contain OP_LOCAL_SET (17) and OP_LOCAL_GET (16)
self.assertIn(17, bc)
self.assertIn(16, bc)
def test_lambda_produces_closure(self):
result = self._compile('(fn (x) (+ x 1))')
bc = list(result['bytecode'])
# Should contain OP_CLOSURE (51)
self.assertIn(51, bc)
def test_closure_captures_upvalue(self):
result = self._compile('(let ((x 10)) (fn (y) (+ x y)))')
bc = list(result['bytecode'])
# Should have OP_CLOSURE with upvalue descriptors
self.assertIn(51, bc)
# Find closure index and check upvalue-count in constants
consts = list(result['constants'])
code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c]
self.assertTrue(len(code_objs) > 0)
code = code_objs[0]
self.assertEqual(code.get('upvalue-count', 0), 1)
# Inner bytecode should use OP_UPVALUE_GET (18)
inner_bc = list(code['bytecode'])
self.assertIn(18, inner_bc)
def test_cond_compiles(self):
result = self._compile('(cond (= x 1) "a" :else "b")')
self.assertTrue(len(list(result['bytecode'])) > 0)
def test_case_compiles(self):
result = self._compile('(case x 1 "one" :else "other")')
self.assertTrue(len(list(result['bytecode'])) > 0)
def test_thread_first_compiles(self):
result = self._compile('(-> x (+ 1))')
self.assertTrue(len(list(result['bytecode'])) > 0)
def test_begin_compiles(self):
result = self._compile('(do (+ 1 2) (+ 3 4))')
bc = list(result['bytecode'])
# Should contain OP_POP (5) between expressions
self.assertIn(5, bc)
def test_define_compiles(self):
result = self._compile('(define x 42)')
bc = list(result['bytecode'])
# Should contain OP_DEFINE (128)
self.assertIn(128, bc)
def test_nested_let_shares_frame(self):
"""Nested lets should use incrementing slot numbers, not restart at 0."""
result = self._compile('(let ((a 1)) (let ((b 2)) (+ a b)))')
bc = list(result['bytecode'])
# First LOC_SET should be slot 0, second should be slot 1
set_indices = []
for i, op in enumerate(bc):
if op == 17 and i + 1 < len(bc): # OP_LOCAL_SET
set_indices.append(bc[i + 1])
self.assertEqual(set_indices, [0, 1])
def test_tail_call(self):
"""Calls in tail position should use OP_TAIL_CALL."""
result = self._compile('(fn (x) (if (> x 0) (foo (- x 1)) 0))')
consts = list(result['constants'])
code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c]
inner_bc = list(code_objs[0]['bytecode'])
# Should contain OP_TAIL_CALL (49) for the recursive call
self.assertIn(49, inner_bc)
class TestVMExecution(unittest.IsolatedAsyncioTestCase):
"""Test that compiled bytecode executes correctly on the OCaml VM."""
@classmethod
def setUpClass(cls):
_skip_if_no_binary()
async def asyncSetUp(self):
self.bridge = OcamlBridge()
await self.bridge.start()
async def asyncTearDown(self):
await self.bridge.stop()
async def _vm_eval(self, src):
"""Compile SX source and execute on VM, return result."""
compiled = _compile(_compiler_env, src)
code_sx = serialize(compiled)
async with self.bridge._lock:
await self.bridge._send(f'(vm-exec {code_sx})')
return await self.bridge._read_until_ok(ctx=None)
async def _cek_eval(self, src):
"""Evaluate SX source on CEK machine, return result."""
async with self.bridge._lock:
await self.bridge._send(f'(eval "{_escape_for_ocaml(src)}")')
return await self.bridge._read_until_ok(ctx=None)
async def test_arithmetic(self):
result = await self._vm_eval('(+ 1 2)')
self.assertEqual(result.strip(), '3')
async def test_nested_arithmetic(self):
result = await self._vm_eval('(+ (* 3 4) (- 10 5))')
self.assertEqual(result.strip(), '17')
async def test_if_true(self):
result = await self._vm_eval('(if true "yes" "no")')
self.assertIn('yes', result)
async def test_if_false(self):
result = await self._vm_eval('(if false "yes" "no")')
self.assertIn('no', result)
async def test_let_binding(self):
result = await self._vm_eval('(let ((x 10) (y 20)) (+ x y))')
self.assertEqual(result.strip(), '30')
async def test_nested_let(self):
result = await self._vm_eval('(let ((a 1)) (let ((b 2)) (+ a b)))')
self.assertEqual(result.strip(), '3')
async def test_closure_captures_variable(self):
result = await self._vm_eval(
'(let ((x 10)) (let ((f (fn (y) (+ x y)))) (f 5)))')
self.assertEqual(result.strip(), '15')
async def test_closure_nested_capture(self):
result = await self._vm_eval(
'(let ((a 1) (b 2)) (let ((f (fn (c) (+ a (+ b c))))) (f 3)))')
self.assertEqual(result.strip(), '6')
async def test_and_short_circuit(self):
result = await self._vm_eval('(and false (error "should not reach"))')
self.assertEqual(result.strip(), 'false')
async def test_or_short_circuit(self):
result = await self._vm_eval('(or 42 (error "should not reach"))')
self.assertEqual(result.strip(), '42')
async def test_when_true(self):
result = await self._vm_eval('(when true "yes")')
self.assertIn('yes', result)
async def test_when_false(self):
result = await self._vm_eval('(when false "yes")')
self.assertIn('nil', result.lower())
async def test_cond(self):
result = await self._vm_eval(
'(let ((x 2)) (cond (= x 1) "one" (= x 2) "two" :else "other"))')
self.assertIn('two', result)
async def test_string_primitives(self):
result = await self._vm_eval('(str "hello" " " "world")')
self.assertIn('hello world', result)
async def test_list_construction(self):
result = await self._vm_eval('(list 1 2 3)')
self.assertIn('1', result)
self.assertIn('2', result)
self.assertIn('3', result)
async def test_define_and_call(self):
result = await self._vm_eval(
'(do (define double (fn (x) (* x 2))) (double 21))')
self.assertEqual(result.strip(), '42')
async def test_higher_order_call(self):
"""A function that takes another function as argument."""
result = await self._vm_eval(
'(let ((apply-fn (fn (f x) (f x)))) (apply-fn (fn (n) (* n 3)) 7))')
self.assertEqual(result.strip(), '21')
async def test_vm_matches_cek(self):
"""VM result must match CEK result for numeric expressions."""
test_exprs = [
('(+ 1 2)', '3'),
('(* 3 (+ 4 5))', '27'),
('(let ((x 10)) (+ x 1))', '11'),
('(- 100 42)', '58'),
]
for src, expected in test_exprs:
vm_result = await self._vm_eval(src)
self.assertEqual(vm_result.strip(), expected,
f"VM wrong for {src}: got {vm_result}, expected {expected}")
class TestVMAutoCompile(unittest.IsolatedAsyncioTestCase):
"""Test patterns that auto-compile needs to handle.
These represent the 111 functions that currently fail."""
@classmethod
def setUpClass(cls):
_skip_if_no_binary()
async def asyncSetUp(self):
self.bridge = OcamlBridge()
await self.bridge.start()
async def asyncTearDown(self):
await self.bridge.stop()
async def _vm_eval(self, src):
compiled = _compile(_compiler_env, src)
code_sx = serialize(compiled)
async with self.bridge._lock:
await self.bridge._send(f'(vm-exec {code_sx})')
return await self.bridge._read_until_ok(ctx=None)
async def test_for_each_via_primitive(self):
"""for-each should work as a primitive call."""
result = await self._vm_eval(
'(let ((sum 0)) (for-each (fn (x) (set! sum (+ sum x))) (list 1 2 3)) sum)')
self.assertEqual(result.strip(), '6')
async def test_map_via_primitive(self):
"""map should work as a primitive call."""
result = await self._vm_eval(
'(map (fn (x) (* x 2)) (list 1 2 3))')
self.assertIn('2', result)
self.assertIn('4', result)
self.assertIn('6', result)
async def test_filter_via_primitive(self):
"""filter should work as a primitive call."""
result = await self._vm_eval(
'(filter (fn (x) (> x 2)) (list 1 2 3 4 5))')
self.assertIn('3', result)
self.assertIn('4', result)
self.assertIn('5', result)
async def test_closure_over_mutable(self):
"""Closure capturing a set! target must share the mutation."""
result = await self._vm_eval(
'(let ((count 0)) (let ((inc (fn () (set! count (+ count 1))))) (inc) (inc) (inc) count))')
self.assertEqual(result.strip(), '3')
async def test_recursive_function(self):
"""Recursive function via define."""
result = await self._vm_eval(
'(do (define fact (fn (n) (if (<= n 1) 1 (* n (fact (- n 1)))))) (fact 5))')
self.assertEqual(result.strip(), '120')
async def test_string_building(self):
"""String concatenation — hot path for aser."""
result = await self._vm_eval(
'(str "(" "div" " " ":class" ")")')
self.assertIn('div', result)
self.assertIn(':class', result)
async def test_type_dispatch(self):
"""type-of dispatch — used heavily by aser."""
result = await self._vm_eval(
'(cond (= (type-of 42) "number") "num" (= (type-of "x") "string") "str" :else "other")')
self.assertIn('num', result)
async def test_type_of_number(self):
"""type-of dispatch — foundation for aser."""
result = await self._vm_eval('(type-of 42)')
self.assertIn('number', result)
async def test_empty_list_check(self):
result = await self._vm_eval('(empty? (list))')
self.assertEqual(result.strip(), 'true')
async def test_multiple_closures_same_scope(self):
"""Multiple closures capturing from the same let."""
result = await self._vm_eval('''
(let ((base 100))
(let ((add (fn (x) (+ base x)))
(sub (fn (x) (- base x))))
(+ (add 10) (sub 10))))''')
self.assertEqual(result.strip(), '200')
def _escape_for_ocaml(s):
"""Escape a string for embedding in an OCaml SX command."""
return s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
if __name__ == "__main__":
unittest.main()