diff --git a/hosts/ocaml/bootstrap.py b/hosts/ocaml/bootstrap.py index 0435e085..eae4c71a 100644 --- a/hosts/ocaml/bootstrap.py +++ b/hosts/ocaml/bootstrap.py @@ -49,20 +49,14 @@ let trampoline v = !trampoline_fn v -(* === Mutable state for strict mode === *) -(* These are defined as top-level refs because the transpiler cannot handle - global set! mutation (it creates local refs that shadow the global). *) +(* === Mutable globals — backing refs for transpiler's !_ref / _ref := === *) let _strict_ref = ref (Bool false) let _prim_param_types_ref = ref Nil +let _last_error_kont_ref = ref Nil -(* JIT call hook — cek_call checks this before CEK dispatch for named - lambdas. Registered by sx_server.ml after compiler loads. Tests - run with hook = None (pure CEK, no compilation dependency). *) +(* JIT call hook — platform-level optimization, registered by sx_server.ml *) let jit_call_hook : (value -> value list -> value option) option ref = ref None -(* Component trace — captures kont from last CEK error for diagnostics *) -let _last_error_kont : value ref = ref Nil - """ @@ -88,7 +82,7 @@ let cek_run_iterative state = done; cek_value !s with Eval_error msg -> - _last_error_kont := cek_kont !s; + _last_error_kont_ref := cek_kont !s; raise (Eval_error msg)) (* Collect component trace from a kont value *) @@ -127,8 +121,8 @@ let format_comp_trace trace = (* Enhance an error message with component trace *) let enhance_error_with_trace msg = - let trace = collect_comp_trace !_last_error_kont in - _last_error_kont := Nil; + let trace = collect_comp_trace !_last_error_kont_ref in + _last_error_kont_ref := Nil; msg ^ (format_comp_trace trace) @@ -215,90 +209,6 @@ def compile_spec_to_ml(spec_dir: str | None = None) -> str: # the transpiler directly — it emits !_ref for reads, _ref := for writes. import re - # Fix cek_call: the spec passes (make-env) as the env arg to - # continue_with_call, but the transpiler evaluates make-env at - # transpile time (it's a primitive), producing Dict instead of Env. - output = output.replace( - "((Dict (Hashtbl.create 0))) (a) ((List []))", - "(Env (Sx_types.make_env ())) (a) ((List []))", - ) - - # Inject JIT dispatch + &rest handling into continue_with_call's lambda branch. - # Replace the entire lambda binding + make_cek_state section. - cwc_lambda_old = ( - 'else (if sx_truthy ((is_lambda (f))) then ' - '(let params = (lambda_params (f)) in let local = (env_merge ((lambda_closure (f))) (env)) in ' - '(if sx_truthy ((prim_call ">" [(len (args)); (len (params))])) then ' - '(raise (Eval_error (value_to_str (String (sx_str [' - '(let _or = (lambda_name (f)) in if sx_truthy _or then _or else (String "lambda")); ' - '(String " expects "); (len (params)); (String " args, got "); (len (args))])))))' - ' else (let () = ignore ((List.iter (fun pair -> ignore (' - '(env_bind local (sx_to_string (first (pair))) (nth (pair) ((Number 1.0))))))' - ' (sx_to_list (prim_call "zip" [params; args])); Nil)) in ' - '(let () = ignore ((List.iter (fun p -> ignore ((env_bind local (sx_to_string p) Nil)))' - ' (sx_to_list (prim_call "slice" [params; (len (args))])); Nil)) in ' - '(make_cek_state ((lambda_body (f))) (local) (kont))))))' - ) - cwc_lambda_new = ( - 'else (if sx_truthy ((is_lambda (f))) then ' - '(let params = (lambda_params (f)) in let local = (env_merge ((lambda_closure (f))) (env)) in ' - '(if not (bind_lambda_with_rest params args local) then begin ' - 'let pl = sx_to_list params and al = sx_to_list args in ' - 'if List.length al > List.length pl then ' - 'raise (Eval_error (Printf.sprintf "%s expects %d args, got %d" ' - '(match lambda_name f with String s -> s | _ -> "lambda") ' - '(List.length pl) (List.length al))); ' - 'List.iter (fun pair -> ignore (env_bind local (sx_to_string (first pair)) (nth pair (Number 1.0)))) ' - '(sx_to_list (prim_call "zip" [params; args])); ' - 'List.iter (fun p -> ignore (env_bind local (sx_to_string p) Nil)) ' - '(sx_to_list (prim_call "slice" [params; len args])) end; ' - '(match !jit_call_hook, f with ' - '| Some hook, Lambda l when l.l_name <> None -> ' - 'let args_list = match args with List a | ListRef { contents = a } -> a | _ -> [] in ' - '(match hook f args_list with ' - 'Some result -> make_cek_value result local kont ' - '| None -> make_cek_state (lambda_body f) local kont) ' - '| _ -> make_cek_state ((lambda_body (f))) (local) (kont))))' - ) - if cwc_lambda_old in output: - output = output.replace(cwc_lambda_old, cwc_lambda_new, 1) - else: - import sys - print("WARNING: Could not find continue_with_call lambda pattern for &rest+JIT injection", file=sys.stderr) - - # Patch call_lambda and continue_with_call to handle &rest in lambda params. - # The transpiler can't handle the index-of-based approach, so we inject it. - REST_HELPER = """ -(* &rest lambda param binding — injected by bootstrap.py *) -and bind_lambda_with_rest (params : value) (args : value) (local_val : value) : bool = - let local = match local_val with Env e -> e | _ -> failwith "bind_lambda_with_rest: expected env" in - let param_list = sx_to_list params in - let arg_list = sx_to_list args in - let rec find_rest i = function - | [] -> None - | h :: rp :: _ when value_to_str h = "&rest" -> Some (i, value_to_str rp) - | _ :: tl -> find_rest (i + 1) tl - in - match find_rest 0 param_list with - | Some (pos, rest_name) -> - let positional = List.filteri (fun i _ -> i < pos) param_list in - List.iteri (fun i p -> - let v = if i < List.length arg_list then List.nth arg_list i else Nil in - ignore (Sx_types.env_bind local (value_to_str p) v) - ) positional; - let rest_args = if List.length arg_list > pos - then List (List.filteri (fun i _ -> i >= pos) arg_list) - else List [] in - ignore (Sx_types.env_bind local rest_name rest_args); - true - | None -> false -""" - # Inject the helper before call_lambda - output = output.replace( - "(* call-lambda *)\nand call_lambda", - REST_HELPER + "\n(* call-lambda *)\nand call_lambda", - ) - # Inject make_raise_guard_frame if missing (transpiler merge bug drops it) if "and make_raise_guard_frame" not in output: RAISE_GUARD_FRAME = """ @@ -311,38 +221,9 @@ and make_raise_guard_frame env saved_kont = RAISE_GUARD_FRAME + "\n(* make-signal-return-frame *)\nand make_signal_return_frame", ) - # Patch call_lambda to use &rest-aware binding - call_lambda_marker = "(* call-lambda *)\nand call_lambda f args caller_env =\n" - call_comp_marker = "\n(* call-component *)" - if call_lambda_marker in output and call_comp_marker in output: - start = output.index(call_lambda_marker) - end = output.index(call_comp_marker) - new_call_lambda = """(* call-lambda *) -and call_lambda f args caller_env = - let params = lambda_params f in - let local = env_merge (lambda_closure f) caller_env in - if not (bind_lambda_with_rest params args local) then begin - let pl = sx_to_list params and al = sx_to_list args in - if List.length al > List.length pl then - raise (Eval_error (Printf.sprintf "%s expects %d args, got %d" - (match lambda_name f with String s -> s | _ -> "lambda") - (List.length pl) (List.length al))); - List.iter (fun pair -> - ignore (env_bind local (sx_to_string (first pair)) (nth pair (Number 1.0))) - ) (sx_to_list (prim_call "zip" [params; args])); - List.iter (fun p -> - ignore (env_bind local (sx_to_string p) Nil) - ) (sx_to_list (prim_call "slice" [params; len args])) - end; - make_thunk (lambda_body f) local -""" - output = output[:start] + new_call_lambda + output[end:] - else: - print("WARNING: Could not find call_lambda for &rest injection", file=sys.stderr) + # === Platform-level patches (not spec concerns) === # Instrument recursive cek_run to capture kont on error (for comp-trace). - # The iterative cek_run_iterative already does this, but cek_call uses - # the recursive cek_run. cek_run_old = ( 'and cek_run state =\n' ' (if sx_truthy ((cek_terminal_p (state))) then (cek_value (state)) else (cek_run ((cek_step (state)))))' @@ -352,12 +233,26 @@ and call_lambda f args caller_env = ' (if sx_truthy ((cek_terminal_p (state))) then (cek_value (state)) else\n' ' try cek_run ((cek_step (state)))\n' ' with Eval_error msg ->\n' - ' (if !_last_error_kont = Nil then _last_error_kont := cek_kont state);\n' + ' (if !_last_error_kont_ref = Nil then _last_error_kont_ref := cek_kont state);\n' ' raise (Eval_error msg))' ) if cek_run_old in output: output = output.replace(cek_run_old, cek_run_new, 1) + # Inject JIT dispatch into continue_with_call's lambda branch. + # Replace final make_cek_state in the lambda branch with JIT check. + jit_old = "(make_cek_state ((lambda_body (f))) (local) (kont))))))" + jit_new = ( + "(match !jit_call_hook, f with " + "| Some hook, Lambda l when l.l_name <> None -> " + "let args_list = match args with List a | ListRef { contents = a } -> a | _ -> [] in " + "(match hook f args_list with " + "Some result -> make_cek_value result local kont " + "| None -> make_cek_state (lambda_body f) local kont) " + "| _ -> make_cek_state ((lambda_body (f))) (local) (kont))))))" + ) + output = output.replace(jit_old, jit_new, 1) + return output diff --git a/hosts/ocaml/transpiler.sx b/hosts/ocaml/transpiler.sx index a98089cc..31b6516a 100644 --- a/hosts/ocaml/transpiler.sx +++ b/hosts/ocaml/transpiler.sx @@ -268,7 +268,9 @@ (define ml-dynamic-globals (list "*render-check*" "*render-fn*")) -(define ml-mutable-globals (list "*strict*" "*prim-param-types*")) +(define + ml-mutable-globals + (list "*strict*" "*prim-param-types*" "*last-error-kont*")) (define ml-is-mutable-global? diff --git a/spec/evaluator.sx b/spec/evaluator.sx index 2d9efce3..66d2caf9 100644 --- a/spec/evaluator.sx +++ b/spec/evaluator.sx @@ -418,6 +418,30 @@ (define eval-expr (fn (expr (env :as dict)) nil)) +(define + bind-lambda-params + (fn + (params args local) + (let + ((rest-idx (index-of params "&rest"))) + (if + (and (number? rest-idx) (< rest-idx (len params))) + (let + ((positional (slice params 0 rest-idx)) + (rest-name (nth params (+ rest-idx 1)))) + (do + (for-each-indexed + (fn + (i p) + (env-bind! local p (if (< i (len args)) (nth args i) nil))) + positional) + (env-bind! + local + rest-name + (if (> (len args) rest-idx) (slice args rest-idx) (quote ()))) + true)) + false)))) + (define call-lambda (fn @@ -425,23 +449,24 @@ (let ((params (lambda-params f)) (local (env-merge (lambda-closure f) caller-env))) - (if - (> (len args) (len params)) - (error - (str - (or (lambda-name f) "lambda") - " expects " - (len params) - " args, got " - (len args))) - (do - (for-each - (fn (pair) (env-bind! local (first pair) (nth pair 1))) - (zip params args)) - (for-each - (fn (p) (env-bind! local p nil)) - (slice params (len args))) - (make-thunk (lambda-body f) local)))))) + (when + (not (bind-lambda-params params args local)) + (when + (> (len args) (len params)) + (error + (str + (or (lambda-name f) "lambda") + " expects " + (len params) + " args, got " + (len args)))) + (for-each + (fn (pair) (env-bind! local (first pair) (nth pair 1))) + (zip params args)) + (for-each + (fn (p) (env-bind! local p nil)) + (slice params (len args)))) + (make-thunk (lambda-body f) local)))) (define call-component @@ -2691,23 +2716,24 @@ (let ((params (lambda-params f)) (local (env-merge (lambda-closure f) env))) - (if - (> (len args) (len params)) - (error - (str - (or (lambda-name f) "lambda") - " expects " - (len params) - " args, got " - (len args))) - (do - (for-each - (fn (pair) (env-bind! local (first pair) (nth pair 1))) - (zip params args)) - (for-each - (fn (p) (env-bind! local p nil)) - (slice params (len args))) - (make-cek-state (lambda-body f) local kont)))) + (when + (not (bind-lambda-params params args local)) + (when + (> (len args) (len params)) + (error + (str + (or (lambda-name f) "lambda") + " expects " + (len params) + " args, got " + (len args)))) + (for-each + (fn (pair) (env-bind! local (first pair) (nth pair 1))) + (zip params args)) + (for-each + (fn (p) (env-bind! local p nil)) + (slice params (len args)))) + (make-cek-state (lambda-body f) local kont)) (or (component? f) (island? f)) (let ((parsed (parse-keyword-args raw-args env))