""" #z3 reader macro — translates SX spec declarations to SMT-LIB format. Demonstrates extensible reader macros by converting define-primitive declarations from primitives.sx into Z3 SMT-LIB verification conditions. Usage: from shared.sx.ref.reader_z3 import z3_translate, register_z3_macro # Register as reader macro (enables #z3 in parser) register_z3_macro() # Or call directly smtlib = z3_translate(parse('(define-primitive "inc" :params (n) ...)')) """ from __future__ import annotations from typing import Any from shared.sx.types import Symbol, Keyword # --------------------------------------------------------------------------- # Type mapping # --------------------------------------------------------------------------- _SX_TO_SORT = { "number": "Int", "boolean": "Bool", "string": "String", "any": "Value", "list": "(List Value)", "dict": "(Array String Value)", } def _sort(sx_type: str) -> str: return _SX_TO_SORT.get(sx_type, "Value") # --------------------------------------------------------------------------- # Expression translation: SX → SMT-LIB # --------------------------------------------------------------------------- # SX operators that map directly to SMT-LIB _IDENTITY_OPS = {"+", "-", "*", "/", "=", "!=", "<", ">", "<=", ">=", "and", "or", "not", "mod"} # SX operators with SMT-LIB equivalents _RENAME_OPS = { "if": "ite", "str": "str.++", } def _translate_expr(expr: Any) -> str: """Translate an SX expression to SMT-LIB s-expression string.""" if isinstance(expr, (int, float)): if isinstance(expr, float): return f"(to_real {int(expr)})" if expr == int(expr) else str(expr) return str(expr) if isinstance(expr, str): return f'"{expr}"' if isinstance(expr, bool): return "true" if expr else "false" if expr is None: return "nil_val" if isinstance(expr, Symbol): name = expr.name # Translate SX predicate names to SMT-LIB if name.endswith("?"): return "is_" + name[:-1].replace("-", "_") return name.replace("-", "_").replace("!", "_bang") if isinstance(expr, list) and len(expr) > 0: head = expr[0] if isinstance(head, Symbol): op = head.name args = expr[1:] # Direct identity ops if op in _IDENTITY_OPS: smt_args = " ".join(_translate_expr(a) for a in args) return f"({op} {smt_args})" # Renamed ops if op in _RENAME_OPS: smt_op = _RENAME_OPS[op] smt_args = " ".join(_translate_expr(a) for a in args) return f"({smt_op} {smt_args})" # max/min → ite if op == "max" and len(args) == 2: a, b = _translate_expr(args[0]), _translate_expr(args[1]) return f"(ite (>= {a} {b}) {a} {b})" if op == "min" and len(args) == 2: a, b = _translate_expr(args[0]), _translate_expr(args[1]) return f"(ite (<= {a} {b}) {a} {b})" # empty? → length check if op == "empty?": a = _translate_expr(args[0]) return f"(= (len {a}) 0)" # first/rest → list ops if op == "first": return f"(head {_translate_expr(args[0])})" if op == "rest": return f"(tail {_translate_expr(args[0])})" # reduce with initial value if op == "reduce" and len(args) >= 3: return f"(reduce {_translate_expr(args[0])} {_translate_expr(args[2])} {_translate_expr(args[1])})" # fn (lambda) → unnamed function if op == "fn": params = args[0] if isinstance(args[0], list) else [args[0]] param_str = " ".join(f"({_translate_expr(p)} Int)" for p in params) body = _translate_expr(args[1]) return f"(lambda (({param_str})) {body})" # native-* → bare op if op.startswith("native-"): bare = op[7:] # strip "native-" smt_args = " ".join(_translate_expr(a) for a in args) return f"({bare} {smt_args})" # Generic function call smt_name = op.replace("-", "_").replace("?", "_p").replace("!", "_bang") smt_args = " ".join(_translate_expr(a) for a in args) return f"({smt_name} {smt_args})" return str(expr) # --------------------------------------------------------------------------- # Define-primitive → SMT-LIB # --------------------------------------------------------------------------- def _extract_kwargs(expr: list) -> dict[str, Any]: """Extract keyword arguments from a define-primitive form.""" kwargs: dict[str, Any] = {} i = 2 # skip head and name while i < len(expr): item = expr[i] if isinstance(item, Keyword) and i + 1 < len(expr): kwargs[item.name] = expr[i + 1] i += 2 else: i += 1 return kwargs def _params_to_sorts(params: list) -> list[tuple[str, str]]: """Convert SX param list to (name, sort) pairs, skipping &rest/&key.""" result = [] skip_next = False for p in params: if isinstance(p, Symbol) and p.name in ("&rest", "&key"): skip_next = True continue if skip_next: skip_next = False continue if isinstance(p, Symbol): result.append((p.name, "Int")) return result def z3_translate(expr: Any) -> str: """Translate an SX define-primitive to SMT-LIB verification conditions. Input: parsed (define-primitive "name" :params (...) :returns "type" ...) Output: SMT-LIB string with declare-fun and assert/check-sat. """ if not isinstance(expr, list) or len(expr) < 2: return f"; Cannot translate: not a list form" head = expr[0] if not isinstance(head, Symbol): return f"; Cannot translate: head is not a symbol" form = head.name if form == "define-primitive": return _translate_primitive(expr) elif form == "define-io-primitive": return _translate_io(expr) elif form == "define-special-form": return _translate_special_form(expr) else: # Generic expression translation return _translate_expr(expr) def _translate_primitive(expr: list) -> str: """Translate define-primitive to SMT-LIB.""" name = expr[1] if len(expr) > 1 else "?" kwargs = _extract_kwargs(expr) params = kwargs.get("params", []) returns = kwargs.get("returns", "any") doc = kwargs.get("doc", "") body = kwargs.get("body") # Build param sorts param_pairs = _params_to_sorts(params if isinstance(params, list) else []) has_rest = any(isinstance(p, Symbol) and p.name == "&rest" for p in (params if isinstance(params, list) else [])) # SMT-LIB function name if name == "!=": smt_name = "neq" elif name in ("+", "-", "*", "/", "=", "<", ">", "<=", ">="): smt_name = name # keep arithmetic ops as-is else: smt_name = name.replace("-", "_").replace("?", "_p").replace("!", "_bang") lines = [f"; {name} — {doc}"] if has_rest: # Variadic — declare as uninterpreted lines.append(f"; (variadic — modeled as uninterpreted)") lines.append(f"(declare-fun {smt_name} (Int Int) {_sort(returns)})") else: param_sorts = " ".join(s for _, s in param_pairs) lines.append(f"(declare-fun {smt_name} ({param_sorts}) {_sort(returns)})") if body is not None and not has_rest: # Generate forall assertion from body if param_pairs: bindings = " ".join(f"({p} Int)" for p, _ in param_pairs) call_args = " ".join(p for p, _ in param_pairs) smt_body = _translate_expr(body) lines.append(f"(assert (forall (({bindings}))") lines.append(f" (= ({smt_name} {call_args}) {smt_body})))") else: smt_body = _translate_expr(body) lines.append(f"(assert (= ({smt_name}) {smt_body}))") lines.append("(check-sat)") return "\n".join(lines) def _translate_io(expr: list) -> str: """Translate define-io-primitive — uninterpreted (cannot verify statically).""" name = expr[1] if len(expr) > 1 else "?" kwargs = _extract_kwargs(expr) doc = kwargs.get("doc", "") smt_name = name.replace("-", "_").replace("?", "_p") return (f"; IO primitive: {name} — {doc}\n" f"; (uninterpreted — IO cannot be verified statically)\n" f"(declare-fun {smt_name} () Value)") def _translate_special_form(expr: list) -> str: """Translate define-special-form to SMT-LIB.""" name = expr[1] if len(expr) > 1 else "?" kwargs = _extract_kwargs(expr) doc = kwargs.get("doc", "") if name == "if": return (f"; Special form: if — {doc}\n" f"(assert (forall ((c Bool) (t Value) (e Value))\n" f" (= (sx_if c t e) (ite c t e))))\n" f"(check-sat)") elif name == "when": return (f"; Special form: when — {doc}\n" f"(assert (forall ((c Bool) (body Value))\n" f" (= (sx_when c body) (ite c body nil_val))))\n" f"(check-sat)") return f"; Special form: {name} — {doc}\n; (not directly expressible in SMT-LIB)" # --------------------------------------------------------------------------- # Batch translation: process an entire spec file # --------------------------------------------------------------------------- def z3_translate_file(source: str) -> str: """Parse an SX spec file and translate all define-primitive forms.""" from shared.sx.parser import parse_all exprs = parse_all(source) results = [] for expr in exprs: if (isinstance(expr, list) and len(expr) >= 2 and isinstance(expr[0], Symbol) and expr[0].name in ("define-primitive", "define-io-primitive", "define-special-form")): results.append(z3_translate(expr)) return "\n\n".join(results) # --------------------------------------------------------------------------- # Reader macro registration # --------------------------------------------------------------------------- def register_z3_macro(): """Register #z3 as a reader macro in the SX parser.""" from shared.sx.parser import register_reader_macro register_reader_macro("z3", z3_translate)