""" Async I/O primitives for the s-expression resolver. These wrap rose-ash's inter-service communication layer so that s-expressions can fetch fragments, query data, call actions, and access request context. Unlike pure primitives (primitives.py), these are **async** and are executed by the resolver rather than the evaluator. They are identified by name during the tree-walk phase and dispatched via ``asyncio.gather()``. Usage in s-expressions:: (frag "blog" "link-card" :slug "apple") (query "market" "products-by-ids" :ids "1,2,3") (action "market" "create-marketplace" :name "Farm Shop" :slug "farm") (current-user) (htmx-request?) """ from __future__ import annotations import contextvars from typing import Any # --------------------------------------------------------------------------- # Registry of async primitives (name → metadata) # --------------------------------------------------------------------------- # Names that the resolver recognises as I/O nodes requiring async resolution. # The resolver collects these during tree-walk, groups them, and dispatches # them in parallel. IO_PRIMITIVES: frozenset[str] = frozenset({ "frag", "query", "action", "current-user", "htmx-request?", "service", "request-arg", "request-path", "nav-tree", "get-children", "g", "csrf-token", "abort", "url-for", "route-prefix", }) # --------------------------------------------------------------------------- # Request context (set per-request by the resolver) # --------------------------------------------------------------------------- # ContextVar for the handler's domain service object. # Set by the handler blueprint before executing a defhandler. _handler_service: contextvars.ContextVar[Any] = contextvars.ContextVar( "_handler_service", default=None ) def set_handler_service(service_obj: Any) -> None: """Bind the local domain service for ``(service ...)`` primitive calls.""" _handler_service.set(service_obj) def get_handler_service() -> Any: """Get the currently bound handler service, or None.""" return _handler_service.get(None) class RequestContext: """Per-request context provided to I/O primitives. Populated by the resolver from the Quart request before resolution begins. """ __slots__ = ("user", "is_htmx", "extras") def __init__( self, user: dict[str, Any] | None = None, is_htmx: bool = False, extras: dict[str, Any] | None = None, ): self.user = user self.is_htmx = is_htmx self.extras = extras or {} # --------------------------------------------------------------------------- # I/O dispatch # --------------------------------------------------------------------------- async def execute_io( name: str, args: list[Any], kwargs: dict[str, Any], ctx: RequestContext, ) -> Any: """Execute an I/O primitive by name. Called by the resolver after collecting and grouping I/O nodes. Returns the result to be substituted back into the tree. """ handler = _IO_HANDLERS.get(name) if handler is None: raise RuntimeError(f"Unknown I/O primitive: {name}") return await handler(args, kwargs, ctx) # --------------------------------------------------------------------------- # Individual handlers # --------------------------------------------------------------------------- def _clean_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: """Strip None and NIL values from kwargs for Python interop.""" from .types import NIL return {k: v for k, v in kwargs.items() if v is not None and v is not NIL} async def _io_frag( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> str: """``(frag "service" "type" :key val ...)`` → fetch_fragment.""" if len(args) < 2: raise ValueError("frag requires service and fragment type") service = str(args[0]) frag_type = str(args[1]) params = _clean_kwargs(kwargs) from shared.infrastructure.fragments import fetch_fragment return await fetch_fragment(service, frag_type, params=params or None) async def _io_query( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> Any: """``(query "service" "query-name" :key val ...)`` → fetch_data.""" if len(args) < 2: raise ValueError("query requires service and query name") service = str(args[0]) query_name = str(args[1]) params = _clean_kwargs(kwargs) from shared.infrastructure.data_client import fetch_data return await fetch_data(service, query_name, params=params or None) async def _io_action( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> Any: """``(action "service" "action-name" :key val ...)`` → call_action.""" if len(args) < 2: raise ValueError("action requires service and action name") service = str(args[0]) action_name = str(args[1]) payload = _clean_kwargs(kwargs) from shared.infrastructure.actions import call_action return await call_action(service, action_name, payload=payload or None) async def _io_current_user( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> dict[str, Any] | None: """``(current-user)`` → user dict from request context.""" return ctx.user async def _io_htmx_request( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> bool: """``(htmx-request?)`` → True if HX-Request header present.""" return ctx.is_htmx async def _io_service( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> Any: """``(service "svc-name" "method-name" :key val ...)`` → call domain service. Looks up the service from the shared registry by name, then calls the named method with ``g.s`` (async session) + keyword args. Falls back to the bound handler service if only one positional arg is given. """ if not args: raise ValueError("service requires at least a method name") if len(args) >= 2: # (service "calendar" "associated-entries" :key val ...) from shared.services.registry import services as svc_registry svc_name = str(args[0]).replace("-", "_") svc = getattr(svc_registry, svc_name, None) if svc is None: raise RuntimeError(f"No service registered as: {svc_name}") method_name = str(args[1]).replace("-", "_") else: # (service "method-name" :key val ...) — legacy / bound service svc = get_handler_service() if svc is None: raise RuntimeError( "No handler service bound — cannot call (service ...)") method_name = str(args[0]).replace("-", "_") method = getattr(svc, method_name, None) if method is None: raise RuntimeError(f"Service has no method: {method_name}") # Convert kwarg keys from kebab-case to snake_case, NIL → None from .types import NIL clean_kwargs = { k.replace("-", "_"): (None if v is NIL else v) for k, v in kwargs.items() } from quart import g result = await method(g.s, **clean_kwargs) return _convert_result(result) def _dto_to_dict(obj: Any) -> dict[str, Any]: """Convert a DTO/dataclass/namedtuple to a plain dict. Adds ``{field}_year``, ``{field}_month``, ``{field}_day`` convenience keys for any datetime-valued field so sx handlers can build URL paths without parsing date strings. """ if hasattr(obj, "__dataclass_fields__"): from shared.contracts.dtos import dto_to_dict return dto_to_dict(obj) elif hasattr(obj, "_asdict"): d = dict(obj._asdict()) elif hasattr(obj, "__dict__"): d = {k: v for k, v in obj.__dict__.items() if not k.startswith("_")} else: return {"value": obj} # Expand datetime fields into year/month/day convenience keys for key, val in list(d.items()): if hasattr(val, "year") and hasattr(val, "strftime"): d[f"{key}_year"] = val.year d[f"{key}_month"] = val.month d[f"{key}_day"] = val.day return d def _convert_result(result: Any) -> Any: """Convert a service method result for sx consumption.""" if result is None: from .types import NIL return NIL if isinstance(result, dict): return {k: _convert_result(v) for k, v in result.items()} if isinstance(result, tuple): # Tuple returns (e.g. (entries, has_more)) → list for sx access return [_convert_result(item) for item in result] if hasattr(result, "__dataclass_fields__") or hasattr(result, "_asdict"): return _dto_to_dict(result) if isinstance(result, list): return [ _dto_to_dict(item) if hasattr(item, "__dataclass_fields__") or hasattr(item, "_asdict") else item for item in result ] return result async def _io_request_arg( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> Any: """``(request-arg "name" default?)`` → request.args.get(name, default).""" if not args: raise ValueError("request-arg requires a name") from quart import request name = str(args[0]) default = args[1] if len(args) > 1 else None return request.args.get(name, default) async def _io_request_path( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> str: """``(request-path)`` → request.path.""" from quart import request return request.path async def _io_nav_tree( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> list[dict[str, Any]]: """``(nav-tree)`` → list of navigation menu node dicts.""" from quart import g from shared.services.navigation import get_navigation_tree nodes = await get_navigation_tree(g.s) return [_dto_to_dict(node) for node in nodes] async def _io_get_children( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> list[dict[str, Any]]: """``(get-children :parent-type "page" :parent-id 1 ...)``""" from quart import g from shared.services.relationships import get_children clean = {k.replace("-", "_"): v for k, v in kwargs.items()} children = await get_children(g.s, **clean) return [_dto_to_dict(child) for child in children] # --------------------------------------------------------------------------- # Handler registry # --------------------------------------------------------------------------- async def _io_g( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> Any: """``(g "key")`` → getattr(g, key, None). Reads a value from the Quart request-local ``g`` object. Kebab-case keys are converted to snake_case automatically. """ from quart import g key = str(args[0]).replace("-", "_") if args else "" return getattr(g, key, None) async def _io_csrf_token( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> str: """``(csrf-token)`` → current CSRF token string.""" from quart import current_app csrf = current_app.jinja_env.globals.get("csrf_token") if callable(csrf): return csrf() return "" async def _io_abort( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> Any: """``(abort 403 "message")`` — raise HTTP error from SX. Allows defpages to abort with HTTP error codes for auth/ownership checks without needing a Python page helper. """ if not args: raise ValueError("abort requires a status code") from quart import abort status = int(args[0]) message = str(args[1]) if len(args) > 1 else "" abort(status, message) async def _io_url_for( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> str: """``(url-for "endpoint" :key val ...)`` → url_for(endpoint, **kwargs). Generates a URL for the given endpoint. Keyword args become URL parameters (kebab-case converted to snake_case). """ if not args: raise ValueError("url-for requires an endpoint name") from quart import url_for endpoint = str(args[0]) clean = {k.replace("-", "_"): v for k, v in _clean_kwargs(kwargs).items()} # Convert numeric values for int URL params for k, v in clean.items(): if isinstance(v, str) and v.isdigit(): clean[k] = int(v) return url_for(endpoint, **clean) async def _io_route_prefix( args: list[Any], kwargs: dict[str, Any], ctx: RequestContext ) -> str: """``(route-prefix)`` → current route prefix string.""" from shared.utils import route_prefix return route_prefix() _IO_HANDLERS: dict[str, Any] = { "frag": _io_frag, "query": _io_query, "action": _io_action, "current-user": _io_current_user, "htmx-request?": _io_htmx_request, "service": _io_service, "request-arg": _io_request_arg, "request-path": _io_request_path, "nav-tree": _io_nav_tree, "get-children": _io_get_children, "g": _io_g, "csrf-token": _io_csrf_token, "abort": _io_abort, "url-for": _io_url_for, "route-prefix": _io_route_prefix, }