370 lines
14 KiB
Python
370 lines
14 KiB
Python
"""Tests for the SX bytecode compiler + VM execution.
|
|
|
|
Compiles SX expressions with compiler.sx (Python-side), executes
|
|
on the OCaml VM via the bridge, verifies results match CEK evaluation.
|
|
|
|
Usage:
|
|
pytest shared/sx/tests/test_vm_compile.py -v
|
|
"""
|
|
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
import time
|
|
import unittest
|
|
|
|
_project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
|
if _project_root not in sys.path:
|
|
sys.path.insert(0, _project_root)
|
|
|
|
from shared.sx.parser import parse_all, serialize
|
|
from shared.sx.ref.sx_ref import eval_expr, trampoline, PRIMITIVES
|
|
from shared.sx.types import Symbol, Keyword, NIL
|
|
from shared.sx.ocaml_bridge import OcamlBridge, OcamlBridgeError, _DEFAULT_BIN
|
|
|
|
|
|
def _skip_if_no_binary():
|
|
bin_path = os.path.abspath(_DEFAULT_BIN)
|
|
if not os.path.isfile(bin_path):
|
|
raise unittest.SkipTest(f"OCaml binary not found at {bin_path}")
|
|
|
|
|
|
# Register primitives needed by compiler.sx
|
|
PRIMITIVES['serialize'] = lambda x: serialize(x)
|
|
PRIMITIVES['primitive?'] = lambda name: isinstance(name, str) and name in PRIMITIVES
|
|
PRIMITIVES['has-key?'] = lambda *a: isinstance(a[0], dict) and str(a[1]) in a[0]
|
|
PRIMITIVES['set-nth!'] = lambda *a: (a[0].__setitem__(int(a[1]), a[2]), NIL)[-1]
|
|
PRIMITIVES['init'] = lambda *a: a[0][:-1] if isinstance(a[0], list) else a[0]
|
|
PRIMITIVES['make-symbol'] = lambda name: Symbol(name)
|
|
PRIMITIVES['concat'] = lambda *a: (a[0] or []) + (a[1] or [])
|
|
PRIMITIVES['slice'] = lambda *a: a[0][int(a[1]):int(a[2])] if len(a) == 3 else a[0][int(a[1]):]
|
|
|
|
|
|
def _load_compiler():
|
|
"""Load compiler.sx into a Python env, return the compile function."""
|
|
env = {}
|
|
for f in ['spec/bytecode.sx', 'spec/compiler.sx']:
|
|
path = os.path.join(_project_root, f)
|
|
with open(path) as fh:
|
|
for expr in parse_all(fh.read()):
|
|
trampoline(eval_expr(expr, env))
|
|
return env
|
|
|
|
|
|
def _compile(env, src):
|
|
"""Compile an SX source string to bytecode dict."""
|
|
ast = parse_all(src)[0]
|
|
return trampoline(eval_expr(
|
|
[Symbol('compile'), [Symbol('quote'), ast]], env))
|
|
|
|
|
|
# Load compiler once for all tests
|
|
_compiler_env = _load_compiler()
|
|
|
|
|
|
class TestCompilerOutput(unittest.TestCase):
|
|
"""Test that the compiler produces valid bytecode for various SX patterns."""
|
|
|
|
def _compile(self, src):
|
|
return _compile(_compiler_env, src)
|
|
|
|
def test_arithmetic(self):
|
|
result = self._compile('(+ 1 2)')
|
|
self.assertIn('bytecode', result)
|
|
self.assertIn('constants', result)
|
|
bc = list(result['bytecode'])
|
|
self.assertTrue(len(bc) > 0)
|
|
|
|
def test_if_produces_jumps(self):
|
|
result = self._compile('(if true "a" "b")')
|
|
bc = list(result['bytecode'])
|
|
# Should contain OP_JUMP_IF_FALSE (33)
|
|
self.assertIn(33, bc)
|
|
|
|
def test_let_uses_local_slots(self):
|
|
result = self._compile('(let ((x 1)) x)')
|
|
bc = list(result['bytecode'])
|
|
# Should contain OP_LOCAL_SET (17) and OP_LOCAL_GET (16)
|
|
self.assertIn(17, bc)
|
|
self.assertIn(16, bc)
|
|
|
|
def test_lambda_produces_closure(self):
|
|
result = self._compile('(fn (x) (+ x 1))')
|
|
bc = list(result['bytecode'])
|
|
# Should contain OP_CLOSURE (51)
|
|
self.assertIn(51, bc)
|
|
|
|
def test_closure_captures_upvalue(self):
|
|
result = self._compile('(let ((x 10)) (fn (y) (+ x y)))')
|
|
bc = list(result['bytecode'])
|
|
# Should have OP_CLOSURE with upvalue descriptors
|
|
self.assertIn(51, bc)
|
|
# Find closure index and check upvalue-count in constants
|
|
consts = list(result['constants'])
|
|
code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c]
|
|
self.assertTrue(len(code_objs) > 0)
|
|
code = code_objs[0]
|
|
self.assertEqual(code.get('upvalue-count', 0), 1)
|
|
# Inner bytecode should use OP_UPVALUE_GET (18)
|
|
inner_bc = list(code['bytecode'])
|
|
self.assertIn(18, inner_bc)
|
|
|
|
def test_cond_compiles(self):
|
|
result = self._compile('(cond (= x 1) "a" :else "b")')
|
|
self.assertTrue(len(list(result['bytecode'])) > 0)
|
|
|
|
def test_case_compiles(self):
|
|
result = self._compile('(case x 1 "one" :else "other")')
|
|
self.assertTrue(len(list(result['bytecode'])) > 0)
|
|
|
|
def test_thread_first_compiles(self):
|
|
result = self._compile('(-> x (+ 1))')
|
|
self.assertTrue(len(list(result['bytecode'])) > 0)
|
|
|
|
def test_begin_compiles(self):
|
|
result = self._compile('(do (+ 1 2) (+ 3 4))')
|
|
bc = list(result['bytecode'])
|
|
# Should contain OP_POP (5) between expressions
|
|
self.assertIn(5, bc)
|
|
|
|
def test_define_compiles(self):
|
|
result = self._compile('(define x 42)')
|
|
bc = list(result['bytecode'])
|
|
# Should contain OP_DEFINE (128)
|
|
self.assertIn(128, bc)
|
|
|
|
def test_nested_let_shares_frame(self):
|
|
"""Nested lets should use incrementing slot numbers, not restart at 0."""
|
|
result = self._compile('(let ((a 1)) (let ((b 2)) (+ a b)))')
|
|
bc = list(result['bytecode'])
|
|
# First LOC_SET should be slot 0, second should be slot 1
|
|
set_indices = []
|
|
for i, op in enumerate(bc):
|
|
if op == 17 and i + 1 < len(bc): # OP_LOCAL_SET
|
|
set_indices.append(bc[i + 1])
|
|
self.assertEqual(set_indices, [0, 1])
|
|
|
|
def test_tail_call(self):
|
|
"""Calls in tail position should use OP_TAIL_CALL."""
|
|
result = self._compile('(fn (x) (if (> x 0) (foo (- x 1)) 0))')
|
|
consts = list(result['constants'])
|
|
code_objs = [c for c in consts if isinstance(c, dict) and 'bytecode' in c]
|
|
inner_bc = list(code_objs[0]['bytecode'])
|
|
# Should contain OP_TAIL_CALL (49) for the recursive call
|
|
self.assertIn(49, inner_bc)
|
|
|
|
|
|
class TestVMExecution(unittest.IsolatedAsyncioTestCase):
|
|
"""Test that compiled bytecode executes correctly on the OCaml VM."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
_skip_if_no_binary()
|
|
|
|
async def asyncSetUp(self):
|
|
self.bridge = OcamlBridge()
|
|
await self.bridge.start()
|
|
|
|
async def asyncTearDown(self):
|
|
await self.bridge.stop()
|
|
|
|
async def _vm_eval(self, src):
|
|
"""Compile SX source and execute on VM, return result."""
|
|
compiled = _compile(_compiler_env, src)
|
|
code_sx = serialize(compiled)
|
|
async with self.bridge._lock:
|
|
await self.bridge._send(f'(vm-exec {code_sx})')
|
|
return await self.bridge._read_until_ok(ctx=None)
|
|
|
|
async def _cek_eval(self, src):
|
|
"""Evaluate SX source on CEK machine, return result."""
|
|
async with self.bridge._lock:
|
|
await self.bridge._send(f'(eval "{_escape_for_ocaml(src)}")')
|
|
return await self.bridge._read_until_ok(ctx=None)
|
|
|
|
async def test_arithmetic(self):
|
|
result = await self._vm_eval('(+ 1 2)')
|
|
self.assertEqual(result.strip(), '3')
|
|
|
|
async def test_nested_arithmetic(self):
|
|
result = await self._vm_eval('(+ (* 3 4) (- 10 5))')
|
|
self.assertEqual(result.strip(), '17')
|
|
|
|
async def test_if_true(self):
|
|
result = await self._vm_eval('(if true "yes" "no")')
|
|
self.assertIn('yes', result)
|
|
|
|
async def test_if_false(self):
|
|
result = await self._vm_eval('(if false "yes" "no")')
|
|
self.assertIn('no', result)
|
|
|
|
async def test_let_binding(self):
|
|
result = await self._vm_eval('(let ((x 10) (y 20)) (+ x y))')
|
|
self.assertEqual(result.strip(), '30')
|
|
|
|
async def test_nested_let(self):
|
|
result = await self._vm_eval('(let ((a 1)) (let ((b 2)) (+ a b)))')
|
|
self.assertEqual(result.strip(), '3')
|
|
|
|
async def test_closure_captures_variable(self):
|
|
result = await self._vm_eval(
|
|
'(let ((x 10)) (let ((f (fn (y) (+ x y)))) (f 5)))')
|
|
self.assertEqual(result.strip(), '15')
|
|
|
|
async def test_closure_nested_capture(self):
|
|
result = await self._vm_eval(
|
|
'(let ((a 1) (b 2)) (let ((f (fn (c) (+ a (+ b c))))) (f 3)))')
|
|
self.assertEqual(result.strip(), '6')
|
|
|
|
async def test_and_short_circuit(self):
|
|
result = await self._vm_eval('(and false (error "should not reach"))')
|
|
self.assertEqual(result.strip(), 'false')
|
|
|
|
async def test_or_short_circuit(self):
|
|
result = await self._vm_eval('(or 42 (error "should not reach"))')
|
|
self.assertEqual(result.strip(), '42')
|
|
|
|
async def test_when_true(self):
|
|
result = await self._vm_eval('(when true "yes")')
|
|
self.assertIn('yes', result)
|
|
|
|
async def test_when_false(self):
|
|
result = await self._vm_eval('(when false "yes")')
|
|
self.assertIn('nil', result.lower())
|
|
|
|
async def test_cond(self):
|
|
result = await self._vm_eval(
|
|
'(let ((x 2)) (cond (= x 1) "one" (= x 2) "two" :else "other"))')
|
|
self.assertIn('two', result)
|
|
|
|
async def test_string_primitives(self):
|
|
result = await self._vm_eval('(str "hello" " " "world")')
|
|
self.assertIn('hello world', result)
|
|
|
|
async def test_list_construction(self):
|
|
result = await self._vm_eval('(list 1 2 3)')
|
|
self.assertIn('1', result)
|
|
self.assertIn('2', result)
|
|
self.assertIn('3', result)
|
|
|
|
async def test_define_and_call(self):
|
|
result = await self._vm_eval(
|
|
'(do (define double (fn (x) (* x 2))) (double 21))')
|
|
self.assertEqual(result.strip(), '42')
|
|
|
|
async def test_higher_order_call(self):
|
|
"""A function that takes another function as argument."""
|
|
result = await self._vm_eval(
|
|
'(let ((apply-fn (fn (f x) (f x)))) (apply-fn (fn (n) (* n 3)) 7))')
|
|
self.assertEqual(result.strip(), '21')
|
|
|
|
async def test_vm_matches_cek(self):
|
|
"""VM result must match CEK result for numeric expressions."""
|
|
test_exprs = [
|
|
('(+ 1 2)', '3'),
|
|
('(* 3 (+ 4 5))', '27'),
|
|
('(let ((x 10)) (+ x 1))', '11'),
|
|
('(- 100 42)', '58'),
|
|
]
|
|
for src, expected in test_exprs:
|
|
vm_result = await self._vm_eval(src)
|
|
self.assertEqual(vm_result.strip(), expected,
|
|
f"VM wrong for {src}: got {vm_result}, expected {expected}")
|
|
|
|
|
|
class TestVMAutoCompile(unittest.IsolatedAsyncioTestCase):
|
|
"""Test patterns that auto-compile needs to handle.
|
|
These represent the 111 functions that currently fail."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
_skip_if_no_binary()
|
|
|
|
async def asyncSetUp(self):
|
|
self.bridge = OcamlBridge()
|
|
await self.bridge.start()
|
|
|
|
async def asyncTearDown(self):
|
|
await self.bridge.stop()
|
|
|
|
async def _vm_eval(self, src):
|
|
compiled = _compile(_compiler_env, src)
|
|
code_sx = serialize(compiled)
|
|
async with self.bridge._lock:
|
|
await self.bridge._send(f'(vm-exec {code_sx})')
|
|
return await self.bridge._read_until_ok(ctx=None)
|
|
|
|
async def test_for_each_via_primitive(self):
|
|
"""for-each should work as a primitive call."""
|
|
result = await self._vm_eval(
|
|
'(let ((sum 0)) (for-each (fn (x) (set! sum (+ sum x))) (list 1 2 3)) sum)')
|
|
self.assertEqual(result.strip(), '6')
|
|
|
|
async def test_map_via_primitive(self):
|
|
"""map should work as a primitive call."""
|
|
result = await self._vm_eval(
|
|
'(map (fn (x) (* x 2)) (list 1 2 3))')
|
|
self.assertIn('2', result)
|
|
self.assertIn('4', result)
|
|
self.assertIn('6', result)
|
|
|
|
async def test_filter_via_primitive(self):
|
|
"""filter should work as a primitive call."""
|
|
result = await self._vm_eval(
|
|
'(filter (fn (x) (> x 2)) (list 1 2 3 4 5))')
|
|
self.assertIn('3', result)
|
|
self.assertIn('4', result)
|
|
self.assertIn('5', result)
|
|
|
|
async def test_closure_over_mutable(self):
|
|
"""Closure capturing a set! target must share the mutation."""
|
|
result = await self._vm_eval(
|
|
'(let ((count 0)) (let ((inc (fn () (set! count (+ count 1))))) (inc) (inc) (inc) count))')
|
|
self.assertEqual(result.strip(), '3')
|
|
|
|
async def test_recursive_function(self):
|
|
"""Recursive function via define."""
|
|
result = await self._vm_eval(
|
|
'(do (define fact (fn (n) (if (<= n 1) 1 (* n (fact (- n 1)))))) (fact 5))')
|
|
self.assertEqual(result.strip(), '120')
|
|
|
|
async def test_string_building(self):
|
|
"""String concatenation — hot path for aser."""
|
|
result = await self._vm_eval(
|
|
'(str "(" "div" " " ":class" ")")')
|
|
self.assertIn('div', result)
|
|
self.assertIn(':class', result)
|
|
|
|
async def test_type_dispatch(self):
|
|
"""type-of dispatch — used heavily by aser."""
|
|
result = await self._vm_eval(
|
|
'(cond (= (type-of 42) "number") "num" (= (type-of "x") "string") "str" :else "other")')
|
|
self.assertIn('num', result)
|
|
|
|
async def test_type_of_number(self):
|
|
"""type-of dispatch — foundation for aser."""
|
|
result = await self._vm_eval('(type-of 42)')
|
|
self.assertIn('number', result)
|
|
|
|
async def test_empty_list_check(self):
|
|
result = await self._vm_eval('(empty? (list))')
|
|
self.assertEqual(result.strip(), 'true')
|
|
|
|
async def test_multiple_closures_same_scope(self):
|
|
"""Multiple closures capturing from the same let."""
|
|
result = await self._vm_eval('''
|
|
(let ((base 100))
|
|
(let ((add (fn (x) (+ base x)))
|
|
(sub (fn (x) (- base x))))
|
|
(+ (add 10) (sub 10))))''')
|
|
self.assertEqual(result.strip(), '200')
|
|
|
|
|
|
def _escape_for_ocaml(s):
|
|
"""Escape a string for embedding in an OCaml SX command."""
|
|
return s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|