sx: step 14 — inline JIT primitives (-69% fib, -62% loop, -50% sum on bench_vm)

The bytecode compiler emitted OP_CALL_PRIM (52) for every primitive call, even
for arithmetic and comparison hot-paths. The VM had specialized opcodes
(OP_ADD, OP_SUB, OP_EQ, etc.) defined but unused.

- lib/compiler.sx (compile-call): emit specialized 1-byte opcode when the
  primitive name + arity matches one of {+, -, *, /, =, <, >, cons, not, len,
  first, rest}. Falls back to CALL_PRIM otherwise. fib bytecode: 50 → 38 bytes.
- hosts/ocaml/lib/sx_compiler.ml: mirror change in the auto-generated OCaml
  compiler so SXBC export from mcp_tree uses the same emission.
- hosts/ocaml/lib/sx_vm.ml: extend OP_ADD/SUB/MUL/DIV to handle Integer+Integer
  (not just Number+Number). Inline OP_EQ via Sx_runtime._fast_eq. Inline
  OP_LT/GT mixed-numeric comparisons. Avoids Hashtbl lookup on the fallback
  path for the common integer cases that dominate tight loops.
- hosts/ocaml/bin/bench_vm.ml: VM-only benchmark — loads compiler.sx via CEK,
  JIT-compiles each fn, measures Sx_vm.call_closure throughput.

Median improvements (best of 3 runs of 9-min, bench_vm.exe):
  fib(22)         107.87ms →  33.13ms   -69%
  loop(200000)    429.64ms → 161.16ms   -62%
  sum-to(50000)    72.85ms →  36.74ms   -50%
  count-lt(20000)  28.44ms →  17.58ms   -38%
  count-eq(20000)  37.23ms →  15.46ms   -58%

Tests: 4550/4550 OCaml passing (unchanged). Zero regressions.

Last step in the sx-improvements roadmap — all 14 steps complete.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-07 02:38:47 +00:00
parent 4cb5302232
commit 6c171d4906
6 changed files with 262 additions and 10 deletions

155
hosts/ocaml/bin/bench_vm.ml Normal file
View File

@@ -0,0 +1,155 @@
(** VM bytecode benchmark — measures throughput of the VM (compiled bytecode).
Loads the SX compiler via CEK, then for each test:
1. Define the function via CEK (as a Lambda).
2. Trigger JIT compilation via Sx_vm.jit_compile_lambda.
3. Call the compiled VmClosure repeatedly via Sx_vm.call_closure.
This measures pure VM execution time on the JIT path. *)
open Sx_types
let load_compiler env globals =
let compiler_path =
if Sys.file_exists "lib/compiler.sx" then "lib/compiler.sx"
else if Sys.file_exists "../../lib/compiler.sx" then "../../lib/compiler.sx"
else if Sys.file_exists "../../../lib/compiler.sx" then "../../../lib/compiler.sx"
else failwith "compiler.sx not found"
in
let ic = open_in compiler_path in
let src = really_input_string ic (in_channel_length ic) in
close_in ic;
let exprs = Sx_parser.parse_all src in
List.iter (fun e -> ignore (Sx_ref.eval_expr e (Env env))) exprs;
let rec sync e =
Hashtbl.iter (fun id v ->
let name = Sx_types.unintern id in
Hashtbl.replace globals name v) e.bindings;
match e.parent with Some p -> sync p | None -> ()
in
sync env
let _make_globals env =
let g = Hashtbl.create 512 in
Hashtbl.iter (fun name fn ->
Hashtbl.replace g name (NativeFn (name, fn))
) Sx_primitives.primitives;
let rec sync e =
Hashtbl.iter (fun id v ->
let name = Sx_types.unintern id in
if not (Hashtbl.mem g name) then Hashtbl.replace g name v) e.bindings;
match e.parent with Some p -> sync p | None -> ()
in
sync env;
g
let define_fn env globals name params body_src =
(* Define via CEK so we get a Lambda value with proper closure. *)
let body_expr = match Sx_parser.parse_all body_src with
| [e] -> e
| _ -> failwith "expected one body expression"
in
let param_syms = List (List.map (fun p -> Symbol p) params) in
let define_expr = List [Symbol "define"; Symbol name; List [Symbol "fn"; param_syms; body_expr]] in
ignore (Sx_ref.eval_expr define_expr (Env env));
(* Sync env to globals so JIT can resolve free vars. *)
let rec sync e =
Hashtbl.iter (fun id v ->
let n = Sx_types.unintern id in
Hashtbl.replace globals n v) e.bindings;
match e.parent with Some p -> sync p | None -> ()
in
sync env;
(* Now find the Lambda and JIT-compile it. *)
let lam_val = Hashtbl.find globals name in
match lam_val with
| Lambda l ->
(match Sx_vm.jit_compile_lambda l globals with
| Some cl ->
l.l_compiled <- Some cl;
Hashtbl.replace globals name (NativeFn (name, fun args ->
Sx_vm.call_closure cl args globals));
cl
| None ->
failwith (Printf.sprintf "JIT failed for %s" name))
| _ -> failwith (Printf.sprintf "%s is not a Lambda after define" name)
let bench_call name cl globals args iters =
let times = ref [] in
for _ = 1 to iters do
Gc.full_major ();
let t0 = Unix.gettimeofday () in
let _r = Sx_vm.call_closure cl args globals in
let t1 = Unix.gettimeofday () in
times := (t1 -. t0) :: !times
done;
let sorted = List.sort compare !times in
let median = List.nth sorted (iters / 2) in
let min_t = List.nth sorted 0 in
let max_t = List.nth sorted (iters - 1) in
Printf.printf " %-22s min=%8.2fms median=%8.2fms max=%8.2fms\n%!"
name (min_t *. 1000.0) (median *. 1000.0) (max_t *. 1000.0);
median
let () =
let iters =
if Array.length Sys.argv > 1
then int_of_string Sys.argv.(1)
else 7
in
Printf.printf "VM (bytecode/JIT) benchmark (%d runs each, taking median)\n%!" iters;
Printf.printf "========================================================\n%!";
let env = Sx_types.make_env () in
let bind n fn = ignore (Sx_types.env_bind env n (NativeFn (n, fn))) in
(* Seed env with primitives as NativeFn so CEK lookups work. *)
Hashtbl.iter (fun name fn ->
Hashtbl.replace env.bindings (Sx_types.intern name) (NativeFn (name, fn))
) Sx_primitives.primitives;
(* Helpers the SX compiler relies on but aren't kernel primitives. *)
bind "symbol-name" (fun args -> match args with
| [Symbol s] -> String s | _ -> raise (Eval_error "symbol-name"));
bind "keyword-name" (fun args -> match args with
| [Keyword k] -> String k | _ -> raise (Eval_error "keyword-name"));
bind "make-symbol" (fun args -> match args with
| [String s] -> Symbol s
| [v] -> Symbol (Sx_types.value_to_string v)
| _ -> raise (Eval_error "make-symbol"));
bind "sx-serialize" (fun args -> match args with
| [v] -> String (Sx_types.inspect v)
| _ -> raise (Eval_error "sx-serialize"));
let globals = Hashtbl.create 1024 in
Hashtbl.iter (fun name fn ->
Hashtbl.replace globals name (NativeFn (name, fn))
) Sx_primitives.primitives;
Printf.printf "Loading compiler.sx ... %!";
let t0 = Unix.gettimeofday () in
load_compiler env globals;
Printf.printf "%.0fms\n%!" ((Unix.gettimeofday () -. t0) *. 1000.0);
(* fib(22) — recursive call benchmark *)
let fib_cl = define_fn env globals "fib" ["n"]
"(if (< n 2) n (+ (fib (- n 1)) (fib (- n 2))))" in
let _ = bench_call "fib(22)" fib_cl globals [Number 22.0] iters in
(* tight loop *)
let loop_cl = define_fn env globals "loop" ["n"; "acc"]
"(if (= n 0) acc (loop (- n 1) (+ acc 1)))" in
let _ = bench_call "loop(200000)" loop_cl globals [Number 200000.0; Number 0.0] iters in
(* sum-to *)
let sum_cl = define_fn env globals "sum_to" ["n"; "acc"]
"(if (= n 0) acc (sum_to (- n 1) (+ acc n)))" in
let _ = bench_call "sum-to(50000)" sum_cl globals [Number 50000.0; Number 0.0] iters in
(* count-lt: comparison-heavy *)
let cnt_cl = define_fn env globals "count_lt" ["n"; "acc"]
"(if (= n 0) acc (count_lt (- n 1) (if (< n 10000) (+ acc 1) acc)))" in
let _ = bench_call "count-lt(20000)" cnt_cl globals [Number 20000.0; Number 0.0] iters in
(* count-eq: equality-heavy on multiples of 7 *)
let eq_cl = define_fn env globals "count_eq" ["n"; "acc"]
"(if (= n 0) acc (count_eq (- n 1) (if (= 0 (- n (* 7 (/ n 7)))) (+ acc 1) acc)))" in
let _ = bench_call "count-eq(20000)" eq_cl globals [Number 20000.0; Number 0.0] iters in
Printf.printf "\nDone.\n%!"

View File

@@ -1,5 +1,5 @@
(executables
(names run_tests debug_set sx_server integration_tests bench_cek bench_inspect)
(names run_tests debug_set sx_server integration_tests bench_cek bench_inspect bench_vm)
(libraries sx unix threads.posix otfm yojson))
(executable

View File

@@ -200,7 +200,30 @@ and compile_qq_list em items scope =
(* compile-call *)
and compile_call em head args scope tail_p =
(let is_prim = (let _and = (prim_call "=" [(type_of (head)); (String "symbol")]) in if not (sx_truthy _and) then _and else (let name = (symbol_name (head)) in (let _and = (Bool (not (sx_truthy ((prim_call "=" [(get ((scope_resolve (scope) (name))) ((String "type"))); (String "local")]))))) in if not (sx_truthy _and) then _and else (let _and = (Bool (not (sx_truthy ((prim_call "=" [(get ((scope_resolve (scope) (name))) ((String "type"))); (String "upvalue")]))))) in if not (sx_truthy _and) then _and else (is_primitive (name)))))) in (if sx_truthy (is_prim) then (let name = (symbol_name (head)) in let argc = (len (args)) in let name_idx = (pool_add ((get (em) ((String "pool")))) (name)) in (let () = ignore ((List.iter (fun a -> ignore ((compile_expr (em) (a) (scope) ((Bool false))))) (sx_to_list args); Nil)) in (let () = ignore ((emit_op (em) ((Number 52.0)))) in (let () = ignore ((emit_u16 (em) (name_idx))) in (emit_byte (em) (argc)))))) else (let () = ignore ((compile_expr (em) (head) (scope) ((Bool false)))) in (let () = ignore ((List.iter (fun a -> ignore ((compile_expr (em) (a) (scope) ((Bool false))))) (sx_to_list args); Nil)) in (if sx_truthy (tail_p) then (let () = ignore ((emit_op (em) ((Number 49.0)))) in (emit_byte (em) ((len (args))))) else (let () = ignore ((emit_op (em) ((Number 48.0)))) in (emit_byte (em) ((len (args))))))))))
(let is_prim = (let _and = (prim_call "=" [(type_of (head)); (String "symbol")]) in if not (sx_truthy _and) then _and else (let name = (symbol_name (head)) in (let _and = (Bool (not (sx_truthy ((prim_call "=" [(get ((scope_resolve (scope) (name))) ((String "type"))); (String "local")]))))) in if not (sx_truthy _and) then _and else (let _and = (Bool (not (sx_truthy ((prim_call "=" [(get ((scope_resolve (scope) (name))) ((String "type"))); (String "upvalue")]))))) in if not (sx_truthy _and) then _and else (is_primitive (name)))))) in (if sx_truthy (is_prim) then (let name = (symbol_name (head)) in let argc = (len (args)) in
(* Specialized opcode for hot 2-arg / 1-arg primitives. *)
let specialized_op = (match name, argc with
| String "+", Number 2.0 -> Some 160
| String "-", Number 2.0 -> Some 161
| String "*", Number 2.0 -> Some 162
| String "/", Number 2.0 -> Some 163
| String "=", Number 2.0 -> Some 164
| String "<", Number 2.0 -> Some 165
| String ">", Number 2.0 -> Some 166
| String "cons", Number 2.0 -> Some 172
| String "not", Number 1.0 -> Some 167
| String "len", Number 1.0 -> Some 168
| String "first", Number 1.0 -> Some 169
| String "rest", Number 1.0 -> Some 170
| _ -> None) in
(let () = ignore ((List.iter (fun a -> ignore ((compile_expr (em) (a) (scope) ((Bool false))))) (sx_to_list args); Nil)) in
(match specialized_op with
| Some op -> emit_op em (Number (float_of_int op))
| None ->
let name_idx = (pool_add ((get (em) ((String "pool")))) (name)) in
let () = ignore ((emit_op (em) ((Number 52.0)))) in
let () = ignore ((emit_u16 (em) (name_idx))) in
emit_byte (em) (argc)))) else (let () = ignore ((compile_expr (em) (head) (scope) ((Bool false)))) in (let () = ignore ((List.iter (fun a -> ignore ((compile_expr (em) (a) (scope) ((Bool false))))) (sx_to_list args); Nil)) in (if sx_truthy (tail_p) then (let () = ignore ((emit_op (em) ((Number 49.0)))) in (emit_byte (em) ((len (args))))) else (let () = ignore ((emit_op (em) ((Number 48.0)))) in (emit_byte (em) ((len (args))))))))))
(* compile *)
and compile expr =

View File

@@ -742,38 +742,57 @@ and run vm =
| 160 (* OP_ADD *) ->
let b = pop vm and a = pop vm in
push vm (match a, b with
| Integer x, Integer y -> Integer (x + y)
| Number x, Number y -> Number (x +. y)
| Integer x, Number y -> Number (float_of_int x +. y)
| Number x, Integer y -> Number (x +. float_of_int y)
| _ -> (Hashtbl.find Sx_primitives.primitives "+") [a; b])
| 161 (* OP_SUB *) ->
let b = pop vm and a = pop vm in
push vm (match a, b with
| Integer x, Integer y -> Integer (x - y)
| Number x, Number y -> Number (x -. y)
| Integer x, Number y -> Number (float_of_int x -. y)
| Number x, Integer y -> Number (x -. float_of_int y)
| _ -> (Hashtbl.find Sx_primitives.primitives "-") [a; b])
| 162 (* OP_MUL *) ->
let b = pop vm and a = pop vm in
push vm (match a, b with
| Integer x, Integer y -> Integer (x * y)
| Number x, Number y -> Number (x *. y)
| Integer x, Number y -> Number (float_of_int x *. y)
| Number x, Integer y -> Number (x *. float_of_int y)
| _ -> (Hashtbl.find Sx_primitives.primitives "*") [a; b])
| 163 (* OP_DIV *) ->
let b = pop vm and a = pop vm in
push vm (match a, b with
| Integer x, Integer y when y <> 0 && x mod y = 0 -> Integer (x / y)
| Integer x, Integer y -> Number (float_of_int x /. float_of_int y)
| Number x, Number y -> Number (x /. y)
| Integer x, Number y -> Number (float_of_int x /. y)
| Number x, Integer y -> Number (x /. float_of_int y)
| _ -> (Hashtbl.find Sx_primitives.primitives "/") [a; b])
| 164 (* OP_EQ *) ->
let b = pop vm and a = pop vm in
push vm ((Hashtbl.find Sx_primitives.primitives "=") [a; b])
push vm (Bool (Sx_runtime._fast_eq a b))
| 165 (* OP_LT *) ->
let b = pop vm and a = pop vm in
push vm (match a, b with
| Integer x, Integer y -> Bool (x < y)
| Number x, Number y -> Bool (x < y)
| Integer x, Number y -> Bool (float_of_int x < y)
| Number x, Integer y -> Bool (x < float_of_int y)
| String x, String y -> Bool (x < y)
| _ -> (Hashtbl.find Sx_primitives.primitives "<") [a; b])
| _ -> Sx_runtime.prim_call "<" [a; b])
| 166 (* OP_GT *) ->
let b = pop vm and a = pop vm in
push vm (match a, b with
| Integer x, Integer y -> Bool (x > y)
| Number x, Number y -> Bool (x > y)
| Integer x, Number y -> Bool (float_of_int x > y)
| Number x, Integer y -> Bool (x > float_of_int y)
| String x, String y -> Bool (x > y)
| _ -> (Hashtbl.find Sx_primitives.primitives ">") [a; b])
| _ -> Sx_runtime.prim_call ">" [a; b])
| 167 (* OP_NOT *) ->
let v = pop vm in
push vm (Bool (not (sx_truthy v)))