Files
rose-ash/shared/sx/primitives.py
giles 6772f1141f Register append! and dict-set! as proper primitives
Previously these mutating operations were internal helpers in the JS
bootstrapper but not declared in primitives.sx or registered in the
Python evaluator. Now properly specced and available in both hosts.

Removes mock injections from cache tests — they use real primitives.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-07 00:21:17 +00:00

471 lines
12 KiB
Python

"""
Primitive registry and built-in pure functions.
All primitives here are pure (no I/O). Async / I/O primitives live in
separate modules and are registered at app startup.
"""
from __future__ import annotations
import math
from typing import Any, Callable
from .types import Keyword, NIL
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
_PRIMITIVES: dict[str, Callable] = {}
def register_primitive(name: str):
"""Decorator that registers a callable as a named primitive.
Usage::
@register_primitive("str")
def prim_str(*args):
return "".join(str(a) for a in args)
"""
def decorator(fn: Callable) -> Callable:
from .boundary import validate_primitive
validate_primitive(name)
_PRIMITIVES[name] = fn
return fn
return decorator
def get_primitive(name: str) -> Callable | None:
return _PRIMITIVES.get(name)
def all_primitives() -> dict[str, Callable]:
"""Return a snapshot of the registry (name → callable)."""
return dict(_PRIMITIVES)
# ---------------------------------------------------------------------------
# Arithmetic
# ---------------------------------------------------------------------------
@register_primitive("+")
def prim_add(*args: Any) -> Any:
return sum(args)
@register_primitive("-")
def prim_sub(a: Any, b: Any = None) -> Any:
return -a if b is None else a - b
@register_primitive("*")
def prim_mul(*args: Any) -> Any:
r = 1
for a in args:
r *= a
return r
@register_primitive("/")
def prim_div(a: Any, b: Any) -> Any:
return a / b
@register_primitive("mod")
def prim_mod(a: Any, b: Any) -> Any:
return a % b
@register_primitive("sqrt")
def prim_sqrt(x: Any) -> float:
return math.sqrt(x)
@register_primitive("pow")
def prim_pow(x: Any, n: Any) -> Any:
return x ** n
@register_primitive("abs")
def prim_abs(x: Any) -> Any:
return abs(x)
@register_primitive("floor")
def prim_floor(x: Any) -> int:
return math.floor(x)
@register_primitive("ceil")
def prim_ceil(x: Any) -> int:
return math.ceil(x)
@register_primitive("round")
def prim_round(x: Any, ndigits: Any = 0) -> Any:
return round(x, int(ndigits))
@register_primitive("min")
def prim_min(*args: Any) -> Any:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
return min(args[0])
return min(args)
@register_primitive("max")
def prim_max(*args: Any) -> Any:
if len(args) == 1 and isinstance(args[0], (list, tuple)):
return max(args[0])
return max(args)
@register_primitive("clamp")
def prim_clamp(x: Any, lo: Any, hi: Any) -> Any:
return max(lo, min(hi, x))
@register_primitive("inc")
def prim_inc(n: Any) -> Any:
return n + 1
@register_primitive("dec")
def prim_dec(n: Any) -> Any:
return n - 1
# ---------------------------------------------------------------------------
# Comparison
# ---------------------------------------------------------------------------
@register_primitive("=")
def prim_eq(a: Any, b: Any) -> bool:
return a == b
@register_primitive("!=")
def prim_neq(a: Any, b: Any) -> bool:
return a != b
@register_primitive("eq?")
def prim_eq_identity(a: Any, b: Any) -> bool:
"""Identity equality — true only if a and b are the same object."""
return a is b
@register_primitive("eqv?")
def prim_eqv(a: Any, b: Any) -> bool:
"""Equivalent: identity for compound types, value for atoms."""
if a is b:
return True
if isinstance(a, (int, float, str, bool)) and isinstance(b, type(a)):
return a == b
if (a is None or a is NIL) and (b is None or b is NIL):
return True
return False
@register_primitive("equal?")
def prim_equal(a: Any, b: Any) -> bool:
"""Deep structural equality (same as =)."""
return a == b
@register_primitive("<")
def prim_lt(a: Any, b: Any) -> bool:
return a < b
@register_primitive(">")
def prim_gt(a: Any, b: Any) -> bool:
return a > b
@register_primitive("<=")
def prim_lte(a: Any, b: Any) -> bool:
return a <= b
@register_primitive(">=")
def prim_gte(a: Any, b: Any) -> bool:
return a >= b
# ---------------------------------------------------------------------------
# Predicates
# ---------------------------------------------------------------------------
@register_primitive("odd?")
def prim_is_odd(n: Any) -> bool:
return n % 2 == 1
@register_primitive("even?")
def prim_is_even(n: Any) -> bool:
return n % 2 == 0
@register_primitive("zero?")
def prim_is_zero(n: Any) -> bool:
return n == 0
@register_primitive("nil?")
def prim_is_nil(x: Any) -> bool:
return x is None or x is NIL
@register_primitive("number?")
def prim_is_number(x: Any) -> bool:
return isinstance(x, (int, float))
@register_primitive("string?")
def prim_is_string(x: Any) -> bool:
return isinstance(x, str)
@register_primitive("list?")
def prim_is_list(x: Any) -> bool:
return isinstance(x, list)
@register_primitive("dict?")
def prim_is_dict(x: Any) -> bool:
return isinstance(x, dict)
@register_primitive("continuation?")
def prim_is_continuation(x: Any) -> bool:
from .types import Continuation
return isinstance(x, Continuation)
@register_primitive("empty?")
def prim_is_empty(coll: Any) -> bool:
if coll is None or coll is NIL:
return True
try:
return len(coll) == 0
except TypeError:
return False
@register_primitive("contains?")
def prim_contains(coll: Any, key: Any) -> bool:
if isinstance(coll, str):
return str(key) in coll
if isinstance(coll, dict):
k = key.name if isinstance(key, Keyword) else key
return k in coll
if isinstance(coll, (list, tuple)):
return key in coll
return False
# ---------------------------------------------------------------------------
# Logic (non-short-circuit versions; and/or are special forms)
# ---------------------------------------------------------------------------
@register_primitive("not")
def prim_not(x: Any) -> bool:
return not x
# ---------------------------------------------------------------------------
# Strings
# ---------------------------------------------------------------------------
@register_primitive("str")
def prim_str(*args: Any) -> str:
parts: list[str] = []
for a in args:
if a is None or a is NIL:
parts.append("")
elif isinstance(a, bool):
parts.append("true" if a else "false")
else:
parts.append(str(a))
return "".join(parts)
@register_primitive("concat")
def prim_concat(*colls: Any) -> list:
result: list[Any] = []
for c in colls:
if c is not None and c is not NIL:
result.extend(c)
return result
@register_primitive("upper")
def prim_upper(s: str) -> str:
return s.upper()
@register_primitive("lower")
def prim_lower(s: str) -> str:
return s.lower()
@register_primitive("trim")
def prim_trim(s: str) -> str:
return s.strip()
@register_primitive("split")
def prim_split(s: str, sep: str = " ") -> list[str]:
return s.split(sep)
@register_primitive("join")
def prim_join(sep: str, coll: list) -> str:
return sep.join(str(x) for x in coll)
@register_primitive("replace")
def prim_replace(s: str, old: str, new: str) -> str:
return s.replace(old, new)
@register_primitive("slice")
def prim_slice(coll: Any, start: int, end: Any = None) -> Any:
"""Slice a string or list: (slice coll start end?)."""
start = int(start)
if end is None or end is NIL:
return coll[start:]
return coll[start:int(end)]
@register_primitive("index-of")
def prim_index_of(s: str, needle: str, start: int = 0) -> int:
return str(s).find(needle, int(start))
@register_primitive("starts-with?")
def prim_starts_with(s, prefix: str) -> bool:
if not isinstance(s, str):
return False
return s.startswith(prefix)
@register_primitive("ends-with?")
def prim_ends_with(s: str, suffix: str) -> bool:
return s.endswith(suffix)
# ---------------------------------------------------------------------------
# Collections — construction
# ---------------------------------------------------------------------------
@register_primitive("list")
def prim_list(*args: Any) -> list:
return list(args)
@register_primitive("dict")
def prim_dict(*pairs: Any) -> dict:
result: dict[str, Any] = {}
i = 0
while i < len(pairs) - 1:
key = pairs[i]
if isinstance(key, Keyword):
key = key.name
result[key] = pairs[i + 1]
i += 2
return result
@register_primitive("range")
def prim_range(start: Any, end: Any, step: Any = 1) -> list[int]:
return list(range(int(start), int(end), int(step)))
# ---------------------------------------------------------------------------
# Collections — access
# ---------------------------------------------------------------------------
@register_primitive("get")
def prim_get(coll: Any, key: Any, default: Any = None) -> Any:
if isinstance(coll, dict):
result = coll.get(key)
if result is not None:
return result
if isinstance(key, Keyword):
result = coll.get(key.name)
if result is not None:
return result
return default
if isinstance(coll, list):
return coll[key] if 0 <= key < len(coll) else default
return default
@register_primitive("len")
def prim_len(coll: Any) -> int:
return len(coll)
@register_primitive("first")
def prim_first(coll: Any) -> Any:
return coll[0] if coll else NIL
@register_primitive("last")
def prim_last(coll: Any) -> Any:
return coll[-1] if coll else NIL
@register_primitive("rest")
def prim_rest(coll: Any) -> list:
return coll[1:] if coll else []
@register_primitive("nth")
def prim_nth(coll: Any, n: Any) -> Any:
return coll[n] if 0 <= n < len(coll) else NIL
@register_primitive("cons")
def prim_cons(x: Any, coll: Any) -> list:
return [x] + list(coll) if coll else [x]
@register_primitive("append")
def prim_append(coll: Any, x: Any) -> list:
return list(coll) + [x] if coll else [x]
@register_primitive("append!")
def prim_append_mut(coll: Any, x: Any) -> list:
coll.append(x)
return coll
@register_primitive("chunk-every")
def prim_chunk_every(coll: Any, n: Any) -> list:
n = int(n)
return [coll[i : i + n] for i in range(0, len(coll), n)]
@register_primitive("zip-pairs")
def prim_zip_pairs(coll: Any) -> list:
if not coll or len(coll) < 2:
return []
return [[coll[i], coll[i + 1]] for i in range(len(coll) - 1)]
# ---------------------------------------------------------------------------
# Collections — dict operations
# ---------------------------------------------------------------------------
@register_primitive("keys")
def prim_keys(d: dict) -> list:
return list(d.keys())
@register_primitive("vals")
def prim_vals(d: dict) -> list:
return list(d.values())
@register_primitive("merge")
def prim_merge(*dicts: Any) -> dict:
result: dict[str, Any] = {}
for d in dicts:
if d is not None and d is not NIL:
result.update(d)
return result
@register_primitive("assoc")
def prim_assoc(d: Any, *pairs: Any) -> dict:
result = dict(d) if d and d is not NIL else {}
i = 0
while i < len(pairs) - 1:
key = pairs[i]
if isinstance(key, Keyword):
key = key.name
result[key] = pairs[i + 1]
i += 2
return result
@register_primitive("dissoc")
def prim_dissoc(d: Any, *keys_to_remove: Any) -> dict:
result = dict(d) if d and d is not NIL else {}
for key in keys_to_remove:
if isinstance(key, Keyword):
key = key.name
result.pop(key, None)
return result
@register_primitive("dict-set!")
def prim_dict_set_mut(d: Any, key: Any, val: Any) -> Any:
if isinstance(key, Keyword):
key = key.name
d[key] = val
return val
@register_primitive("into")
def prim_into(target: Any, coll: Any) -> Any:
if isinstance(target, list):
if isinstance(coll, dict):
return [[k, v] for k, v in coll.items()]
return list(coll)
if isinstance(target, dict):
if isinstance(coll, dict):
return dict(coll)
result: dict[str, Any] = {}
for item in coll:
if isinstance(item, (list, tuple)) and len(item) >= 2:
key = item[0].name if isinstance(item[0], Keyword) else item[0]
result[key] = item[1]
return result
raise ValueError(f"into: unsupported target type {type(target).__name__}")