From cc2dcbddd46da50449eb817eac79295dcda245aa Mon Sep 17 00:00:00 2001 From: giles Date: Tue, 24 Feb 2026 23:09:39 +0000 Subject: [PATCH] Squashed 'core/' content from commit 4957443 git-subtree-dir: core git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07 --- .gitignore | 47 + README.md | 110 ++ artdag/__init__.py | 61 + artdag/activities.py | 371 ++++ artdag/activitypub/__init__.py | 33 + artdag/activitypub/activity.py | 203 ++ artdag/activitypub/actor.py | 206 +++ artdag/activitypub/ownership.py | 226 +++ artdag/activitypub/signatures.py | 163 ++ artdag/analysis/__init__.py | 26 + artdag/analysis/analyzer.py | 282 +++ artdag/analysis/audio.py | 336 ++++ artdag/analysis/schema.py | 352 ++++ artdag/analysis/video.py | 266 +++ artdag/cache.py | 464 +++++ artdag/cli.py | 724 ++++++++ artdag/client.py | 201 ++ artdag/dag.py | 344 ++++ artdag/effects/__init__.py | 55 + artdag/effects/binding.py | 311 ++++ artdag/effects/frame_processor.py | 347 ++++ artdag/effects/loader.py | 455 +++++ artdag/effects/meta.py | 247 +++ artdag/effects/runner.py | 259 +++ artdag/effects/sandbox.py | 431 +++++ artdag/engine.py | 246 +++ artdag/executor.py | 106 ++ artdag/nodes/__init__.py | 11 + artdag/nodes/compose.py | 548 ++++++ artdag/nodes/effect.py | 520 ++++++ artdag/nodes/encoding.py | 50 + artdag/nodes/source.py | 62 + artdag/nodes/transform.py | 224 +++ artdag/planning/__init__.py | 29 + artdag/planning/planner.py | 756 ++++++++ artdag/planning/schema.py | 594 ++++++ artdag/planning/tree_reduction.py | 231 +++ artdag/registry/__init__.py | 20 + artdag/registry/registry.py | 294 +++ artdag/server.py | 253 +++ artdag/sexp/__init__.py | 75 + artdag/sexp/compiler.py | 2463 +++++++++++++++++++++++++ artdag/sexp/effect_loader.py | 337 ++++ artdag/sexp/evaluator.py | 869 +++++++++ artdag/sexp/external_tools.py | 292 +++ artdag/sexp/ffmpeg_compiler.py | 616 +++++++ artdag/sexp/parser.py | 425 +++++ artdag/sexp/planner.py | 2187 ++++++++++++++++++++++ artdag/sexp/primitives.py | 620 +++++++ artdag/sexp/scheduler.py | 779 ++++++++ artdag/sexp/stage_cache.py | 412 +++++ artdag/sexp/test_ffmpeg_compiler.py | 146 ++ artdag/sexp/test_primitives.py | 201 ++ artdag/sexp/test_stage_cache.py | 324 ++++ artdag/sexp/test_stage_compiler.py | 286 +++ artdag/sexp/test_stage_integration.py | 739 ++++++++ artdag/sexp/test_stage_planner.py | 228 +++ artdag/sexp/test_stage_scheduler.py | 323 ++++ docs/EXECUTION_MODEL.md | 384 ++++ docs/IPFS_PRIMARY_ARCHITECTURE.md | 443 +++++ docs/L1_STORAGE.md | 181 ++ docs/OFFLINE_TESTING.md | 211 +++ effects/identity/README.md | 35 + effects/identity/requirements.txt | 2 + examples/simple_sequence.yaml | 42 + examples/test_local.sh | 54 + examples/test_plan.py | 93 + pyproject.toml | 62 + scripts/compute_repo_hash.py | 67 + scripts/install-ffglitch.sh | 82 + scripts/register_identity_effect.py | 83 + scripts/setup_actor.py | 120 ++ scripts/sign_assets.py | 143 ++ tests/__init__.py | 1 + tests/test_activities.py | 613 ++++++ tests/test_cache.py | 163 ++ tests/test_dag.py | 271 +++ tests/test_engine.py | 464 +++++ tests/test_executor.py | 110 ++ tests/test_ipfs_access.py | 301 +++ 80 files changed, 25711 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 artdag/__init__.py create mode 100644 artdag/activities.py create mode 100644 artdag/activitypub/__init__.py create mode 100644 artdag/activitypub/activity.py create mode 100644 artdag/activitypub/actor.py create mode 100644 artdag/activitypub/ownership.py create mode 100644 artdag/activitypub/signatures.py create mode 100644 artdag/analysis/__init__.py create mode 100644 artdag/analysis/analyzer.py create mode 100644 artdag/analysis/audio.py create mode 100644 artdag/analysis/schema.py create mode 100644 artdag/analysis/video.py create mode 100644 artdag/cache.py create mode 100644 artdag/cli.py create mode 100644 artdag/client.py create mode 100644 artdag/dag.py create mode 100644 artdag/effects/__init__.py create mode 100644 artdag/effects/binding.py create mode 100644 artdag/effects/frame_processor.py create mode 100644 artdag/effects/loader.py create mode 100644 artdag/effects/meta.py create mode 100644 artdag/effects/runner.py create mode 100644 artdag/effects/sandbox.py create mode 100644 artdag/engine.py create mode 100644 artdag/executor.py create mode 100644 artdag/nodes/__init__.py create mode 100644 artdag/nodes/compose.py create mode 100644 artdag/nodes/effect.py create mode 100644 artdag/nodes/encoding.py create mode 100644 artdag/nodes/source.py create mode 100644 artdag/nodes/transform.py create mode 100644 artdag/planning/__init__.py create mode 100644 artdag/planning/planner.py create mode 100644 artdag/planning/schema.py create mode 100644 artdag/planning/tree_reduction.py create mode 100644 artdag/registry/__init__.py create mode 100644 artdag/registry/registry.py create mode 100644 artdag/server.py create mode 100644 artdag/sexp/__init__.py create mode 100644 artdag/sexp/compiler.py create mode 100644 artdag/sexp/effect_loader.py create mode 100644 artdag/sexp/evaluator.py create mode 100644 artdag/sexp/external_tools.py create mode 100644 artdag/sexp/ffmpeg_compiler.py create mode 100644 artdag/sexp/parser.py create mode 100644 artdag/sexp/planner.py create mode 100644 artdag/sexp/primitives.py create mode 100644 artdag/sexp/scheduler.py create mode 100644 artdag/sexp/stage_cache.py create mode 100644 artdag/sexp/test_ffmpeg_compiler.py create mode 100644 artdag/sexp/test_primitives.py create mode 100644 artdag/sexp/test_stage_cache.py create mode 100644 artdag/sexp/test_stage_compiler.py create mode 100644 artdag/sexp/test_stage_integration.py create mode 100644 artdag/sexp/test_stage_planner.py create mode 100644 artdag/sexp/test_stage_scheduler.py create mode 100644 docs/EXECUTION_MODEL.md create mode 100644 docs/IPFS_PRIMARY_ARCHITECTURE.md create mode 100644 docs/L1_STORAGE.md create mode 100644 docs/OFFLINE_TESTING.md create mode 100644 effects/identity/README.md create mode 100644 effects/identity/requirements.txt create mode 100644 examples/simple_sequence.yaml create mode 100755 examples/test_local.sh create mode 100755 examples/test_plan.py create mode 100644 pyproject.toml create mode 100644 scripts/compute_repo_hash.py create mode 100755 scripts/install-ffglitch.sh create mode 100644 scripts/register_identity_effect.py create mode 100644 scripts/setup_actor.py create mode 100644 scripts/sign_assets.py create mode 100644 tests/__init__.py create mode 100644 tests/test_activities.py create mode 100644 tests/test_cache.py create mode 100644 tests/test_dag.py create mode 100644 tests/test_engine.py create mode 100644 tests/test_executor.py create mode 100644 tests/test_ipfs_access.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2d2f90c --- /dev/null +++ b/.gitignore @@ -0,0 +1,47 @@ +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Environment +.env +.venv +env/ +venv/ + +# Private keys (ActivityPub secrets) +.cache/ + +# Test outputs +test_cache/ +test_plan_output.json +analysis.json +plan.json +plan_with_analysis.json diff --git a/README.md b/README.md new file mode 100644 index 0000000..27602a4 --- /dev/null +++ b/README.md @@ -0,0 +1,110 @@ +# artdag + +Content-addressed DAG execution engine with ActivityPub ownership. + +## Features + +- **Content-addressed nodes**: `node_id = SHA3-256(type + config + inputs)` for automatic deduplication +- **Quantum-resistant hashing**: SHA-3 throughout for future-proof integrity +- **ActivityPub ownership**: Cryptographically signed ownership claims +- **Federated identity**: `@user@artdag.rose-ash.com` style identities +- **Pluggable executors**: Register custom node types +- **Built-in video primitives**: SOURCE, SEGMENT, RESIZE, TRANSFORM, SEQUENCE, MUX, BLEND + +## Installation + +```bash +pip install -e . +``` + +### Optional: External Effect Tools + +Some effects can use external tools for better performance: + +**Pixelsort** (glitch art pixel sorting): +```bash +# Rust CLI (recommended - fast) +cargo install --git https://github.com/Void-ux/pixelsort.git pixelsort + +# Or Python CLI +pip install git+https://github.com/Blotz/pixelsort-cli +``` + +**Datamosh** (video glitch/corruption): +```bash +# FFglitch (recommended) +./scripts/install-ffglitch.sh + +# Or Python CLI +pip install git+https://github.com/tiberiuiancu/datamoshing +``` + +Check available tools: +```bash +python -m artdag.sexp.external_tools +``` + +## Quick Start + +```python +from artdag import Engine, DAGBuilder, Registry +from artdag.activitypub import OwnershipManager + +# Create ownership manager +manager = OwnershipManager("./my_registry") + +# Create your identity +actor = manager.create_actor("alice", "Alice") +print(f"Created: {actor.handle}") # @alice@artdag.rose-ash.com + +# Register an asset with ownership +asset, activity = manager.register_asset( + actor=actor, + name="my_image", + path="/path/to/image.jpg", + tags=["photo", "art"], +) +print(f"Owned: {asset.name} (hash: {asset.content_hash})") + +# Build and execute a DAG +engine = Engine("./cache") +builder = DAGBuilder() + +source = builder.source(str(asset.path)) +resized = builder.resize(source, width=1920, height=1080) +builder.set_output(resized) + +result = engine.execute(builder.build()) +print(f"Output: {result.output_path}") +``` + +## Architecture + +``` +artdag/ +├── dag.py # Node, DAG, DAGBuilder +├── cache.py # Content-addressed file cache +├── executor.py # Base executor + registry +├── engine.py # DAG execution engine +├── activitypub/ # Identity + ownership +│ ├── actor.py # Actor identity with RSA keys +│ ├── activity.py # Create, Announce activities +│ ├── signatures.py # RSA signing/verification +│ └── ownership.py # Links actors to assets +├── nodes/ # Built-in executors +│ ├── source.py # SOURCE +│ ├── transform.py # SEGMENT, RESIZE, TRANSFORM +│ ├── compose.py # SEQUENCE, LAYER, MUX, BLEND +│ └── effect.py # EFFECT (identity, etc.) +└── effects/ # Effect implementations + └── identity/ # The foundational identity effect +``` + +## Related Repos + +- **Registry**: https://git.rose-ash.com/art-dag/registry - Asset registry with ownership proofs +- **Recipes**: https://git.rose-ash.com/art-dag/recipes - DAG recipes using effects + +## License + +MIT diff --git a/artdag/__init__.py b/artdag/__init__.py new file mode 100644 index 0000000..4b8abe2 --- /dev/null +++ b/artdag/__init__.py @@ -0,0 +1,61 @@ +# artdag - Content-addressed DAG execution engine with ActivityPub ownership +# +# A standalone execution engine that processes directed acyclic graphs (DAGs) +# where each node represents an operation. Nodes are content-addressed for +# automatic caching and deduplication. +# +# Core concepts: +# - Node: An operation with type, config, and inputs +# - DAG: A graph of nodes with a designated output node +# - Executor: Implements the actual operation for a node type +# - Engine: Executes DAGs by resolving dependencies and running executors + +from .dag import Node, DAG, DAGBuilder, NodeType +from .cache import Cache, CacheEntry +from .executor import Executor, register_executor, get_executor +from .engine import Engine +from .registry import Registry, Asset +from .activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn + +# Analysis and planning modules (optional, require extra dependencies) +try: + from .analysis import Analyzer, AnalysisResult +except ImportError: + Analyzer = None + AnalysisResult = None + +try: + from .planning import RecipePlanner, ExecutionPlan, ExecutionStep +except ImportError: + RecipePlanner = None + ExecutionPlan = None + ExecutionStep = None + +__all__ = [ + # Core + "Node", + "DAG", + "DAGBuilder", + "NodeType", + "Cache", + "CacheEntry", + "Executor", + "register_executor", + "get_executor", + "Engine", + "Registry", + "Asset", + "Activity", + "ActivityStore", + "ActivityManager", + "make_is_shared_fn", + # Analysis (optional) + "Analyzer", + "AnalysisResult", + # Planning (optional) + "RecipePlanner", + "ExecutionPlan", + "ExecutionStep", +] + +__version__ = "0.1.0" diff --git a/artdag/activities.py b/artdag/activities.py new file mode 100644 index 0000000..0919ee7 --- /dev/null +++ b/artdag/activities.py @@ -0,0 +1,371 @@ +# artdag/activities.py +""" +Persistent activity (job) tracking for cache management. + +Activities represent executions of DAGs. They track: +- Input node IDs (sources) +- Output node ID (terminal node) +- Intermediate node IDs (everything in between) + +This enables deletion rules: +- Shared items (ActivityPub published) cannot be deleted +- Inputs/outputs of activities cannot be deleted +- Intermediates can be deleted (reconstructible) +- Activities can only be discarded if no items are shared +""" + +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Set + +from .cache import Cache, CacheEntry +from .dag import DAG + +logger = logging.getLogger(__name__) + + +def make_is_shared_fn(activitypub_store: "ActivityStore") -> Callable[[str], bool]: + """ + Create an is_shared function from an ActivityPub ActivityStore. + + Args: + activitypub_store: The ActivityPub activity store + (from artdag.activitypub.activity) + + Returns: + Function that checks if a cid has been published + """ + def is_shared(cid: str) -> bool: + activities = activitypub_store.find_by_object_hash(cid) + return any(a.activity_type == "Create" for a in activities) + return is_shared + + +@dataclass +class Activity: + """ + A recorded execution of a DAG. + + Tracks which cache entries are inputs, outputs, and intermediates + to enforce deletion rules. + """ + activity_id: str + input_ids: List[str] # Source node cache IDs + output_id: str # Terminal node cache ID + intermediate_ids: List[str] # Everything in between + created_at: float + status: str = "completed" # pending|running|completed|failed + dag_snapshot: Optional[Dict[str, Any]] = None # Serialized DAG for reconstruction + + def to_dict(self) -> Dict[str, Any]: + return { + "activity_id": self.activity_id, + "input_ids": self.input_ids, + "output_id": self.output_id, + "intermediate_ids": self.intermediate_ids, + "created_at": self.created_at, + "status": self.status, + "dag_snapshot": self.dag_snapshot, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Activity": + return cls( + activity_id=data["activity_id"], + input_ids=data["input_ids"], + output_id=data["output_id"], + intermediate_ids=data["intermediate_ids"], + created_at=data["created_at"], + status=data.get("status", "completed"), + dag_snapshot=data.get("dag_snapshot"), + ) + + @classmethod + def from_dag(cls, dag: DAG, activity_id: str = None) -> "Activity": + """ + Create an Activity from a DAG. + + Classifies nodes as inputs, output, or intermediates. + """ + if activity_id is None: + activity_id = str(uuid.uuid4()) + + # Find input nodes (nodes with no inputs - sources) + input_ids = [] + for node_id, node in dag.nodes.items(): + if not node.inputs: + input_ids.append(node_id) + + # Output is the terminal node + output_id = dag.output_id + + # Intermediates are everything else + intermediate_ids = [] + for node_id in dag.nodes: + if node_id not in input_ids and node_id != output_id: + intermediate_ids.append(node_id) + + return cls( + activity_id=activity_id, + input_ids=sorted(input_ids), + output_id=output_id, + intermediate_ids=sorted(intermediate_ids), + created_at=time.time(), + status="completed", + dag_snapshot=dag.to_dict(), + ) + + @property + def all_node_ids(self) -> List[str]: + """All node IDs involved in this activity.""" + return self.input_ids + [self.output_id] + self.intermediate_ids + + +class ActivityStore: + """ + Persistent storage for activities. + + Provides methods to check deletion eligibility and perform deletions. + """ + + def __init__(self, store_dir: Path | str): + self.store_dir = Path(store_dir) + self.store_dir.mkdir(parents=True, exist_ok=True) + self._activities: Dict[str, Activity] = {} + self._load() + + def _index_path(self) -> Path: + return self.store_dir / "activities.json" + + def _load(self): + """Load activities from disk.""" + index_path = self._index_path() + if index_path.exists(): + try: + with open(index_path) as f: + data = json.load(f) + self._activities = { + a["activity_id"]: Activity.from_dict(a) + for a in data.get("activities", []) + } + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to load activities: {e}") + self._activities = {} + + def _save(self): + """Save activities to disk.""" + data = { + "version": "1.0", + "activities": [a.to_dict() for a in self._activities.values()], + } + with open(self._index_path(), "w") as f: + json.dump(data, f, indent=2) + + def add(self, activity: Activity) -> None: + """Add an activity.""" + self._activities[activity.activity_id] = activity + self._save() + + def get(self, activity_id: str) -> Optional[Activity]: + """Get an activity by ID.""" + return self._activities.get(activity_id) + + def remove(self, activity_id: str) -> bool: + """Remove an activity record (does not delete cache entries).""" + if activity_id not in self._activities: + return False + del self._activities[activity_id] + self._save() + return True + + def list(self) -> List[Activity]: + """List all activities.""" + return list(self._activities.values()) + + def find_by_input_ids(self, input_ids: List[str]) -> List[Activity]: + """Find activities with the same inputs (for UI grouping).""" + sorted_inputs = sorted(input_ids) + return [ + a for a in self._activities.values() + if sorted(a.input_ids) == sorted_inputs + ] + + def find_using_node(self, node_id: str) -> List[Activity]: + """Find all activities that reference a node ID.""" + return [ + a for a in self._activities.values() + if node_id in a.all_node_ids + ] + + def __len__(self) -> int: + return len(self._activities) + + +class ActivityManager: + """ + Manages activities and cache deletion with sharing rules. + + Deletion rules: + 1. Shared items (ActivityPub published) cannot be deleted + 2. Inputs/outputs of activities cannot be deleted + 3. Intermediates can be deleted (reconstructible) + 4. Activities can only be discarded if no items are shared + """ + + def __init__( + self, + cache: Cache, + activity_store: ActivityStore, + is_shared_fn: Callable[[str], bool], + ): + """ + Args: + cache: The L1 cache + activity_store: Activity persistence + is_shared_fn: Function that checks if a cid is shared + (published via ActivityPub) + """ + self.cache = cache + self.activities = activity_store + self._is_shared = is_shared_fn + + def record_activity(self, dag: DAG) -> Activity: + """Record a completed DAG execution as an activity.""" + activity = Activity.from_dag(dag) + self.activities.add(activity) + return activity + + def is_shared(self, node_id: str) -> bool: + """Check if a cache entry is shared (published via ActivityPub).""" + entry = self.cache.get_entry(node_id) + if not entry or not entry.cid: + return False + return self._is_shared(entry.cid) + + def can_delete_cache_entry(self, node_id: str) -> bool: + """ + Check if a cache entry can be deleted. + + Returns False if: + - Entry is shared (ActivityPub published) + - Entry is an input or output of any activity + """ + # Check if shared + if self.is_shared(node_id): + return False + + # Check if it's an input or output of any activity + for activity in self.activities.list(): + if node_id in activity.input_ids: + return False + if node_id == activity.output_id: + return False + + # It's either an intermediate or orphaned - can delete + return True + + def can_discard_activity(self, activity_id: str) -> bool: + """ + Check if an activity can be discarded. + + Returns False if any cache entry (input, output, or intermediate) + is shared via ActivityPub. + """ + activity = self.activities.get(activity_id) + if not activity: + return False + + # Check if any item is shared + for node_id in activity.all_node_ids: + if self.is_shared(node_id): + return False + + return True + + def discard_activity(self, activity_id: str) -> bool: + """ + Discard an activity and delete its intermediate cache entries. + + Returns False if the activity cannot be discarded (has shared items). + + When discarded: + - Intermediate cache entries are deleted + - The activity record is removed + - Inputs remain (may be used by other activities) + - Output is deleted if orphaned (not shared, not used elsewhere) + """ + if not self.can_discard_activity(activity_id): + return False + + activity = self.activities.get(activity_id) + if not activity: + return False + + output_id = activity.output_id + intermediate_ids = list(activity.intermediate_ids) + + # Remove the activity record first + self.activities.remove(activity_id) + + # Delete intermediates + for node_id in intermediate_ids: + self.cache.remove(node_id) + logger.debug(f"Deleted intermediate: {node_id}") + + # Check if output is now orphaned + if self._is_orphaned(output_id) and not self.is_shared(output_id): + self.cache.remove(output_id) + logger.debug(f"Deleted orphaned output: {output_id}") + + # Inputs remain - they may be used by other activities + # But check if any are orphaned now + for input_id in activity.input_ids: + if self._is_orphaned(input_id) and not self.is_shared(input_id): + self.cache.remove(input_id) + logger.debug(f"Deleted orphaned input: {input_id}") + + return True + + def _is_orphaned(self, node_id: str) -> bool: + """Check if a node is not referenced by any activity.""" + for activity in self.activities.list(): + if node_id in activity.all_node_ids: + return False + return True + + def get_deletable_entries(self) -> List[CacheEntry]: + """Get all cache entries that can be deleted.""" + deletable = [] + for entry in self.cache.list_entries(): + if self.can_delete_cache_entry(entry.node_id): + deletable.append(entry) + return deletable + + def get_discardable_activities(self) -> List[Activity]: + """Get all activities that can be discarded.""" + return [ + a for a in self.activities.list() + if self.can_discard_activity(a.activity_id) + ] + + def cleanup_intermediates(self) -> int: + """ + Delete all intermediate cache entries. + + Intermediates are safe to delete as they can be reconstructed + from inputs using the DAG. + + Returns: + Number of entries deleted + """ + deleted = 0 + for activity in self.activities.list(): + for node_id in activity.intermediate_ids: + if self.cache.has(node_id): + self.cache.remove(node_id) + deleted += 1 + return deleted diff --git a/artdag/activitypub/__init__.py b/artdag/activitypub/__init__.py new file mode 100644 index 0000000..e9abbdc --- /dev/null +++ b/artdag/activitypub/__init__.py @@ -0,0 +1,33 @@ +# primitive/activitypub/__init__.py +""" +ActivityPub implementation for Art DAG. + +Provides decentralized identity and ownership for assets. +Domain: artdag.rose-ash.com + +Core concepts: +- Actor: A user identity with cryptographic keys +- Object: An asset (image, video, etc.) +- Activity: An action (Create, Announce, Like, etc.) +- Signature: Cryptographic proof of authorship +""" + +from .actor import Actor, ActorStore +from .activity import Activity, CreateActivity, ActivityStore +from .signatures import sign_activity, verify_signature, verify_activity_ownership +from .ownership import OwnershipManager, OwnershipRecord + +__all__ = [ + "Actor", + "ActorStore", + "Activity", + "CreateActivity", + "ActivityStore", + "sign_activity", + "verify_signature", + "verify_activity_ownership", + "OwnershipManager", + "OwnershipRecord", +] + +DOMAIN = "artdag.rose-ash.com" diff --git a/artdag/activitypub/activity.py b/artdag/activitypub/activity.py new file mode 100644 index 0000000..d7ab9b8 --- /dev/null +++ b/artdag/activitypub/activity.py @@ -0,0 +1,203 @@ +# primitive/activitypub/activity.py +""" +ActivityPub Activity types. + +Activities represent actions taken by actors on objects. +Key activity types for Art DAG: +- Create: Actor creates/claims ownership of an object +- Announce: Actor shares/boosts an object +- Like: Actor endorses an object +""" + +import json +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .actor import Actor, DOMAIN + + +def _generate_id() -> str: + """Generate unique activity ID.""" + return str(uuid.uuid4()) + + +@dataclass +class Activity: + """ + Base ActivityPub Activity. + + Attributes: + activity_id: Unique identifier + activity_type: Type (Create, Announce, Like, etc.) + actor_id: ID of the actor performing the activity + object_data: The object of the activity + published: ISO timestamp + signature: Cryptographic signature (added after signing) + """ + activity_id: str + activity_type: str + actor_id: str + object_data: Dict[str, Any] + published: str = field(default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())) + signature: Optional[Dict[str, Any]] = None + + def to_activitypub(self) -> Dict[str, Any]: + """Return ActivityPub JSON-LD representation.""" + activity = { + "@context": "https://www.w3.org/ns/activitystreams", + "type": self.activity_type, + "id": f"https://{DOMAIN}/activities/{self.activity_id}", + "actor": self.actor_id, + "object": self.object_data, + "published": self.published, + } + if self.signature: + activity["signature"] = self.signature + return activity + + def to_dict(self) -> Dict[str, Any]: + """Serialize for storage.""" + return { + "activity_id": self.activity_id, + "activity_type": self.activity_type, + "actor_id": self.actor_id, + "object_data": self.object_data, + "published": self.published, + "signature": self.signature, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Activity": + """Deserialize from storage.""" + return cls( + activity_id=data["activity_id"], + activity_type=data["activity_type"], + actor_id=data["actor_id"], + object_data=data["object_data"], + published=data.get("published", ""), + signature=data.get("signature"), + ) + + +@dataclass +class CreateActivity(Activity): + """ + Create activity - establishes ownership of an object. + + Used when an actor creates or claims an asset. + """ + activity_type: str = field(default="Create", init=False) + + @classmethod + def for_asset( + cls, + actor: Actor, + asset_name: str, + cid: str, + asset_type: str = "Image", + metadata: Dict[str, Any] = None, + ) -> "CreateActivity": + """ + Create a Create activity for an asset. + + Args: + actor: The actor claiming ownership + asset_name: Name of the asset + cid: SHA-3 hash of the asset content + asset_type: ActivityPub object type (Image, Video, Audio, etc.) + metadata: Additional metadata + + Returns: + CreateActivity establishing ownership + """ + object_data = { + "type": asset_type, + "name": asset_name, + "id": f"https://{DOMAIN}/objects/{cid}", + "contentHash": { + "algorithm": "sha3-256", + "value": cid, + }, + "attributedTo": actor.id, + } + if metadata: + object_data["metadata"] = metadata + + return cls( + activity_id=_generate_id(), + actor_id=actor.id, + object_data=object_data, + ) + + +class ActivityStore: + """ + Persistent storage for activities. + + Activities are stored as an append-only log for auditability. + """ + + def __init__(self, store_dir: Path | str): + self.store_dir = Path(store_dir) + self.store_dir.mkdir(parents=True, exist_ok=True) + self._activities: List[Activity] = [] + self._load() + + def _log_path(self) -> Path: + return self.store_dir / "activities.json" + + def _load(self): + """Load activities from disk.""" + log_path = self._log_path() + if log_path.exists(): + with open(log_path) as f: + data = json.load(f) + self._activities = [ + Activity.from_dict(a) for a in data.get("activities", []) + ] + + def _save(self): + """Save activities to disk.""" + data = { + "version": "1.0", + "activities": [a.to_dict() for a in self._activities], + } + with open(self._log_path(), "w") as f: + json.dump(data, f, indent=2) + + def add(self, activity: Activity) -> None: + """Add an activity to the log.""" + self._activities.append(activity) + self._save() + + def get(self, activity_id: str) -> Optional[Activity]: + """Get an activity by ID.""" + for a in self._activities: + if a.activity_id == activity_id: + return a + return None + + def list(self) -> List[Activity]: + """List all activities.""" + return list(self._activities) + + def find_by_actor(self, actor_id: str) -> List[Activity]: + """Find activities by actor.""" + return [a for a in self._activities if a.actor_id == actor_id] + + def find_by_object_hash(self, cid: str) -> List[Activity]: + """Find activities referencing an object by hash.""" + results = [] + for a in self._activities: + obj_hash = a.object_data.get("contentHash", {}) + if isinstance(obj_hash, dict) and obj_hash.get("value") == cid: + results.append(a) + elif a.object_data.get("contentHash") == cid: + results.append(a) + return results + + def __len__(self) -> int: + return len(self._activities) diff --git a/artdag/activitypub/actor.py b/artdag/activitypub/actor.py new file mode 100644 index 0000000..8e0deed --- /dev/null +++ b/artdag/activitypub/actor.py @@ -0,0 +1,206 @@ +# primitive/activitypub/actor.py +""" +ActivityPub Actor management. + +An Actor is an identity with: +- Username and display name +- RSA key pair for signing +- ActivityPub-compliant JSON-LD representation +""" + +import json +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa, padding + +DOMAIN = "artdag.rose-ash.com" + + +def _generate_keypair() -> tuple[bytes, bytes]: + """Generate RSA key pair for signing.""" + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + private_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + public_pem = private_key.public_key().public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return private_pem, public_pem + + +@dataclass +class Actor: + """ + An ActivityPub Actor (identity). + + Attributes: + username: Unique username (e.g., "giles") + display_name: Human-readable name + public_key: PEM-encoded public key + private_key: PEM-encoded private key (kept secret) + created_at: Timestamp of creation + """ + username: str + display_name: str + public_key: bytes + private_key: bytes + created_at: float = field(default_factory=time.time) + domain: str = DOMAIN + + @property + def id(self) -> str: + """ActivityPub actor ID (URL).""" + return f"https://{self.domain}/users/{self.username}" + + @property + def handle(self) -> str: + """Fediverse handle.""" + return f"@{self.username}@{self.domain}" + + @property + def inbox(self) -> str: + """ActivityPub inbox URL.""" + return f"{self.id}/inbox" + + @property + def outbox(self) -> str: + """ActivityPub outbox URL.""" + return f"{self.id}/outbox" + + @property + def key_id(self) -> str: + """Key ID for HTTP Signatures.""" + return f"{self.id}#main-key" + + def to_activitypub(self) -> Dict[str, Any]: + """Return ActivityPub JSON-LD representation.""" + return { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1", + ], + "type": "Person", + "id": self.id, + "preferredUsername": self.username, + "name": self.display_name, + "inbox": self.inbox, + "outbox": self.outbox, + "publicKey": { + "id": self.key_id, + "owner": self.id, + "publicKeyPem": self.public_key.decode("utf-8"), + }, + } + + def to_dict(self) -> Dict[str, Any]: + """Serialize for storage.""" + return { + "username": self.username, + "display_name": self.display_name, + "public_key": self.public_key.decode("utf-8"), + "private_key": self.private_key.decode("utf-8"), + "created_at": self.created_at, + "domain": self.domain, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Actor": + """Deserialize from storage.""" + return cls( + username=data["username"], + display_name=data["display_name"], + public_key=data["public_key"].encode("utf-8"), + private_key=data["private_key"].encode("utf-8"), + created_at=data.get("created_at", time.time()), + domain=data.get("domain", DOMAIN), + ) + + @classmethod + def create(cls, username: str, display_name: str = None) -> "Actor": + """Create a new actor with generated keys.""" + private_pem, public_pem = _generate_keypair() + return cls( + username=username, + display_name=display_name or username, + public_key=public_pem, + private_key=private_pem, + ) + + +class ActorStore: + """ + Persistent storage for actors. + + Structure: + store_dir/ + actors.json # Index of all actors + keys/ + .private.pem + .public.pem + """ + + def __init__(self, store_dir: Path | str): + self.store_dir = Path(store_dir) + self.store_dir.mkdir(parents=True, exist_ok=True) + self._actors: Dict[str, Actor] = {} + self._load() + + def _index_path(self) -> Path: + return self.store_dir / "actors.json" + + def _load(self): + """Load actors from disk.""" + index_path = self._index_path() + if index_path.exists(): + with open(index_path) as f: + data = json.load(f) + self._actors = { + username: Actor.from_dict(actor_data) + for username, actor_data in data.get("actors", {}).items() + } + + def _save(self): + """Save actors to disk.""" + data = { + "version": "1.0", + "domain": DOMAIN, + "actors": { + username: actor.to_dict() + for username, actor in self._actors.items() + }, + } + with open(self._index_path(), "w") as f: + json.dump(data, f, indent=2) + + def create(self, username: str, display_name: str = None) -> Actor: + """Create and store a new actor.""" + if username in self._actors: + raise ValueError(f"Actor {username} already exists") + + actor = Actor.create(username, display_name) + self._actors[username] = actor + self._save() + return actor + + def get(self, username: str) -> Optional[Actor]: + """Get an actor by username.""" + return self._actors.get(username) + + def list(self) -> list[Actor]: + """List all actors.""" + return list(self._actors.values()) + + def __contains__(self, username: str) -> bool: + return username in self._actors + + def __len__(self) -> int: + return len(self._actors) diff --git a/artdag/activitypub/ownership.py b/artdag/activitypub/ownership.py new file mode 100644 index 0000000..8290871 --- /dev/null +++ b/artdag/activitypub/ownership.py @@ -0,0 +1,226 @@ +# primitive/activitypub/ownership.py +""" +Ownership integration between ActivityPub and Registry. + +Connects actors, activities, and assets to establish provable ownership. +""" + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .actor import Actor, ActorStore +from .activity import Activity, CreateActivity, ActivityStore +from .signatures import sign_activity, verify_activity_ownership +from ..registry import Registry, Asset + + +@dataclass +class OwnershipRecord: + """ + A verified ownership record linking actor to asset. + + Attributes: + actor_handle: The actor's fediverse handle + asset_name: Name of the owned asset + cid: SHA-3 hash of the asset + activity_id: ID of the Create activity establishing ownership + verified: Whether the signature has been verified + """ + actor_handle: str + asset_name: str + cid: str + activity_id: str + verified: bool = False + + +class OwnershipManager: + """ + Manages ownership relationships between actors and assets. + + Integrates: + - ActorStore: Identity management + - Registry: Asset storage + - ActivityStore: Ownership activities + """ + + def __init__(self, base_dir: Path | str): + self.base_dir = Path(base_dir) + self.base_dir.mkdir(parents=True, exist_ok=True) + + # Initialize stores + self.actors = ActorStore(self.base_dir / "actors") + self.activities = ActivityStore(self.base_dir / "activities") + self.registry = Registry(self.base_dir / "registry") + + def create_actor(self, username: str, display_name: str = None) -> Actor: + """Create a new actor identity.""" + return self.actors.create(username, display_name) + + def get_actor(self, username: str) -> Optional[Actor]: + """Get an actor by username.""" + return self.actors.get(username) + + def register_asset( + self, + actor: Actor, + name: str, + cid: str, + url: str = None, + local_path: Path | str = None, + tags: List[str] = None, + metadata: Dict[str, Any] = None, + ) -> tuple[Asset, Activity]: + """ + Register an asset and establish ownership. + + Creates the asset in the registry and a signed Create activity + proving the actor's ownership. + + Args: + actor: The actor claiming ownership + name: Name for the asset + cid: SHA-3-256 hash of the content + url: Public URL (canonical location) + local_path: Optional local path + tags: Optional tags + metadata: Optional metadata + + Returns: + Tuple of (Asset, signed CreateActivity) + """ + # Add to registry + asset = self.registry.add( + name=name, + cid=cid, + url=url, + local_path=local_path, + tags=tags, + metadata=metadata, + ) + + # Create ownership activity + activity = CreateActivity.for_asset( + actor=actor, + asset_name=name, + cid=asset.cid, + asset_type=self._asset_type_to_ap(asset.asset_type), + metadata=metadata, + ) + + # Sign the activity + signed_activity = sign_activity(activity, actor) + + # Store the activity + self.activities.add(signed_activity) + + return asset, signed_activity + + def _asset_type_to_ap(self, asset_type: str) -> str: + """Convert registry asset type to ActivityPub type.""" + type_map = { + "image": "Image", + "video": "Video", + "audio": "Audio", + "unknown": "Document", + } + return type_map.get(asset_type, "Document") + + def get_owner(self, asset_name: str) -> Optional[Actor]: + """ + Get the owner of an asset. + + Finds the earliest Create activity for the asset and returns + the actor if the signature is valid. + """ + asset = self.registry.get(asset_name) + if not asset: + return None + + # Find Create activities for this asset + activities = self.activities.find_by_object_hash(asset.cid) + create_activities = [a for a in activities if a.activity_type == "Create"] + + if not create_activities: + return None + + # Get the earliest (first owner) + earliest = min(create_activities, key=lambda a: a.published) + + # Extract username from actor_id + # Format: https://artdag.rose-ash.com/users/{username} + actor_id = earliest.actor_id + if "/users/" in actor_id: + username = actor_id.split("/users/")[-1] + actor = self.actors.get(username) + if actor and verify_activity_ownership(earliest, actor): + return actor + + return None + + def verify_ownership(self, asset_name: str, actor: Actor) -> bool: + """ + Verify that an actor owns an asset. + + Checks for a valid signed Create activity linking the actor + to the asset. + """ + asset = self.registry.get(asset_name) + if not asset: + return False + + activities = self.activities.find_by_object_hash(asset.cid) + for activity in activities: + if activity.activity_type == "Create" and activity.actor_id == actor.id: + if verify_activity_ownership(activity, actor): + return True + + return False + + def list_owned_assets(self, actor: Actor) -> List[Asset]: + """List all assets owned by an actor.""" + activities = self.activities.find_by_actor(actor.id) + owned = [] + + for activity in activities: + if activity.activity_type == "Create": + # Find asset by hash + obj_hash = activity.object_data.get("contentHash", {}) + if isinstance(obj_hash, dict): + hash_value = obj_hash.get("value") + else: + hash_value = obj_hash + + if hash_value: + asset = self.registry.find_by_hash(hash_value) + if asset: + owned.append(asset) + + return owned + + def get_ownership_records(self) -> List[OwnershipRecord]: + """Get all ownership records.""" + records = [] + + for activity in self.activities.list(): + if activity.activity_type != "Create": + continue + + # Extract info + actor_id = activity.actor_id + username = actor_id.split("/users/")[-1] if "/users/" in actor_id else "unknown" + actor = self.actors.get(username) + + obj_hash = activity.object_data.get("contentHash", {}) + hash_value = obj_hash.get("value") if isinstance(obj_hash, dict) else obj_hash + + records.append(OwnershipRecord( + actor_handle=actor.handle if actor else f"@{username}@unknown", + asset_name=activity.object_data.get("name", "unknown"), + cid=hash_value or "unknown", + activity_id=activity.activity_id, + verified=verify_activity_ownership(activity, actor) if actor else False, + )) + + return records diff --git a/artdag/activitypub/signatures.py b/artdag/activitypub/signatures.py new file mode 100644 index 0000000..099524c --- /dev/null +++ b/artdag/activitypub/signatures.py @@ -0,0 +1,163 @@ +# primitive/activitypub/signatures.py +""" +Cryptographic signatures for ActivityPub. + +Uses RSA-SHA256 signatures compatible with HTTP Signatures spec +and Linked Data Signatures for ActivityPub. +""" + +import base64 +import hashlib +import json +import time +from typing import Any, Dict + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.exceptions import InvalidSignature + +from .actor import Actor +from .activity import Activity + + +def _canonicalize(data: Dict[str, Any]) -> str: + """ + Canonicalize JSON for signing. + + Uses JCS (JSON Canonicalization Scheme) - sorted keys, no whitespace. + """ + return json.dumps(data, sort_keys=True, separators=(",", ":")) + + +def _hash_sha256(data: str) -> bytes: + """Hash string with SHA-256.""" + return hashlib.sha256(data.encode()).digest() + + +def sign_activity(activity: Activity, actor: Actor) -> Activity: + """ + Sign an activity with the actor's private key. + + Uses Linked Data Signatures with RsaSignature2017. + + Args: + activity: The activity to sign + actor: The actor whose key signs the activity + + Returns: + Activity with signature attached + """ + # Load private key + private_key = serialization.load_pem_private_key( + actor.private_key, + password=None, + ) + + # Create signature options + created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + + # Canonicalize the activity (without signature) + activity_data = activity.to_activitypub() + activity_data.pop("signature", None) + canonical = _canonicalize(activity_data) + + # Create the data to sign: hash of options + hash of document + options = { + "@context": "https://w3id.org/security/v1", + "type": "RsaSignature2017", + "creator": actor.key_id, + "created": created, + } + options_hash = _hash_sha256(_canonicalize(options)) + document_hash = _hash_sha256(canonical) + to_sign = options_hash + document_hash + + # Sign with RSA-SHA256 + signature_bytes = private_key.sign( + to_sign, + padding.PKCS1v15(), + hashes.SHA256(), + ) + signature_value = base64.b64encode(signature_bytes).decode("utf-8") + + # Attach signature to activity + activity.signature = { + "type": "RsaSignature2017", + "creator": actor.key_id, + "created": created, + "signatureValue": signature_value, + } + + return activity + + +def verify_signature(activity: Activity, public_key_pem: bytes) -> bool: + """ + Verify an activity's signature. + + Args: + activity: The activity with signature + public_key_pem: PEM-encoded public key + + Returns: + True if signature is valid + """ + if not activity.signature: + return False + + try: + # Load public key + public_key = serialization.load_pem_public_key(public_key_pem) + + # Reconstruct signature options + options = { + "@context": "https://w3id.org/security/v1", + "type": activity.signature["type"], + "creator": activity.signature["creator"], + "created": activity.signature["created"], + } + + # Canonicalize activity without signature + activity_data = activity.to_activitypub() + activity_data.pop("signature", None) + canonical = _canonicalize(activity_data) + + # Recreate signed data + options_hash = _hash_sha256(_canonicalize(options)) + document_hash = _hash_sha256(canonical) + signed_data = options_hash + document_hash + + # Decode and verify signature + signature_bytes = base64.b64decode(activity.signature["signatureValue"]) + public_key.verify( + signature_bytes, + signed_data, + padding.PKCS1v15(), + hashes.SHA256(), + ) + return True + + except (InvalidSignature, KeyError, ValueError): + return False + + +def verify_activity_ownership(activity: Activity, actor: Actor) -> bool: + """ + Verify that an activity was signed by the claimed actor. + + Args: + activity: The activity to verify + actor: The claimed actor + + Returns: + True if the activity was signed by this actor + """ + if not activity.signature: + return False + + # Check creator matches actor + if activity.signature.get("creator") != actor.key_id: + return False + + # Verify signature + return verify_signature(activity, actor.public_key) diff --git a/artdag/analysis/__init__.py b/artdag/analysis/__init__.py new file mode 100644 index 0000000..2ab2b81 --- /dev/null +++ b/artdag/analysis/__init__.py @@ -0,0 +1,26 @@ +# artdag/analysis - Audio and video feature extraction +# +# Provides the Analysis phase of the 3-phase execution model: +# 1. ANALYZE - Extract features from inputs +# 2. PLAN - Generate execution plan with cache IDs +# 3. EXECUTE - Run steps with caching + +from .schema import ( + AnalysisResult, + AudioFeatures, + VideoFeatures, + BeatInfo, + EnergyEnvelope, + SpectrumBands, +) +from .analyzer import Analyzer + +__all__ = [ + "Analyzer", + "AnalysisResult", + "AudioFeatures", + "VideoFeatures", + "BeatInfo", + "EnergyEnvelope", + "SpectrumBands", +] diff --git a/artdag/analysis/analyzer.py b/artdag/analysis/analyzer.py new file mode 100644 index 0000000..fd1bdbe --- /dev/null +++ b/artdag/analysis/analyzer.py @@ -0,0 +1,282 @@ +# artdag/analysis/analyzer.py +""" +Main Analyzer class for the Analysis phase. + +Coordinates audio and video feature extraction with caching. +""" + +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import Dict, List, Optional + +from .schema import AnalysisResult, AudioFeatures, VideoFeatures +from .audio import analyze_audio, FEATURE_ALL as AUDIO_ALL +from .video import analyze_video, FEATURE_ALL as VIDEO_ALL + +logger = logging.getLogger(__name__) + + +class AnalysisCache: + """ + Simple file-based cache for analysis results. + + Stores results as JSON files keyed by analysis cache_id. + """ + + def __init__(self, cache_dir: Path): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _path_for(self, cache_id: str) -> Path: + """Get cache file path for a cache_id.""" + return self.cache_dir / f"{cache_id}.json" + + def get(self, cache_id: str) -> Optional[AnalysisResult]: + """Retrieve cached analysis result.""" + path = self._path_for(cache_id) + if not path.exists(): + return None + + try: + with open(path, "r") as f: + data = json.load(f) + return AnalysisResult.from_dict(data) + except (json.JSONDecodeError, KeyError) as e: + logger.warning(f"Failed to load analysis cache {cache_id}: {e}") + return None + + def put(self, result: AnalysisResult) -> None: + """Store analysis result in cache.""" + path = self._path_for(result.cache_id) + with open(path, "w") as f: + json.dump(result.to_dict(), f, indent=2) + + def has(self, cache_id: str) -> bool: + """Check if analysis result is cached.""" + return self._path_for(cache_id).exists() + + def remove(self, cache_id: str) -> bool: + """Remove cached analysis result.""" + path = self._path_for(cache_id) + if path.exists(): + path.unlink() + return True + return False + + +class Analyzer: + """ + Analyzes media inputs to extract features. + + The Analyzer is the first phase of the 3-phase execution model. + It extracts features from inputs that inform downstream processing. + + Example: + analyzer = Analyzer(cache_dir=Path("./analysis_cache")) + + # Analyze a music file for beats + result = analyzer.analyze( + input_path=Path("/path/to/music.mp3"), + input_hash="abc123...", + features=["beats", "energy"] + ) + + print(f"Tempo: {result.tempo} BPM") + print(f"Beats: {result.beat_times}") + """ + + def __init__( + self, + cache_dir: Optional[Path] = None, + content_cache: Optional["Cache"] = None, # artdag.Cache for input lookup + ): + """ + Initialize the Analyzer. + + Args: + cache_dir: Directory for analysis cache. If None, no caching. + content_cache: artdag Cache for looking up inputs by hash + """ + self.cache = AnalysisCache(cache_dir) if cache_dir else None + self.content_cache = content_cache + + def get_input_path(self, input_hash: str, input_path: Optional[Path] = None) -> Path: + """ + Resolve input to a file path. + + Args: + input_hash: Content hash of the input + input_path: Optional direct path to file + + Returns: + Path to the input file + + Raises: + ValueError: If input cannot be resolved + """ + if input_path and input_path.exists(): + return input_path + + if self.content_cache: + entry = self.content_cache.get(input_hash) + if entry: + return Path(entry.output_path) + + raise ValueError(f"Cannot resolve input {input_hash}: no path provided and not in cache") + + def analyze( + self, + input_hash: str, + features: List[str], + input_path: Optional[Path] = None, + media_type: Optional[str] = None, + ) -> AnalysisResult: + """ + Analyze an input file and extract features. + + Args: + input_hash: Content hash of the input (for cache key) + features: List of features to extract: + Audio: "beats", "tempo", "energy", "spectrum", "onsets" + Video: "metadata", "motion_tempo", "scene_changes" + Meta: "all" (extracts all relevant features) + input_path: Optional direct path to file + media_type: Optional hint ("audio", "video", or None for auto-detect) + + Returns: + AnalysisResult with extracted features + """ + # Compute cache ID + temp_result = AnalysisResult( + input_hash=input_hash, + features_requested=sorted(features), + ) + cache_id = temp_result.cache_id + + # Check cache + if self.cache and self.cache.has(cache_id): + cached = self.cache.get(cache_id) + if cached: + logger.info(f"Analysis cache hit: {cache_id[:16]}...") + return cached + + # Resolve input path + path = self.get_input_path(input_hash, input_path) + logger.info(f"Analyzing {path} for features: {features}") + + # Detect media type if not specified + if media_type is None: + media_type = self._detect_media_type(path) + + # Extract features + audio_features = None + video_features = None + + # Normalize features + if "all" in features: + audio_features_list = [AUDIO_ALL] + video_features_list = [VIDEO_ALL] + else: + audio_features_list = [f for f in features if f in ("beats", "tempo", "energy", "spectrum", "onsets")] + video_features_list = [f for f in features if f in ("metadata", "motion_tempo", "scene_changes")] + + if media_type in ("audio", "video") and audio_features_list: + try: + audio_features = analyze_audio(path, features=audio_features_list) + except Exception as e: + logger.warning(f"Audio analysis failed: {e}") + + if media_type == "video" and video_features_list: + try: + video_features = analyze_video(path, features=video_features_list) + except Exception as e: + logger.warning(f"Video analysis failed: {e}") + + result = AnalysisResult( + input_hash=input_hash, + features_requested=sorted(features), + audio=audio_features, + video=video_features, + analyzed_at=datetime.now(timezone.utc).isoformat(), + ) + + # Cache result + if self.cache: + self.cache.put(result) + + return result + + def analyze_multiple( + self, + inputs: Dict[str, Path], + features: List[str], + ) -> Dict[str, AnalysisResult]: + """ + Analyze multiple inputs. + + Args: + inputs: Dict mapping input_hash to file path + features: Features to extract from all inputs + + Returns: + Dict mapping input_hash to AnalysisResult + """ + results = {} + for input_hash, input_path in inputs.items(): + try: + results[input_hash] = self.analyze( + input_hash=input_hash, + features=features, + input_path=input_path, + ) + except Exception as e: + logger.error(f"Analysis failed for {input_hash}: {e}") + raise + + return results + + def _detect_media_type(self, path: Path) -> str: + """ + Detect if file is audio or video. + + Args: + path: Path to media file + + Returns: + "audio" or "video" + """ + import subprocess + import json + + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_streams", + str(path) + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data = json.loads(result.stdout) + streams = data.get("streams", []) + + has_video = any(s.get("codec_type") == "video" for s in streams) + has_audio = any(s.get("codec_type") == "audio" for s in streams) + + if has_video: + return "video" + elif has_audio: + return "audio" + else: + return "unknown" + + except (subprocess.CalledProcessError, json.JSONDecodeError): + # Fall back to extension-based detection + ext = path.suffix.lower() + if ext in (".mp4", ".mov", ".avi", ".mkv", ".webm"): + return "video" + elif ext in (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac"): + return "audio" + return "unknown" diff --git a/artdag/analysis/audio.py b/artdag/analysis/audio.py new file mode 100644 index 0000000..4ee034b --- /dev/null +++ b/artdag/analysis/audio.py @@ -0,0 +1,336 @@ +# artdag/analysis/audio.py +""" +Audio feature extraction. + +Uses librosa for beat detection, energy analysis, and spectral features. +Falls back to basic ffprobe if librosa is not available. +""" + +import json +import logging +import subprocess +from pathlib import Path +from typing import List, Optional, Tuple + +from .schema import AudioFeatures, BeatInfo, EnergyEnvelope, SpectrumBands + +logger = logging.getLogger(__name__) + +# Feature names for requesting specific analysis +FEATURE_BEATS = "beats" +FEATURE_TEMPO = "tempo" +FEATURE_ENERGY = "energy" +FEATURE_SPECTRUM = "spectrum" +FEATURE_ONSETS = "onsets" +FEATURE_ALL = "all" + + +def _get_audio_info_ffprobe(path: Path) -> Tuple[float, int, int]: + """Get basic audio info using ffprobe.""" + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_streams", + "-select_streams", "a:0", + str(path) + ] + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data = json.loads(result.stdout) + if not data.get("streams"): + raise ValueError("No audio stream found") + + stream = data["streams"][0] + duration = float(stream.get("duration", 0)) + sample_rate = int(stream.get("sample_rate", 44100)) + channels = int(stream.get("channels", 2)) + return duration, sample_rate, channels + except (subprocess.CalledProcessError, json.JSONDecodeError, KeyError) as e: + logger.warning(f"ffprobe failed: {e}") + raise ValueError(f"Could not read audio info: {e}") + + +def _extract_audio_to_wav(path: Path, duration: Optional[float] = None) -> Path: + """Extract audio to temporary WAV file for librosa processing.""" + import tempfile + wav_path = Path(tempfile.mktemp(suffix=".wav")) + + cmd = ["ffmpeg", "-y", "-i", str(path)] + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend([ + "-vn", # No video + "-acodec", "pcm_s16le", + "-ar", "22050", # Resample to 22050 Hz for librosa + "-ac", "1", # Mono + str(wav_path) + ]) + + try: + subprocess.run(cmd, capture_output=True, check=True) + return wav_path + except subprocess.CalledProcessError as e: + logger.error(f"Audio extraction failed: {e.stderr}") + raise ValueError(f"Could not extract audio: {e}") + + +def analyze_beats(path: Path, sample_rate: int = 22050) -> BeatInfo: + """ + Detect beats and tempo using librosa. + + Args: + path: Path to audio file (or pre-extracted WAV) + sample_rate: Sample rate for analysis + + Returns: + BeatInfo with beat times, tempo, and confidence + """ + try: + import librosa + except ImportError: + raise ImportError("librosa required for beat detection. Install with: pip install librosa") + + # Load audio + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + + # Detect tempo and beats + tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr) + + # Convert frames to times + beat_times = librosa.frames_to_time(beat_frames, sr=sr).tolist() + + # Estimate confidence from onset strength consistency + onset_env = librosa.onset.onset_strength(y=y, sr=sr) + beat_strength = onset_env[beat_frames] if len(beat_frames) > 0 else [] + confidence = float(beat_strength.mean() / onset_env.max()) if len(beat_strength) > 0 and onset_env.max() > 0 else 0.5 + + # Detect downbeats (first beat of each bar) + # Use beat phase to estimate bar positions + downbeat_times = None + if len(beat_times) >= 4: + # Assume 4/4 time signature, downbeats every 4 beats + downbeat_times = [beat_times[i] for i in range(0, len(beat_times), 4)] + + return BeatInfo( + beat_times=beat_times, + tempo=float(tempo) if hasattr(tempo, '__float__') else float(tempo[0]) if len(tempo) > 0 else 120.0, + confidence=min(1.0, max(0.0, confidence)), + downbeat_times=downbeat_times, + time_signature=4, + ) + + +def analyze_energy(path: Path, window_ms: float = 50.0, sample_rate: int = 22050) -> EnergyEnvelope: + """ + Extract energy (loudness) envelope. + + Args: + path: Path to audio file + window_ms: Analysis window size in milliseconds + sample_rate: Sample rate for analysis + + Returns: + EnergyEnvelope with times and normalized values + """ + try: + import librosa + import numpy as np + except ImportError: + raise ImportError("librosa and numpy required. Install with: pip install librosa numpy") + + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + + # Calculate frame size from window_ms + hop_length = int(sr * window_ms / 1000) + + # RMS energy + rms = librosa.feature.rms(y=y, hop_length=hop_length)[0] + + # Normalize to 0-1 + rms_max = rms.max() + if rms_max > 0: + rms_normalized = rms / rms_max + else: + rms_normalized = rms + + # Generate time points + times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop_length) + + return EnergyEnvelope( + times=times.tolist(), + values=rms_normalized.tolist(), + window_ms=window_ms, + ) + + +def analyze_spectrum( + path: Path, + band_ranges: Optional[dict] = None, + window_ms: float = 50.0, + sample_rate: int = 22050 +) -> SpectrumBands: + """ + Extract frequency band envelopes. + + Args: + path: Path to audio file + band_ranges: Dict mapping band name to (low_hz, high_hz) + window_ms: Analysis window size + sample_rate: Sample rate + + Returns: + SpectrumBands with bass, mid, high envelopes + """ + try: + import librosa + import numpy as np + except ImportError: + raise ImportError("librosa and numpy required") + + if band_ranges is None: + band_ranges = { + "bass": (20, 200), + "mid": (200, 2000), + "high": (2000, 20000), + } + + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + hop_length = int(sr * window_ms / 1000) + + # Compute STFT + n_fft = 2048 + stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length)) + + # Frequency bins + freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft) + + def band_energy(low_hz: float, high_hz: float) -> List[float]: + """Sum energy in frequency band.""" + mask = (freqs >= low_hz) & (freqs <= high_hz) + if not mask.any(): + return [0.0] * stft.shape[1] + band = stft[mask, :].sum(axis=0) + # Normalize + band_max = band.max() + if band_max > 0: + band = band / band_max + return band.tolist() + + times = librosa.frames_to_time(np.arange(stft.shape[1]), sr=sr, hop_length=hop_length) + + return SpectrumBands( + bass=band_energy(*band_ranges["bass"]), + mid=band_energy(*band_ranges["mid"]), + high=band_energy(*band_ranges["high"]), + times=times.tolist(), + band_ranges=band_ranges, + ) + + +def analyze_onsets(path: Path, sample_rate: int = 22050) -> List[float]: + """ + Detect onset times (note/sound starts). + + Args: + path: Path to audio file + sample_rate: Sample rate + + Returns: + List of onset times in seconds + """ + try: + import librosa + except ImportError: + raise ImportError("librosa required") + + y, sr = librosa.load(str(path), sr=sample_rate, mono=True) + + # Detect onsets + onset_frames = librosa.onset.onset_detect(y=y, sr=sr) + onset_times = librosa.frames_to_time(onset_frames, sr=sr) + + return onset_times.tolist() + + +def analyze_audio( + path: Path, + features: Optional[List[str]] = None, +) -> AudioFeatures: + """ + Extract audio features from file. + + Args: + path: Path to audio/video file + features: List of features to extract. Options: + - "beats": Beat detection (tempo, beat times) + - "energy": Loudness envelope + - "spectrum": Frequency band envelopes + - "onsets": Note onset times + - "all": All features + + Returns: + AudioFeatures with requested analysis + """ + if features is None: + features = [FEATURE_ALL] + + # Normalize features + if FEATURE_ALL in features: + features = [FEATURE_BEATS, FEATURE_ENERGY, FEATURE_SPECTRUM, FEATURE_ONSETS] + + # Get basic info via ffprobe + duration, sample_rate, channels = _get_audio_info_ffprobe(path) + + result = AudioFeatures( + duration=duration, + sample_rate=sample_rate, + channels=channels, + ) + + # Check if librosa is available for advanced features + try: + import librosa # noqa: F401 + has_librosa = True + except ImportError: + has_librosa = False + if any(f in features for f in [FEATURE_BEATS, FEATURE_ENERGY, FEATURE_SPECTRUM, FEATURE_ONSETS]): + logger.warning("librosa not available, skipping advanced audio features") + + if not has_librosa: + return result + + # Extract audio to WAV for librosa + wav_path = None + try: + wav_path = _extract_audio_to_wav(path) + + if FEATURE_BEATS in features or FEATURE_TEMPO in features: + try: + result.beats = analyze_beats(wav_path) + except Exception as e: + logger.warning(f"Beat detection failed: {e}") + + if FEATURE_ENERGY in features: + try: + result.energy = analyze_energy(wav_path) + except Exception as e: + logger.warning(f"Energy analysis failed: {e}") + + if FEATURE_SPECTRUM in features: + try: + result.spectrum = analyze_spectrum(wav_path) + except Exception as e: + logger.warning(f"Spectrum analysis failed: {e}") + + if FEATURE_ONSETS in features: + try: + result.onsets = analyze_onsets(wav_path) + except Exception as e: + logger.warning(f"Onset detection failed: {e}") + + finally: + # Clean up temporary WAV file + if wav_path and wav_path.exists(): + wav_path.unlink() + + return result diff --git a/artdag/analysis/schema.py b/artdag/analysis/schema.py new file mode 100644 index 0000000..4b9825b --- /dev/null +++ b/artdag/analysis/schema.py @@ -0,0 +1,352 @@ +# artdag/analysis/schema.py +""" +Data structures for analysis results. + +Analysis extracts features from input media that inform downstream processing. +Results are cached by: analysis_cache_id = SHA3-256(input_hash + sorted(features)) +""" + +import hashlib +import json +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """Create stable hash from arbitrary data.""" + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +@dataclass +class BeatInfo: + """ + Beat detection results. + + Attributes: + beat_times: List of beat positions in seconds + tempo: Estimated tempo in BPM + confidence: Tempo detection confidence (0-1) + downbeat_times: First beat of each bar (if detected) + time_signature: Detected or assumed time signature (e.g., 4) + """ + beat_times: List[float] + tempo: float + confidence: float = 1.0 + downbeat_times: Optional[List[float]] = None + time_signature: int = 4 + + def to_dict(self) -> Dict[str, Any]: + return { + "beat_times": self.beat_times, + "tempo": self.tempo, + "confidence": self.confidence, + "downbeat_times": self.downbeat_times, + "time_signature": self.time_signature, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "BeatInfo": + return cls( + beat_times=data["beat_times"], + tempo=data["tempo"], + confidence=data.get("confidence", 1.0), + downbeat_times=data.get("downbeat_times"), + time_signature=data.get("time_signature", 4), + ) + + +@dataclass +class EnergyEnvelope: + """ + Energy (loudness) over time. + + Attributes: + times: Time points in seconds + values: Energy values (0-1, normalized) + window_ms: Analysis window size in milliseconds + """ + times: List[float] + values: List[float] + window_ms: float = 50.0 + + def to_dict(self) -> Dict[str, Any]: + return { + "times": self.times, + "values": self.values, + "window_ms": self.window_ms, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EnergyEnvelope": + return cls( + times=data["times"], + values=data["values"], + window_ms=data.get("window_ms", 50.0), + ) + + def at_time(self, t: float) -> float: + """Interpolate energy value at given time.""" + if not self.times: + return 0.0 + if t <= self.times[0]: + return self.values[0] + if t >= self.times[-1]: + return self.values[-1] + + # Binary search for bracketing indices + lo, hi = 0, len(self.times) - 1 + while hi - lo > 1: + mid = (lo + hi) // 2 + if self.times[mid] <= t: + lo = mid + else: + hi = mid + + # Linear interpolation + t0, t1 = self.times[lo], self.times[hi] + v0, v1 = self.values[lo], self.values[hi] + alpha = (t - t0) / (t1 - t0) if t1 != t0 else 0 + return v0 + alpha * (v1 - v0) + + +@dataclass +class SpectrumBands: + """ + Frequency band envelopes over time. + + Attributes: + bass: Low frequency envelope (20-200 Hz typical) + mid: Mid frequency envelope (200-2000 Hz typical) + high: High frequency envelope (2000-20000 Hz typical) + times: Time points in seconds + band_ranges: Frequency ranges for each band in Hz + """ + bass: List[float] + mid: List[float] + high: List[float] + times: List[float] + band_ranges: Dict[str, Tuple[float, float]] = field(default_factory=lambda: { + "bass": (20, 200), + "mid": (200, 2000), + "high": (2000, 20000), + }) + + def to_dict(self) -> Dict[str, Any]: + return { + "bass": self.bass, + "mid": self.mid, + "high": self.high, + "times": self.times, + "band_ranges": self.band_ranges, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SpectrumBands": + return cls( + bass=data["bass"], + mid=data["mid"], + high=data["high"], + times=data["times"], + band_ranges=data.get("band_ranges", { + "bass": (20, 200), + "mid": (200, 2000), + "high": (2000, 20000), + }), + ) + + +@dataclass +class AudioFeatures: + """ + All extracted audio features. + + Attributes: + duration: Audio duration in seconds + sample_rate: Sample rate in Hz + channels: Number of audio channels + beats: Beat detection results + energy: Energy envelope + spectrum: Frequency band envelopes + onsets: Note/sound onset times + """ + duration: float + sample_rate: int + channels: int + beats: Optional[BeatInfo] = None + energy: Optional[EnergyEnvelope] = None + spectrum: Optional[SpectrumBands] = None + onsets: Optional[List[float]] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "duration": self.duration, + "sample_rate": self.sample_rate, + "channels": self.channels, + "beats": self.beats.to_dict() if self.beats else None, + "energy": self.energy.to_dict() if self.energy else None, + "spectrum": self.spectrum.to_dict() if self.spectrum else None, + "onsets": self.onsets, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AudioFeatures": + return cls( + duration=data["duration"], + sample_rate=data["sample_rate"], + channels=data["channels"], + beats=BeatInfo.from_dict(data["beats"]) if data.get("beats") else None, + energy=EnergyEnvelope.from_dict(data["energy"]) if data.get("energy") else None, + spectrum=SpectrumBands.from_dict(data["spectrum"]) if data.get("spectrum") else None, + onsets=data.get("onsets"), + ) + + +@dataclass +class VideoFeatures: + """ + Extracted video features. + + Attributes: + duration: Video duration in seconds + frame_rate: Frames per second + width: Frame width in pixels + height: Frame height in pixels + codec: Video codec name + motion_tempo: Estimated tempo from motion analysis (optional) + scene_changes: Times of detected scene changes + """ + duration: float + frame_rate: float + width: int + height: int + codec: str = "" + motion_tempo: Optional[float] = None + scene_changes: Optional[List[float]] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "duration": self.duration, + "frame_rate": self.frame_rate, + "width": self.width, + "height": self.height, + "codec": self.codec, + "motion_tempo": self.motion_tempo, + "scene_changes": self.scene_changes, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VideoFeatures": + return cls( + duration=data["duration"], + frame_rate=data["frame_rate"], + width=data["width"], + height=data["height"], + codec=data.get("codec", ""), + motion_tempo=data.get("motion_tempo"), + scene_changes=data.get("scene_changes"), + ) + + +@dataclass +class AnalysisResult: + """ + Complete analysis result for an input. + + Combines audio and video features with metadata for caching. + + Attributes: + input_hash: Content hash of the analyzed input + features_requested: List of features that were requested + audio: Audio features (if input has audio) + video: Video features (if input has video) + cache_id: Computed cache ID for this analysis + analyzed_at: Timestamp of analysis + """ + input_hash: str + features_requested: List[str] + audio: Optional[AudioFeatures] = None + video: Optional[VideoFeatures] = None + cache_id: Optional[str] = None + analyzed_at: Optional[str] = None + + def __post_init__(self): + """Compute cache_id if not provided.""" + if self.cache_id is None: + self.cache_id = self._compute_cache_id() + + def _compute_cache_id(self) -> str: + """ + Compute cache ID from input hash and requested features. + + cache_id = SHA3-256(input_hash + sorted(features_requested)) + """ + content = { + "input_hash": self.input_hash, + "features": sorted(self.features_requested), + } + return _stable_hash(content) + + def to_dict(self) -> Dict[str, Any]: + return { + "input_hash": self.input_hash, + "features_requested": self.features_requested, + "audio": self.audio.to_dict() if self.audio else None, + "video": self.video.to_dict() if self.video else None, + "cache_id": self.cache_id, + "analyzed_at": self.analyzed_at, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnalysisResult": + return cls( + input_hash=data["input_hash"], + features_requested=data["features_requested"], + audio=AudioFeatures.from_dict(data["audio"]) if data.get("audio") else None, + video=VideoFeatures.from_dict(data["video"]) if data.get("video") else None, + cache_id=data.get("cache_id"), + analyzed_at=data.get("analyzed_at"), + ) + + def to_json(self) -> str: + """Serialize to JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "AnalysisResult": + """Deserialize from JSON string.""" + return cls.from_dict(json.loads(json_str)) + + # Convenience accessors + @property + def tempo(self) -> Optional[float]: + """Get tempo if beats were analyzed.""" + return self.audio.beats.tempo if self.audio and self.audio.beats else None + + @property + def beat_times(self) -> Optional[List[float]]: + """Get beat times if beats were analyzed.""" + return self.audio.beats.beat_times if self.audio and self.audio.beats else None + + @property + def downbeat_times(self) -> Optional[List[float]]: + """Get downbeat times if analyzed.""" + return self.audio.beats.downbeat_times if self.audio and self.audio.beats else None + + @property + def duration(self) -> float: + """Get duration from video or audio.""" + if self.video: + return self.video.duration + if self.audio: + return self.audio.duration + return 0.0 + + @property + def dimensions(self) -> Optional[Tuple[int, int]]: + """Get video dimensions if available.""" + if self.video: + return (self.video.width, self.video.height) + return None diff --git a/artdag/analysis/video.py b/artdag/analysis/video.py new file mode 100644 index 0000000..94d4152 --- /dev/null +++ b/artdag/analysis/video.py @@ -0,0 +1,266 @@ +# artdag/analysis/video.py +""" +Video feature extraction. + +Uses ffprobe for basic metadata and optional OpenCV for motion analysis. +""" + +import json +import logging +import subprocess +from fractions import Fraction +from pathlib import Path +from typing import List, Optional + +from .schema import VideoFeatures + +logger = logging.getLogger(__name__) + +# Feature names +FEATURE_METADATA = "metadata" +FEATURE_MOTION_TEMPO = "motion_tempo" +FEATURE_SCENE_CHANGES = "scene_changes" +FEATURE_ALL = "all" + + +def _parse_frame_rate(rate_str: str) -> float: + """Parse frame rate string like '30000/1001' or '30'.""" + try: + if "/" in rate_str: + frac = Fraction(rate_str) + return float(frac) + return float(rate_str) + except (ValueError, ZeroDivisionError): + return 30.0 # Default + + +def analyze_metadata(path: Path) -> VideoFeatures: + """ + Extract video metadata using ffprobe. + + Args: + path: Path to video file + + Returns: + VideoFeatures with basic metadata + """ + cmd = [ + "ffprobe", "-v", "quiet", + "-print_format", "json", + "-show_streams", + "-show_format", + "-select_streams", "v:0", + str(path) + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + data = json.loads(result.stdout) + except (subprocess.CalledProcessError, json.JSONDecodeError) as e: + raise ValueError(f"Could not read video info: {e}") + + if not data.get("streams"): + raise ValueError("No video stream found") + + stream = data["streams"][0] + fmt = data.get("format", {}) + + # Get duration from format or stream + duration = float(fmt.get("duration", stream.get("duration", 0))) + + # Parse frame rate + frame_rate = _parse_frame_rate(stream.get("avg_frame_rate", "30")) + + return VideoFeatures( + duration=duration, + frame_rate=frame_rate, + width=int(stream.get("width", 0)), + height=int(stream.get("height", 0)), + codec=stream.get("codec_name", ""), + ) + + +def analyze_scene_changes(path: Path, threshold: float = 0.3) -> List[float]: + """ + Detect scene changes using ffmpeg scene detection. + + Args: + path: Path to video file + threshold: Scene change threshold (0-1, lower = more sensitive) + + Returns: + List of scene change times in seconds + """ + cmd = [ + "ffmpeg", "-i", str(path), + "-vf", f"select='gt(scene,{threshold})',showinfo", + "-f", "null", "-" + ] + + try: + result = subprocess.run(cmd, capture_output=True, text=True) + stderr = result.stderr + except subprocess.CalledProcessError as e: + logger.warning(f"Scene detection failed: {e}") + return [] + + # Parse scene change times from ffmpeg output + scene_times = [] + for line in stderr.split("\n"): + if "pts_time:" in line: + try: + # Extract pts_time value + for part in line.split(): + if part.startswith("pts_time:"): + time_str = part.split(":")[1] + scene_times.append(float(time_str)) + break + except (ValueError, IndexError): + continue + + return scene_times + + +def analyze_motion_tempo(path: Path, sample_duration: float = 30.0) -> Optional[float]: + """ + Estimate tempo from video motion periodicity. + + Analyzes optical flow or frame differences to detect rhythmic motion. + This is useful for matching video speed to audio tempo. + + Args: + path: Path to video file + sample_duration: Duration to analyze (seconds) + + Returns: + Estimated motion tempo in BPM, or None if not detectable + """ + try: + import cv2 + import numpy as np + except ImportError: + logger.warning("OpenCV not available, skipping motion tempo analysis") + return None + + cap = cv2.VideoCapture(str(path)) + if not cap.isOpened(): + logger.warning(f"Could not open video: {path}") + return None + + try: + fps = cap.get(cv2.CAP_PROP_FPS) + if fps <= 0: + fps = 30.0 + + max_frames = int(sample_duration * fps) + frame_diffs = [] + prev_gray = None + + frame_count = 0 + while frame_count < max_frames: + ret, frame = cap.read() + if not ret: + break + + # Convert to grayscale and resize for speed + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + gray = cv2.resize(gray, (160, 90)) + + if prev_gray is not None: + # Calculate frame difference + diff = cv2.absdiff(gray, prev_gray) + frame_diffs.append(np.mean(diff)) + + prev_gray = gray + frame_count += 1 + + if len(frame_diffs) < 60: # Need at least 2 seconds at 30fps + return None + + # Convert to numpy array + motion = np.array(frame_diffs) + + # Normalize + motion = motion - motion.mean() + if motion.std() > 0: + motion = motion / motion.std() + + # Autocorrelation to find periodicity + n = len(motion) + acf = np.correlate(motion, motion, mode="full")[n-1:] + acf = acf / acf[0] # Normalize + + # Find peaks in autocorrelation (potential beat periods) + # Look for periods between 0.3s (200 BPM) and 2s (30 BPM) + min_lag = int(0.3 * fps) + max_lag = min(int(2.0 * fps), len(acf) - 1) + + if max_lag <= min_lag: + return None + + # Find the highest peak in the valid range + search_range = acf[min_lag:max_lag] + if len(search_range) == 0: + return None + + peak_idx = np.argmax(search_range) + min_lag + peak_value = acf[peak_idx] + + # Only report if peak is significant + if peak_value < 0.1: + return None + + # Convert lag to BPM + period_seconds = peak_idx / fps + bpm = 60.0 / period_seconds + + # Sanity check + if 30 <= bpm <= 200: + return round(bpm, 1) + + return None + + finally: + cap.release() + + +def analyze_video( + path: Path, + features: Optional[List[str]] = None, +) -> VideoFeatures: + """ + Extract video features from file. + + Args: + path: Path to video file + features: List of features to extract. Options: + - "metadata": Basic video info (always included) + - "motion_tempo": Estimated tempo from motion + - "scene_changes": Scene change detection + - "all": All features + + Returns: + VideoFeatures with requested analysis + """ + if features is None: + features = [FEATURE_METADATA] + + if FEATURE_ALL in features: + features = [FEATURE_METADATA, FEATURE_MOTION_TEMPO, FEATURE_SCENE_CHANGES] + + # Basic metadata is always extracted + result = analyze_metadata(path) + + if FEATURE_MOTION_TEMPO in features: + try: + result.motion_tempo = analyze_motion_tempo(path) + except Exception as e: + logger.warning(f"Motion tempo analysis failed: {e}") + + if FEATURE_SCENE_CHANGES in features: + try: + result.scene_changes = analyze_scene_changes(path) + except Exception as e: + logger.warning(f"Scene change detection failed: {e}") + + return result diff --git a/artdag/cache.py b/artdag/cache.py new file mode 100644 index 0000000..6012dba --- /dev/null +++ b/artdag/cache.py @@ -0,0 +1,464 @@ +# primitive/cache.py +""" +Content-addressed file cache for node outputs. + +Each node's output is stored at: cache_dir / node_id / output_file +This enables automatic reuse when the same operation is requested. +""" + +import json +import logging +import shutil +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +def _file_hash(path: Path, algorithm: str = "sha3_256") -> str: + """ + Compute content hash of a file. + + Uses SHA-3 (Keccak) by default for quantum resistance. + """ + import hashlib + hasher = hashlib.new(algorithm) + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +@dataclass +class CacheEntry: + """Metadata about a cached output.""" + node_id: str + output_path: Path + created_at: float + size_bytes: int + node_type: str + cid: str = "" # Content identifier (IPFS CID or local hash) + execution_time: float = 0.0 + + def to_dict(self) -> Dict: + return { + "node_id": self.node_id, + "output_path": str(self.output_path), + "created_at": self.created_at, + "size_bytes": self.size_bytes, + "node_type": self.node_type, + "cid": self.cid, + "execution_time": self.execution_time, + } + + @classmethod + def from_dict(cls, data: Dict) -> "CacheEntry": + # Support both "cid" and legacy "content_hash" + cid = data.get("cid") or data.get("content_hash", "") + return cls( + node_id=data["node_id"], + output_path=Path(data["output_path"]), + created_at=data["created_at"], + size_bytes=data["size_bytes"], + node_type=data["node_type"], + cid=cid, + execution_time=data.get("execution_time", 0.0), + ) + + +@dataclass +class CacheStats: + """Statistics about cache usage.""" + total_entries: int = 0 + total_size_bytes: int = 0 + hits: int = 0 + misses: int = 0 + hit_rate: float = 0.0 + + def record_hit(self): + self.hits += 1 + self._update_rate() + + def record_miss(self): + self.misses += 1 + self._update_rate() + + def _update_rate(self): + total = self.hits + self.misses + self.hit_rate = self.hits / total if total > 0 else 0.0 + + +class Cache: + """ + Code-addressed file cache. + + The filesystem IS the index - no JSON index files needed. + Each node's hash is its directory name. + + Structure: + cache_dir/ + / + output.ext # Actual output file + metadata.json # Per-node metadata (optional) + """ + + def __init__(self, cache_dir: Path | str): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + self.stats = CacheStats() + + def _node_dir(self, node_id: str) -> Path: + """Get the cache directory for a node.""" + return self.cache_dir / node_id + + def _find_output_file(self, node_dir: Path) -> Optional[Path]: + """Find the output file in a node directory.""" + if not node_dir.exists() or not node_dir.is_dir(): + return None + for f in node_dir.iterdir(): + if f.is_file() and f.name.startswith("output."): + return f + return None + + def get(self, node_id: str) -> Optional[Path]: + """ + Get cached output path for a node. + + Checks filesystem directly - no in-memory index. + Returns the output path if cached, None otherwise. + """ + node_dir = self._node_dir(node_id) + output_file = self._find_output_file(node_dir) + + if output_file: + self.stats.record_hit() + logger.debug(f"Cache hit: {node_id[:16]}...") + return output_file + + self.stats.record_miss() + return None + + def put(self, node_id: str, source_path: Path, node_type: str, + execution_time: float = 0.0, move: bool = False) -> Path: + """ + Store a file in the cache. + + Args: + node_id: The code-addressed node ID (hash) + source_path: Path to the file to cache + node_type: Type of the node (for metadata) + execution_time: How long the node took to execute + move: If True, move the file instead of copying + + Returns: + Path to the cached file + """ + node_dir = self._node_dir(node_id) + node_dir.mkdir(parents=True, exist_ok=True) + + # Preserve extension + ext = source_path.suffix or ".out" + output_path = node_dir / f"output{ext}" + + # Copy or move file (skip if already in place) + source_resolved = Path(source_path).resolve() + output_resolved = output_path.resolve() + if source_resolved != output_resolved: + if move: + shutil.move(source_path, output_path) + else: + shutil.copy2(source_path, output_path) + + # Compute content hash (IPFS CID of the result) + cid = _file_hash(output_path) + + # Store per-node metadata (optional, for stats/debugging) + metadata = { + "node_id": node_id, + "output_path": str(output_path), + "created_at": time.time(), + "size_bytes": output_path.stat().st_size, + "node_type": node_type, + "cid": cid, + "execution_time": execution_time, + } + metadata_path = node_dir / "metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.debug(f"Cached: {node_id[:16]}... ({metadata['size_bytes']} bytes)") + return output_path + + def has(self, node_id: str) -> bool: + """Check if a node is cached (without affecting stats).""" + return self._find_output_file(self._node_dir(node_id)) is not None + + def remove(self, node_id: str) -> bool: + """Remove a node from the cache.""" + node_dir = self._node_dir(node_id) + if node_dir.exists(): + shutil.rmtree(node_dir) + return True + return False + + def clear(self): + """Clear all cached entries.""" + for node_dir in self.cache_dir.iterdir(): + if node_dir.is_dir() and not node_dir.name.startswith("_"): + shutil.rmtree(node_dir) + self.stats = CacheStats() + + def get_stats(self) -> CacheStats: + """Get cache statistics (scans filesystem).""" + stats = CacheStats() + for node_dir in self.cache_dir.iterdir(): + if node_dir.is_dir() and not node_dir.name.startswith("_"): + output_file = self._find_output_file(node_dir) + if output_file: + stats.total_entries += 1 + stats.total_size_bytes += output_file.stat().st_size + stats.hits = self.stats.hits + stats.misses = self.stats.misses + stats.hit_rate = self.stats.hit_rate + return stats + + def list_entries(self) -> List[CacheEntry]: + """List all cache entries (scans filesystem).""" + entries = [] + for node_dir in self.cache_dir.iterdir(): + if node_dir.is_dir() and not node_dir.name.startswith("_"): + entry = self._load_entry_from_disk(node_dir.name) + if entry: + entries.append(entry) + return entries + + def _load_entry_from_disk(self, node_id: str) -> Optional[CacheEntry]: + """Load entry metadata from disk.""" + node_dir = self._node_dir(node_id) + metadata_path = node_dir / "metadata.json" + output_file = self._find_output_file(node_dir) + + if not output_file: + return None + + if metadata_path.exists(): + try: + with open(metadata_path) as f: + data = json.load(f) + return CacheEntry.from_dict(data) + except (json.JSONDecodeError, KeyError): + pass + + # Fallback: create entry from filesystem + return CacheEntry( + node_id=node_id, + output_path=output_file, + created_at=output_file.stat().st_mtime, + size_bytes=output_file.stat().st_size, + node_type="unknown", + cid=_file_hash(output_file), + ) + + def get_entry(self, node_id: str) -> Optional[CacheEntry]: + """Get cache entry metadata (without affecting stats).""" + return self._load_entry_from_disk(node_id) + + def find_by_cid(self, cid: str) -> Optional[CacheEntry]: + """Find a cache entry by its content hash (scans filesystem).""" + for entry in self.list_entries(): + if entry.cid == cid: + return entry + return None + + def prune(self, max_size_bytes: int = None, max_age_seconds: float = None) -> int: + """ + Prune cache based on size or age. + + Args: + max_size_bytes: Remove oldest entries until under this size + max_age_seconds: Remove entries older than this + + Returns: + Number of entries removed + """ + removed = 0 + now = time.time() + entries = self.list_entries() + + # Remove by age first + if max_age_seconds is not None: + for entry in entries: + if now - entry.created_at > max_age_seconds: + self.remove(entry.node_id) + removed += 1 + + # Then by size (remove oldest first) + if max_size_bytes is not None: + stats = self.get_stats() + if stats.total_size_bytes > max_size_bytes: + sorted_entries = sorted(entries, key=lambda e: e.created_at) + total_size = stats.total_size_bytes + for entry in sorted_entries: + if total_size <= max_size_bytes: + break + self.remove(entry.node_id) + total_size -= entry.size_bytes + removed += 1 + + return removed + + def get_output_path(self, node_id: str, extension: str = ".mkv") -> Path: + """Get the output path for a node (creates directory if needed).""" + node_dir = self._node_dir(node_id) + node_dir.mkdir(parents=True, exist_ok=True) + return node_dir / f"output{extension}" + + # Effect storage methods + + def _effects_dir(self) -> Path: + """Get the effects subdirectory.""" + effects_dir = self.cache_dir / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + return effects_dir + + def store_effect(self, source: str) -> str: + """ + Store an effect in the cache. + + Args: + source: Effect source code + + Returns: + Content hash (cache ID) of the effect + """ + import hashlib as _hashlib + + # Compute content hash + cid = _hashlib.sha3_256(source.encode("utf-8")).hexdigest() + + # Try to load full metadata if effects module available + try: + from .effects.loader import load_effect + loaded = load_effect(source) + meta_dict = loaded.meta.to_dict() + dependencies = loaded.dependencies + requires_python = loaded.requires_python + except ImportError: + # Fallback: store without parsed metadata + meta_dict = {} + dependencies = [] + requires_python = ">=3.10" + + effect_dir = self._effects_dir() / cid + effect_dir.mkdir(parents=True, exist_ok=True) + + # Store source + source_path = effect_dir / "effect.py" + source_path.write_text(source, encoding="utf-8") + + # Store metadata + metadata = { + "cid": cid, + "meta": meta_dict, + "dependencies": dependencies, + "requires_python": requires_python, + "stored_at": time.time(), + } + metadata_path = effect_dir / "metadata.json" + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logger.info(f"Stored effect '{loaded.meta.name}' with hash {cid[:16]}...") + return cid + + def get_effect(self, cid: str) -> Optional[str]: + """ + Get effect source by content hash. + + Args: + cid: SHA3-256 hash of effect source + + Returns: + Effect source code if found, None otherwise + """ + effect_dir = self._effects_dir() / cid + source_path = effect_dir / "effect.py" + + if not source_path.exists(): + return None + + return source_path.read_text(encoding="utf-8") + + def get_effect_path(self, cid: str) -> Optional[Path]: + """ + Get path to effect source file. + + Args: + cid: SHA3-256 hash of effect source + + Returns: + Path to effect.py if found, None otherwise + """ + effect_dir = self._effects_dir() / cid + source_path = effect_dir / "effect.py" + + if not source_path.exists(): + return None + + return source_path + + def get_effect_metadata(self, cid: str) -> Optional[dict]: + """ + Get effect metadata by content hash. + + Args: + cid: SHA3-256 hash of effect source + + Returns: + Metadata dict if found, None otherwise + """ + effect_dir = self._effects_dir() / cid + metadata_path = effect_dir / "metadata.json" + + if not metadata_path.exists(): + return None + + try: + with open(metadata_path) as f: + return json.load(f) + except (json.JSONDecodeError, KeyError): + return None + + def has_effect(self, cid: str) -> bool: + """Check if an effect is cached.""" + effect_dir = self._effects_dir() / cid + return (effect_dir / "effect.py").exists() + + def list_effects(self) -> List[dict]: + """List all cached effects with their metadata.""" + effects = [] + effects_dir = self._effects_dir() + + if not effects_dir.exists(): + return effects + + for effect_dir in effects_dir.iterdir(): + if effect_dir.is_dir(): + metadata = self.get_effect_metadata(effect_dir.name) + if metadata: + effects.append(metadata) + + return effects + + def remove_effect(self, cid: str) -> bool: + """Remove an effect from the cache.""" + effect_dir = self._effects_dir() / cid + + if not effect_dir.exists(): + return False + + shutil.rmtree(effect_dir) + logger.info(f"Removed effect {cid[:16]}...") + return True diff --git a/artdag/cli.py b/artdag/cli.py new file mode 100644 index 0000000..9aa5c8c --- /dev/null +++ b/artdag/cli.py @@ -0,0 +1,724 @@ +#!/usr/bin/env python3 +""" +Art DAG CLI + +Command-line interface for the 3-phase execution model: + artdag analyze - Extract features from inputs + artdag plan - Generate execution plan + artdag execute - Run the plan + artdag run-recipe - Full pipeline + +Usage: + artdag analyze -i :[@] [--features ] + artdag plan -i : [--analysis ] + artdag execute [--dry-run] + artdag run-recipe -i :[@] +""" + +import argparse +import json +import sys +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +def parse_input(input_str: str) -> Tuple[str, str, Optional[str]]: + """ + Parse input specification: name:hash[@path] + + Returns (name, hash, path or None) + """ + if "@" in input_str: + name_hash, path = input_str.rsplit("@", 1) + else: + name_hash = input_str + path = None + + if ":" not in name_hash: + raise ValueError(f"Invalid input format: {input_str}. Expected name:hash[@path]") + + name, hash_value = name_hash.split(":", 1) + return name, hash_value, path + + +def parse_inputs(input_list: List[str]) -> Tuple[Dict[str, str], Dict[str, str]]: + """ + Parse list of input specifications. + + Returns (input_hashes, input_paths) + """ + input_hashes = {} + input_paths = {} + + for input_str in input_list: + name, hash_value, path = parse_input(input_str) + input_hashes[name] = hash_value + if path: + input_paths[name] = path + + return input_hashes, input_paths + + +def cmd_analyze(args): + """Run analysis phase.""" + from .analysis import Analyzer + + # Parse inputs + input_hashes, input_paths = parse_inputs(args.input) + + # Parse features + features = args.features.split(",") if args.features else ["all"] + + # Create analyzer + cache_dir = Path(args.cache_dir) if args.cache_dir else Path("./analysis_cache") + analyzer = Analyzer(cache_dir=cache_dir) + + # Analyze each input + results = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + path = Path(path) + + print(f"Analyzing {name} ({hash_value[:16]}...)...") + + result = analyzer.analyze( + input_hash=hash_value, + features=features, + input_path=path, + ) + + results[hash_value] = result.to_dict() + + # Print summary + if result.audio and result.audio.beats: + print(f" Tempo: {result.audio.beats.tempo:.1f} BPM") + print(f" Beats: {len(result.audio.beats.beat_times)}") + if result.video: + print(f" Duration: {result.video.duration:.1f}s") + print(f" Dimensions: {result.video.width}x{result.video.height}") + + # Write output + output_path = Path(args.output) if args.output else Path("analysis.json") + with open(output_path, "w") as f: + json.dump(results, f, indent=2) + + print(f"\nAnalysis saved to: {output_path}") + + +def cmd_plan(args): + """Run planning phase.""" + from .analysis import AnalysisResult + from .planning import RecipePlanner, Recipe + + # Load recipe + recipe = Recipe.from_file(Path(args.recipe)) + print(f"Recipe: {recipe.name} v{recipe.version}") + + # Parse inputs + input_hashes, _ = parse_inputs(args.input) + + # Load analysis if provided + analysis = {} + if args.analysis: + with open(args.analysis, "r") as f: + analysis_data = json.load(f) + for hash_value, data in analysis_data.items(): + analysis[hash_value] = AnalysisResult.from_dict(data) + + # Create planner + planner = RecipePlanner(use_tree_reduction=not args.no_tree_reduction) + + # Generate plan + print("Generating execution plan...") + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis, + ) + + # Print summary + print(f"\nPlan ID: {plan.plan_id[:16]}...") + print(f"Steps: {len(plan.steps)}") + + steps_by_level = plan.get_steps_by_level() + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + print(f"Levels: {max_level + 1}") + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f" Level {level}: {len(steps)} steps (parallel)") + + # Write output + output_path = Path(args.output) if args.output else Path("plan.json") + with open(output_path, "w") as f: + f.write(plan.to_json()) + + print(f"\nPlan saved to: {output_path}") + + +def cmd_execute(args): + """Run execution phase.""" + from .planning import ExecutionPlan + from .cache import Cache + from .executor import get_executor + from .dag import NodeType + from . import nodes # Register built-in executors + + # Load plan + with open(args.plan, "r") as f: + plan = ExecutionPlan.from_json(f.read()) + + print(f"Executing plan: {plan.plan_id[:16]}...") + print(f"Steps: {len(plan.steps)}") + + if args.dry_run: + print("\n=== DRY RUN ===") + + # Check cache status + cache = Cache(Path(args.cache_dir) if args.cache_dir else Path("./cache")) + steps_by_level = plan.get_steps_by_level() + + cached_count = 0 + pending_count = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}:") + for step in steps: + if cache.has(step.cache_id): + print(f" [CACHED] {step.step_id}: {step.node_type}") + cached_count += 1 + else: + print(f" [PENDING] {step.step_id}: {step.node_type}") + pending_count += 1 + + print(f"\nSummary: {cached_count} cached, {pending_count} pending") + return + + # Execute locally (for testing - production uses Celery) + cache = Cache(Path(args.cache_dir) if args.cache_dir else Path("./cache")) + + cache_paths = {} + for name, hash_value in plan.input_hashes.items(): + if cache.has(hash_value): + entry = cache.get(hash_value) + cache_paths[hash_value] = str(entry.output_path) + + steps_by_level = plan.get_steps_by_level() + executed = 0 + cached = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}: {len(steps)} steps") + + for step in steps: + if cache.has(step.cache_id): + cached_path = cache.get(step.cache_id) + cache_paths[step.cache_id] = str(cached_path) + cache_paths[step.step_id] = str(cached_path) + print(f" [CACHED] {step.step_id}") + cached += 1 + continue + + print(f" [RUNNING] {step.step_id}: {step.node_type}...") + + # Get executor + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + print(f" ERROR: No executor for {step.node_type}") + continue + + # Resolve inputs + input_paths = [] + for input_id in step.input_steps: + if input_id in cache_paths: + input_paths.append(Path(cache_paths[input_id])) + else: + input_step = plan.get_step(input_id) + if input_step and input_step.cache_id in cache_paths: + input_paths.append(Path(cache_paths[input_step.cache_id])) + + if len(input_paths) != len(step.input_steps): + print(f" ERROR: Missing inputs") + continue + + # Execute + output_path = cache.get_output_path(step.cache_id) + try: + result_path = executor.execute(step.config, input_paths, output_path) + cache.put(step.cache_id, result_path, node_type=step.node_type) + cache_paths[step.cache_id] = str(result_path) + cache_paths[step.step_id] = str(result_path) + print(f" [DONE] -> {result_path}") + executed += 1 + except Exception as e: + print(f" [FAILED] {e}") + + # Final output + output_step = plan.get_step(plan.output_step) + output_path = cache_paths.get(output_step.cache_id) if output_step else None + + print(f"\n=== Complete ===") + print(f"Cached: {cached}") + print(f"Executed: {executed}") + if output_path: + print(f"Output: {output_path}") + + +def cmd_run_recipe(args): + """Run complete pipeline: analyze → plan → execute.""" + from .analysis import Analyzer, AnalysisResult + from .planning import RecipePlanner, Recipe + from .cache import Cache + from .executor import get_executor + from .dag import NodeType + from . import nodes # Register built-in executors + + # Load recipe + recipe = Recipe.from_file(Path(args.recipe)) + print(f"Recipe: {recipe.name} v{recipe.version}") + + # Parse inputs + input_hashes, input_paths = parse_inputs(args.input) + + # Parse features + features = args.features.split(",") if args.features else ["beats", "energy"] + + cache_dir = Path(args.cache_dir) if args.cache_dir else Path("./cache") + + # Phase 1: Analyze + print("\n=== Phase 1: Analysis ===") + analyzer = Analyzer(cache_dir=cache_dir / "analysis") + + analysis = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + path = Path(path) + print(f"Analyzing {name}...") + + result = analyzer.analyze( + input_hash=hash_value, + features=features, + input_path=path, + ) + analysis[hash_value] = result + + if result.audio and result.audio.beats: + print(f" Tempo: {result.audio.beats.tempo:.1f} BPM, {len(result.audio.beats.beat_times)} beats") + + # Phase 2: Plan + print("\n=== Phase 2: Planning ===") + + # Check for cached plan + plans_dir = cache_dir / "plans" + plans_dir.mkdir(parents=True, exist_ok=True) + + # Generate plan to get plan_id (deterministic hash) + planner = RecipePlanner(use_tree_reduction=True) + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis, + ) + + plan_cache_path = plans_dir / f"{plan.plan_id}.json" + + if plan_cache_path.exists(): + print(f"Plan cached: {plan.plan_id[:16]}...") + from .planning import ExecutionPlan + with open(plan_cache_path, "r") as f: + plan = ExecutionPlan.from_json(f.read()) + else: + # Save plan to cache + with open(plan_cache_path, "w") as f: + f.write(plan.to_json()) + print(f"Plan saved: {plan.plan_id[:16]}...") + + print(f"Plan: {len(plan.steps)} steps") + steps_by_level = plan.get_steps_by_level() + print(f"Levels: {len(steps_by_level)}") + + # Phase 3: Execute + print("\n=== Phase 3: Execution ===") + + cache = Cache(cache_dir) + + # Build initial cache paths + cache_paths = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + cache_paths[hash_value] = path + cache_paths[name] = path + + executed = 0 + cached = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}: {len(steps)} steps") + + for step in steps: + if cache.has(step.cache_id): + cached_path = cache.get(step.cache_id) + cache_paths[step.cache_id] = str(cached_path) + cache_paths[step.step_id] = str(cached_path) + print(f" [CACHED] {step.step_id}") + cached += 1 + continue + + # Handle SOURCE specially + if step.node_type == "SOURCE": + cid = step.config.get("cid") + if cid in cache_paths: + cache_paths[step.cache_id] = cache_paths[cid] + cache_paths[step.step_id] = cache_paths[cid] + print(f" [SOURCE] {step.step_id}") + continue + + print(f" [RUNNING] {step.step_id}: {step.node_type}...") + + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + print(f" SKIP: No executor for {step.node_type}") + continue + + # Resolve inputs + input_paths_list = [] + for input_id in step.input_steps: + if input_id in cache_paths: + input_paths_list.append(Path(cache_paths[input_id])) + else: + input_step = plan.get_step(input_id) + if input_step and input_step.cache_id in cache_paths: + input_paths_list.append(Path(cache_paths[input_step.cache_id])) + + if len(input_paths_list) != len(step.input_steps): + print(f" ERROR: Missing inputs for {step.step_id}") + continue + + output_path = cache.get_output_path(step.cache_id) + try: + result_path = executor.execute(step.config, input_paths_list, output_path) + cache.put(step.cache_id, result_path, node_type=step.node_type) + cache_paths[step.cache_id] = str(result_path) + cache_paths[step.step_id] = str(result_path) + print(f" [DONE]") + executed += 1 + except Exception as e: + print(f" [FAILED] {e}") + + # Final output + output_step = plan.get_step(plan.output_step) + output_path = cache_paths.get(output_step.cache_id) if output_step else None + + print(f"\n=== Complete ===") + print(f"Cached: {cached}") + print(f"Executed: {executed}") + if output_path: + print(f"Output: {output_path}") + + +def cmd_run_recipe_ipfs(args): + """Run complete pipeline with IPFS-primary mode. + + Everything stored on IPFS: + - Inputs (media files) + - Analysis results (JSON) + - Execution plans (JSON) + - Step outputs (media files) + """ + import hashlib + import shutil + import tempfile + + from .analysis import Analyzer, AnalysisResult + from .planning import RecipePlanner, Recipe, ExecutionPlan + from .executor import get_executor + from .dag import NodeType + from . import nodes # Register built-in executors + + # Check for ipfs_client + try: + from art_celery import ipfs_client + except ImportError: + # Try relative import for when running from art-celery + try: + import ipfs_client + except ImportError: + print("Error: ipfs_client not available. Install art-celery or run from art-celery directory.") + sys.exit(1) + + # Check IPFS availability + if not ipfs_client.is_available(): + print("Error: IPFS daemon not available. Start IPFS with 'ipfs daemon'") + sys.exit(1) + + print("=== IPFS-Primary Mode ===") + print(f"IPFS Node: {ipfs_client.get_node_id()[:16]}...") + + # Load recipe + recipe_path = Path(args.recipe) + recipe = Recipe.from_file(recipe_path) + print(f"\nRecipe: {recipe.name} v{recipe.version}") + + # Parse inputs + input_hashes, input_paths = parse_inputs(args.input) + + # Parse features + features = args.features.split(",") if args.features else ["beats", "energy"] + + # Phase 0: Register on IPFS + print("\n=== Phase 0: Register on IPFS ===") + + # Register recipe + recipe_bytes = recipe_path.read_bytes() + recipe_cid = ipfs_client.add_bytes(recipe_bytes) + print(f"Recipe CID: {recipe_cid}") + + # Register inputs + input_cids = {} + for name, hash_value in input_hashes.items(): + path = input_paths.get(name) + if path: + cid = ipfs_client.add_file(Path(path)) + if cid: + input_cids[name] = cid + print(f"Input '{name}': {cid}") + else: + print(f"Error: Failed to add input '{name}' to IPFS") + sys.exit(1) + + # Phase 1: Analyze + print("\n=== Phase 1: Analysis ===") + + # Create temp dir for analysis + work_dir = Path(tempfile.mkdtemp(prefix="artdag_ipfs_")) + analysis_cids = {} + analysis = {} + + try: + for name, hash_value in input_hashes.items(): + input_cid = input_cids.get(name) + if not input_cid: + continue + + print(f"Analyzing {name}...") + + # Fetch from IPFS to temp + temp_input = work_dir / f"input_{name}.mkv" + if not ipfs_client.get_file(input_cid, temp_input): + print(f" Error: Failed to fetch from IPFS") + continue + + # Run analysis + analyzer = Analyzer(cache_dir=None) + result = analyzer.analyze( + input_hash=hash_value, + features=features, + input_path=temp_input, + ) + + if result.audio and result.audio.beats: + print(f" Tempo: {result.audio.beats.tempo:.1f} BPM, {len(result.audio.beats.beat_times)} beats") + + # Store analysis on IPFS + analysis_cid = ipfs_client.add_json(result.to_dict()) + if analysis_cid: + analysis_cids[hash_value] = analysis_cid + analysis[hash_value] = result + print(f" Analysis CID: {analysis_cid}") + + # Phase 2: Plan + print("\n=== Phase 2: Planning ===") + + planner = RecipePlanner(use_tree_reduction=True) + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + analysis=analysis if analysis else None, + ) + + # Store plan on IPFS + import json + plan_dict = json.loads(plan.to_json()) + plan_cid = ipfs_client.add_json(plan_dict) + print(f"Plan ID: {plan.plan_id[:16]}...") + print(f"Plan CID: {plan_cid}") + print(f"Steps: {len(plan.steps)}") + + steps_by_level = plan.get_steps_by_level() + print(f"Levels: {len(steps_by_level)}") + + # Phase 3: Execute + print("\n=== Phase 3: Execution ===") + + # CID results + cid_results = dict(input_cids) + step_cids = {} + + executed = 0 + cached = 0 + + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"\nLevel {level}: {len(steps)} steps") + + for step in steps: + # Handle SOURCE + if step.node_type == "SOURCE": + source_name = step.config.get("name") or step.step_id + cid = cid_results.get(source_name) + if cid: + step_cids[step.step_id] = cid + print(f" [SOURCE] {step.step_id}") + continue + + print(f" [RUNNING] {step.step_id}: {step.node_type}...") + + try: + node_type = NodeType[step.node_type] + except KeyError: + node_type = step.node_type + + executor = get_executor(node_type) + if executor is None: + print(f" SKIP: No executor for {step.node_type}") + continue + + # Fetch inputs from IPFS + input_paths_list = [] + for i, input_step_id in enumerate(step.input_steps): + input_cid = step_cids.get(input_step_id) or cid_results.get(input_step_id) + if not input_cid: + print(f" ERROR: Missing input CID for {input_step_id}") + continue + + temp_path = work_dir / f"step_{step.step_id}_input_{i}.mkv" + if not ipfs_client.get_file(input_cid, temp_path): + print(f" ERROR: Failed to fetch {input_cid}") + continue + input_paths_list.append(temp_path) + + if len(input_paths_list) != len(step.input_steps): + print(f" ERROR: Missing inputs") + continue + + # Execute + output_path = work_dir / f"step_{step.step_id}_output.mkv" + try: + result_path = executor.execute(step.config, input_paths_list, output_path) + + # Add to IPFS + output_cid = ipfs_client.add_file(result_path) + if output_cid: + step_cids[step.step_id] = output_cid + print(f" [DONE] CID: {output_cid}") + executed += 1 + else: + print(f" [FAILED] Could not add to IPFS") + except Exception as e: + print(f" [FAILED] {e}") + + # Final output + output_step = plan.get_step(plan.output_step) + output_cid = step_cids.get(output_step.step_id) if output_step else None + + print(f"\n=== Complete ===") + print(f"Executed: {executed}") + if output_cid: + print(f"Output CID: {output_cid}") + print(f"Fetch with: ipfs get {output_cid}") + + # Summary of all CIDs + print(f"\n=== All CIDs ===") + print(f"Recipe: {recipe_cid}") + print(f"Plan: {plan_cid}") + for name, cid in input_cids.items(): + print(f"Input '{name}': {cid}") + for hash_val, cid in analysis_cids.items(): + print(f"Analysis '{hash_val[:16]}...': {cid}") + if output_cid: + print(f"Output: {output_cid}") + + finally: + # Cleanup temp + shutil.rmtree(work_dir, ignore_errors=True) + + +def main(): + parser = argparse.ArgumentParser( + prog="artdag", + description="Art DAG - Declarative media composition", + ) + subparsers = parser.add_subparsers(dest="command", help="Commands") + + # analyze command + analyze_parser = subparsers.add_parser("analyze", help="Extract features from inputs") + analyze_parser.add_argument("recipe", help="Recipe YAML file") + analyze_parser.add_argument("-i", "--input", action="append", required=True, + help="Input: name:hash[@path]") + analyze_parser.add_argument("--features", help="Features to extract (comma-separated)") + analyze_parser.add_argument("-o", "--output", help="Output file (default: analysis.json)") + analyze_parser.add_argument("--cache-dir", help="Analysis cache directory") + + # plan command + plan_parser = subparsers.add_parser("plan", help="Generate execution plan") + plan_parser.add_argument("recipe", help="Recipe YAML file") + plan_parser.add_argument("-i", "--input", action="append", required=True, + help="Input: name:hash") + plan_parser.add_argument("--analysis", help="Analysis JSON file") + plan_parser.add_argument("-o", "--output", help="Output file (default: plan.json)") + plan_parser.add_argument("--no-tree-reduction", action="store_true", + help="Disable tree reduction optimization") + + # execute command + execute_parser = subparsers.add_parser("execute", help="Execute a plan") + execute_parser.add_argument("plan", help="Plan JSON file") + execute_parser.add_argument("--dry-run", action="store_true", + help="Show what would execute") + execute_parser.add_argument("--cache-dir", help="Cache directory") + + # run-recipe command + run_parser = subparsers.add_parser("run-recipe", help="Full pipeline: analyze → plan → execute") + run_parser.add_argument("recipe", help="Recipe YAML file") + run_parser.add_argument("-i", "--input", action="append", required=True, + help="Input: name:hash[@path]") + run_parser.add_argument("--features", help="Features to extract (comma-separated)") + run_parser.add_argument("--cache-dir", help="Cache directory") + run_parser.add_argument("--ipfs-primary", action="store_true", + help="Use IPFS-primary mode (everything on IPFS, no local cache)") + + args = parser.parse_args() + + if args.command == "analyze": + cmd_analyze(args) + elif args.command == "plan": + cmd_plan(args) + elif args.command == "execute": + cmd_execute(args) + elif args.command == "run-recipe": + if getattr(args, 'ipfs_primary', False): + cmd_run_recipe_ipfs(args) + else: + cmd_run_recipe(args) + else: + parser.print_help() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/artdag/client.py b/artdag/client.py new file mode 100644 index 0000000..21a1ab5 --- /dev/null +++ b/artdag/client.py @@ -0,0 +1,201 @@ +# primitive/client.py +""" +Client SDK for the primitive execution server. + +Provides a simple API for submitting DAGs and retrieving results. + +Usage: + client = PrimitiveClient("http://localhost:8080") + + # Build a DAG + builder = DAGBuilder() + source = builder.source("/path/to/video.mp4") + segment = builder.segment(source, duration=5.0) + builder.set_output(segment) + dag = builder.build() + + # Execute and wait for result + result = client.execute(dag) + print(f"Output: {result.output_path}") +""" + +import json +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional +from urllib.request import urlopen, Request +from urllib.error import HTTPError, URLError + +from .dag import DAG, DAGBuilder + + +@dataclass +class ExecutionResult: + """Result from server execution.""" + success: bool + output_path: Optional[Path] = None + error: Optional[str] = None + execution_time: float = 0.0 + nodes_executed: int = 0 + nodes_cached: int = 0 + + +@dataclass +class CacheStats: + """Cache statistics from server.""" + total_entries: int = 0 + total_size_bytes: int = 0 + hits: int = 0 + misses: int = 0 + hit_rate: float = 0.0 + + +class PrimitiveClient: + """ + Client for the primitive execution server. + + Args: + base_url: Server URL (e.g., "http://localhost:8080") + timeout: Request timeout in seconds + """ + + def __init__(self, base_url: str = "http://localhost:8080", timeout: float = 300): + self.base_url = base_url.rstrip("/") + self.timeout = timeout + + def _request(self, method: str, path: str, data: dict = None) -> dict: + """Make HTTP request to server.""" + url = f"{self.base_url}{path}" + + if data is not None: + body = json.dumps(data).encode() + headers = {"Content-Type": "application/json"} + else: + body = None + headers = {} + + req = Request(url, data=body, headers=headers, method=method) + + try: + with urlopen(req, timeout=self.timeout) as response: + return json.loads(response.read().decode()) + except HTTPError as e: + error_body = e.read().decode() + try: + error_data = json.loads(error_body) + raise RuntimeError(error_data.get("error", str(e))) + except json.JSONDecodeError: + raise RuntimeError(f"HTTP {e.code}: {error_body}") + except URLError as e: + raise ConnectionError(f"Failed to connect to server: {e}") + + def health(self) -> bool: + """Check if server is healthy.""" + try: + result = self._request("GET", "/health") + return result.get("status") == "ok" + except Exception: + return False + + def submit(self, dag: DAG) -> str: + """ + Submit a DAG for execution. + + Args: + dag: The DAG to execute + + Returns: + Job ID for tracking + """ + result = self._request("POST", "/execute", dag.to_dict()) + return result["job_id"] + + def status(self, job_id: str) -> str: + """ + Get job status. + + Args: + job_id: Job ID from submit() + + Returns: + Status: "pending", "running", "completed", or "failed" + """ + result = self._request("GET", f"/status/{job_id}") + return result["status"] + + def result(self, job_id: str) -> Optional[ExecutionResult]: + """ + Get job result (non-blocking). + + Args: + job_id: Job ID from submit() + + Returns: + ExecutionResult if complete, None if still running + """ + data = self._request("GET", f"/result/{job_id}") + + if not data.get("ready", False): + return None + + return ExecutionResult( + success=data.get("success", False), + output_path=Path(data["output_path"]) if data.get("output_path") else None, + error=data.get("error"), + execution_time=data.get("execution_time", 0), + nodes_executed=data.get("nodes_executed", 0), + nodes_cached=data.get("nodes_cached", 0), + ) + + def wait(self, job_id: str, poll_interval: float = 0.5) -> ExecutionResult: + """ + Wait for job completion and return result. + + Args: + job_id: Job ID from submit() + poll_interval: Seconds between status checks + + Returns: + ExecutionResult + """ + while True: + result = self.result(job_id) + if result is not None: + return result + time.sleep(poll_interval) + + def execute(self, dag: DAG, poll_interval: float = 0.5) -> ExecutionResult: + """ + Submit DAG and wait for result. + + Convenience method combining submit() and wait(). + + Args: + dag: The DAG to execute + poll_interval: Seconds between status checks + + Returns: + ExecutionResult + """ + job_id = self.submit(dag) + return self.wait(job_id, poll_interval) + + def cache_stats(self) -> CacheStats: + """Get cache statistics.""" + data = self._request("GET", "/cache/stats") + return CacheStats( + total_entries=data.get("total_entries", 0), + total_size_bytes=data.get("total_size_bytes", 0), + hits=data.get("hits", 0), + misses=data.get("misses", 0), + hit_rate=data.get("hit_rate", 0.0), + ) + + def clear_cache(self) -> None: + """Clear the server cache.""" + self._request("DELETE", "/cache") + + +# Re-export DAGBuilder for convenience +__all__ = ["PrimitiveClient", "ExecutionResult", "CacheStats", "DAGBuilder"] diff --git a/artdag/dag.py b/artdag/dag.py new file mode 100644 index 0000000..735b7a2 --- /dev/null +++ b/artdag/dag.py @@ -0,0 +1,344 @@ +# primitive/dag.py +""" +Core DAG data structures. + +Nodes are content-addressed: node_id = hash(type + config + input_ids) +This enables automatic caching and deduplication. +""" + +import hashlib +import json +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import Any, Dict, List, Optional + + +class NodeType(Enum): + """Built-in node types.""" + # Source operations + SOURCE = auto() # Load file from path + + # Transform operations + SEGMENT = auto() # Extract time range + RESIZE = auto() # Scale/crop/pad + TRANSFORM = auto() # Visual effects (color, blur, etc.) + + # Compose operations + SEQUENCE = auto() # Concatenate in time + LAYER = auto() # Stack spatially (overlay) + MUX = auto() # Combine video + audio streams + BLEND = auto() # Blend two inputs + AUDIO_MIX = auto() # Mix multiple audio streams + SWITCH = auto() # Time-based input switching + + # Analysis operations + ANALYZE = auto() # Extract features (audio, motion, etc.) + + # Generation operations + GENERATE = auto() # Create content (text, graphics, etc.) + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """ + Create stable hash from arbitrary data. + + Uses SHA-3 (Keccak) for quantum resistance. + Returns full hash - no truncation. + + Args: + data: Data to hash (will be JSON serialized) + algorithm: Hash algorithm (default: sha3_256) + + Returns: + Full hex digest + """ + # Convert to JSON with sorted keys for stability + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +@dataclass +class Node: + """ + A node in the execution DAG. + + Attributes: + node_type: The operation type (NodeType enum or string for custom types) + config: Operation-specific configuration + inputs: List of input node IDs (resolved during execution) + node_id: Content-addressed ID (computed from type + config + inputs) + name: Optional human-readable name for debugging + """ + node_type: NodeType | str + config: Dict[str, Any] = field(default_factory=dict) + inputs: List[str] = field(default_factory=list) + node_id: Optional[str] = None + name: Optional[str] = None + + def __post_init__(self): + """Compute node_id if not provided.""" + if self.node_id is None: + self.node_id = self._compute_id() + + def _compute_id(self) -> str: + """Compute content-addressed ID from node contents.""" + type_str = self.node_type.name if isinstance(self.node_type, NodeType) else str(self.node_type) + content = { + "type": type_str, + "config": self.config, + "inputs": sorted(self.inputs), # Sort for stability + } + return _stable_hash(content) + + def to_dict(self) -> Dict[str, Any]: + """Serialize node to dictionary.""" + type_str = self.node_type.name if isinstance(self.node_type, NodeType) else str(self.node_type) + return { + "node_id": self.node_id, + "node_type": type_str, + "config": self.config, + "inputs": self.inputs, + "name": self.name, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Node": + """Deserialize node from dictionary.""" + type_str = data["node_type"] + try: + node_type = NodeType[type_str] + except KeyError: + node_type = type_str # Custom type as string + + return cls( + node_type=node_type, + config=data.get("config", {}), + inputs=data.get("inputs", []), + node_id=data.get("node_id"), + name=data.get("name"), + ) + + +@dataclass +class DAG: + """ + A directed acyclic graph of nodes. + + Attributes: + nodes: Dictionary mapping node_id -> Node + output_id: The ID of the final output node + metadata: Optional metadata about the DAG (source, version, etc.) + """ + nodes: Dict[str, Node] = field(default_factory=dict) + output_id: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def add_node(self, node: Node) -> str: + """Add a node to the DAG, returning its ID.""" + if node.node_id in self.nodes: + # Node already exists (deduplication via content addressing) + return node.node_id + self.nodes[node.node_id] = node + return node.node_id + + def set_output(self, node_id: str) -> None: + """Set the output node.""" + if node_id not in self.nodes: + raise ValueError(f"Node {node_id} not in DAG") + self.output_id = node_id + + def get_node(self, node_id: str) -> Node: + """Get a node by ID.""" + if node_id not in self.nodes: + raise KeyError(f"Node {node_id} not found") + return self.nodes[node_id] + + def topological_order(self) -> List[str]: + """Return nodes in topological order (dependencies first).""" + visited = set() + order = [] + + def visit(node_id: str): + if node_id in visited: + return + visited.add(node_id) + node = self.nodes[node_id] + for input_id in node.inputs: + visit(input_id) + order.append(node_id) + + # Visit all nodes (not just output, in case of disconnected components) + for node_id in self.nodes: + visit(node_id) + + return order + + def validate(self) -> List[str]: + """Validate DAG structure. Returns list of errors (empty if valid).""" + errors = [] + + if self.output_id is None: + errors.append("No output node set") + elif self.output_id not in self.nodes: + errors.append(f"Output node {self.output_id} not in DAG") + + # Check all input references are valid + for node_id, node in self.nodes.items(): + for input_id in node.inputs: + if input_id not in self.nodes: + errors.append(f"Node {node_id} references missing input {input_id}") + + # Check for cycles (skip if we already found missing inputs) + if not any("missing" in e for e in errors): + try: + self.topological_order() + except (RecursionError, KeyError): + errors.append("DAG contains cycles or invalid references") + + return errors + + def to_dict(self) -> Dict[str, Any]: + """Serialize DAG to dictionary.""" + return { + "nodes": {nid: node.to_dict() for nid, node in self.nodes.items()}, + "output_id": self.output_id, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "DAG": + """Deserialize DAG from dictionary.""" + dag = cls(metadata=data.get("metadata", {})) + for node_data in data.get("nodes", {}).values(): + dag.add_node(Node.from_dict(node_data)) + dag.output_id = data.get("output_id") + return dag + + def to_json(self) -> str: + """Serialize DAG to JSON string.""" + return json.dumps(self.to_dict(), indent=2) + + @classmethod + def from_json(cls, json_str: str) -> "DAG": + """Deserialize DAG from JSON string.""" + return cls.from_dict(json.loads(json_str)) + + +class DAGBuilder: + """ + Fluent builder for constructing DAGs. + + Example: + builder = DAGBuilder() + source = builder.source("/path/to/video.mp4") + segment = builder.segment(source, duration=5.0) + builder.set_output(segment) + dag = builder.build() + """ + + def __init__(self): + self.dag = DAG() + + def _add(self, node_type: NodeType | str, config: Dict[str, Any], + inputs: List[str] = None, name: str = None) -> str: + """Add a node and return its ID.""" + node = Node( + node_type=node_type, + config=config, + inputs=inputs or [], + name=name, + ) + return self.dag.add_node(node) + + # Source operations + + def source(self, path: str, name: str = None) -> str: + """Add a SOURCE node.""" + return self._add(NodeType.SOURCE, {"path": path}, name=name) + + # Transform operations + + def segment(self, input_id: str, duration: float = None, + offset: float = 0, precise: bool = True, name: str = None) -> str: + """Add a SEGMENT node.""" + config = {"offset": offset, "precise": precise} + if duration is not None: + config["duration"] = duration + return self._add(NodeType.SEGMENT, config, [input_id], name=name) + + def resize(self, input_id: str, width: int, height: int, + mode: str = "fit", name: str = None) -> str: + """Add a RESIZE node.""" + return self._add( + NodeType.RESIZE, + {"width": width, "height": height, "mode": mode}, + [input_id], + name=name + ) + + def transform(self, input_id: str, effects: Dict[str, Any], + name: str = None) -> str: + """Add a TRANSFORM node.""" + return self._add(NodeType.TRANSFORM, {"effects": effects}, [input_id], name=name) + + # Compose operations + + def sequence(self, input_ids: List[str], transition: Dict[str, Any] = None, + name: str = None) -> str: + """Add a SEQUENCE node.""" + config = {"transition": transition or {"type": "cut"}} + return self._add(NodeType.SEQUENCE, config, input_ids, name=name) + + def layer(self, input_ids: List[str], configs: List[Dict] = None, + name: str = None) -> str: + """Add a LAYER node.""" + return self._add( + NodeType.LAYER, + {"inputs": configs or [{}] * len(input_ids)}, + input_ids, + name=name + ) + + def mux(self, video_id: str, audio_id: str, shortest: bool = True, + name: str = None) -> str: + """Add a MUX node.""" + return self._add( + NodeType.MUX, + {"video_stream": 0, "audio_stream": 1, "shortest": shortest}, + [video_id, audio_id], + name=name + ) + + def blend(self, input1_id: str, input2_id: str, mode: str = "overlay", + opacity: float = 0.5, name: str = None) -> str: + """Add a BLEND node.""" + return self._add( + NodeType.BLEND, + {"mode": mode, "opacity": opacity}, + [input1_id, input2_id], + name=name + ) + + def audio_mix(self, input_ids: List[str], gains: List[float] = None, + normalize: bool = True, name: str = None) -> str: + """Add an AUDIO_MIX node to mix multiple audio streams.""" + config = {"normalize": normalize} + if gains is not None: + config["gains"] = gains + return self._add(NodeType.AUDIO_MIX, config, input_ids, name=name) + + # Output + + def set_output(self, node_id: str) -> "DAGBuilder": + """Set the output node.""" + self.dag.set_output(node_id) + return self + + def build(self) -> DAG: + """Build and validate the DAG.""" + errors = self.dag.validate() + if errors: + raise ValueError(f"Invalid DAG: {errors}") + return self.dag diff --git a/artdag/effects/__init__.py b/artdag/effects/__init__.py new file mode 100644 index 0000000..701765b --- /dev/null +++ b/artdag/effects/__init__.py @@ -0,0 +1,55 @@ +""" +Cacheable effect system. + +Effects are single Python files with: +- PEP 723 embedded dependencies +- @-tag metadata in docstrings +- Frame-by-frame or whole-video API + +Effects are cached by content hash (SHA3-256) and executed in +sandboxed environments for determinism. +""" + +from .meta import EffectMeta, ParamSpec, ExecutionContext +from .loader import load_effect, load_effect_file, LoadedEffect, compute_cid +from .binding import ( + AnalysisData, + ResolvedBinding, + resolve_binding, + resolve_all_bindings, + bindings_to_lookup_table, + has_bindings, + extract_binding_sources, +) +from .sandbox import Sandbox, SandboxConfig, SandboxResult, is_bwrap_available, get_venv_path +from .runner import run_effect, run_effect_from_cache, EffectExecutor + +__all__ = [ + # Meta types + "EffectMeta", + "ParamSpec", + "ExecutionContext", + # Loader + "load_effect", + "load_effect_file", + "LoadedEffect", + "compute_cid", + # Binding + "AnalysisData", + "ResolvedBinding", + "resolve_binding", + "resolve_all_bindings", + "bindings_to_lookup_table", + "has_bindings", + "extract_binding_sources", + # Sandbox + "Sandbox", + "SandboxConfig", + "SandboxResult", + "is_bwrap_available", + "get_venv_path", + # Runner + "run_effect", + "run_effect_from_cache", + "EffectExecutor", +] diff --git a/artdag/effects/binding.py b/artdag/effects/binding.py new file mode 100644 index 0000000..9017185 --- /dev/null +++ b/artdag/effects/binding.py @@ -0,0 +1,311 @@ +""" +Parameter binding resolution. + +Resolves bind expressions to per-frame lookup tables at plan time. +Binding options: + - :range [lo hi] - map 0-1 to output range + - :smooth N - smoothing window in seconds + - :offset N - time offset in seconds + - :on-event V - value on discrete events + - :decay N - exponential decay after event + - :noise N - add deterministic noise (seeded) + - :seed N - explicit RNG seed +""" + +import hashlib +import math +import random +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class AnalysisData: + """ + Analysis data for binding resolution. + + Attributes: + frame_rate: Video frame rate + total_frames: Total number of frames + features: Dict mapping feature name to per-frame values + events: Dict mapping event name to list of frame indices + """ + + frame_rate: float + total_frames: int + features: Dict[str, List[float]] # feature -> [value_per_frame] + events: Dict[str, List[int]] # event -> [frame_indices] + + def get_feature(self, name: str, frame: int) -> float: + """Get feature value at frame, interpolating if needed.""" + if name not in self.features: + return 0.0 + values = self.features[name] + if not values: + return 0.0 + if frame >= len(values): + return values[-1] + return values[frame] + + def get_events_in_range( + self, name: str, start_frame: int, end_frame: int + ) -> List[int]: + """Get event frames in range.""" + if name not in self.events: + return [] + return [f for f in self.events[name] if start_frame <= f < end_frame] + + +@dataclass +class ResolvedBinding: + """ + Resolved binding with per-frame values. + + Attributes: + param_name: Parameter this binding applies to + values: List of values, one per frame + """ + + param_name: str + values: List[float] + + def get(self, frame: int) -> float: + """Get value at frame.""" + if frame >= len(self.values): + return self.values[-1] if self.values else 0.0 + return self.values[frame] + + +def resolve_binding( + binding: Dict[str, Any], + analysis: AnalysisData, + param_name: str, + cache_id: str = None, +) -> ResolvedBinding: + """ + Resolve a binding specification to per-frame values. + + Args: + binding: Binding spec with source, feature, and options + analysis: Analysis data with features and events + param_name: Name of the parameter being bound + cache_id: Cache ID for deterministic seeding + + Returns: + ResolvedBinding with values for each frame + """ + feature = binding.get("feature") + if not feature: + raise ValueError(f"Binding for {param_name} missing feature") + + # Get base values + values = [] + is_event = feature in analysis.events + + if is_event: + # Event-based binding + on_event = binding.get("on_event", 1.0) + decay = binding.get("decay", 0.0) + values = _resolve_event_binding( + analysis.events.get(feature, []), + analysis.total_frames, + analysis.frame_rate, + on_event, + decay, + ) + else: + # Continuous feature binding + feature_values = analysis.features.get(feature, []) + if not feature_values: + # No data, use zeros + values = [0.0] * analysis.total_frames + else: + # Extend to total frames if needed + values = list(feature_values) + while len(values) < analysis.total_frames: + values.append(values[-1] if values else 0.0) + + # Apply offset + offset = binding.get("offset") + if offset: + offset_frames = int(offset * analysis.frame_rate) + values = _apply_offset(values, offset_frames) + + # Apply smoothing + smooth = binding.get("smooth") + if smooth: + window_frames = int(smooth * analysis.frame_rate) + values = _apply_smoothing(values, window_frames) + + # Apply range mapping + range_spec = binding.get("range") + if range_spec: + lo, hi = range_spec + values = _apply_range(values, lo, hi) + + # Apply noise + noise = binding.get("noise") + if noise: + seed = binding.get("seed") + if seed is None and cache_id: + # Derive seed from cache_id for determinism + seed = int(hashlib.sha256(cache_id.encode()).hexdigest()[:8], 16) + values = _apply_noise(values, noise, seed or 0) + + return ResolvedBinding(param_name=param_name, values=values) + + +def _resolve_event_binding( + event_frames: List[int], + total_frames: int, + frame_rate: float, + on_event: float, + decay: float, +) -> List[float]: + """ + Resolve event-based binding with optional decay. + + Args: + event_frames: List of frame indices where events occur + total_frames: Total number of frames + frame_rate: Video frame rate + on_event: Value at event + decay: Decay time constant in seconds (0 = instant) + + Returns: + List of values per frame + """ + values = [0.0] * total_frames + + if not event_frames: + return values + + event_set = set(event_frames) + + if decay <= 0: + # No decay - just mark event frames + for f in event_frames: + if 0 <= f < total_frames: + values[f] = on_event + else: + # Apply exponential decay + decay_frames = decay * frame_rate + for f in event_frames: + if f < 0 or f >= total_frames: + continue + # Apply decay from this event forward + for i in range(f, total_frames): + elapsed = i - f + decayed = on_event * math.exp(-elapsed / decay_frames) + if decayed < 0.001: + break + values[i] = max(values[i], decayed) + + return values + + +def _apply_offset(values: List[float], offset_frames: int) -> List[float]: + """Shift values by offset frames (positive = delay).""" + if offset_frames == 0: + return values + + n = len(values) + result = [0.0] * n + + for i in range(n): + src = i - offset_frames + if 0 <= src < n: + result[i] = values[src] + + return result + + +def _apply_smoothing(values: List[float], window_frames: int) -> List[float]: + """Apply moving average smoothing.""" + if window_frames <= 1: + return values + + n = len(values) + result = [] + half = window_frames // 2 + + for i in range(n): + start = max(0, i - half) + end = min(n, i + half + 1) + avg = sum(values[start:end]) / (end - start) + result.append(avg) + + return result + + +def _apply_range(values: List[float], lo: float, hi: float) -> List[float]: + """Map values from 0-1 to lo-hi range.""" + return [lo + v * (hi - lo) for v in values] + + +def _apply_noise(values: List[float], amount: float, seed: int) -> List[float]: + """Add deterministic noise to values.""" + rng = random.Random(seed) + return [v + rng.uniform(-amount, amount) for v in values] + + +def resolve_all_bindings( + config: Dict[str, Any], + analysis: AnalysisData, + cache_id: str = None, +) -> Dict[str, ResolvedBinding]: + """ + Resolve all bindings in a config dict. + + Looks for values with _binding: True marker. + + Args: + config: Node config with potential bindings + analysis: Analysis data + cache_id: Cache ID for seeding + + Returns: + Dict mapping param name to resolved binding + """ + resolved = {} + + for key, value in config.items(): + if isinstance(value, dict) and value.get("_binding"): + resolved[key] = resolve_binding(value, analysis, key, cache_id) + + return resolved + + +def bindings_to_lookup_table( + bindings: Dict[str, ResolvedBinding], +) -> Dict[str, List[float]]: + """ + Convert resolved bindings to simple lookup tables. + + Returns dict mapping param name to list of per-frame values. + This format is JSON-serializable for inclusion in execution plans. + """ + return {name: binding.values for name, binding in bindings.items()} + + +def has_bindings(config: Dict[str, Any]) -> bool: + """Check if config contains any bindings.""" + for value in config.values(): + if isinstance(value, dict) and value.get("_binding"): + return True + return False + + +def extract_binding_sources(config: Dict[str, Any]) -> List[str]: + """ + Extract all analysis source references from bindings. + + Returns list of node IDs that provide analysis data. + """ + sources = [] + for value in config.values(): + if isinstance(value, dict) and value.get("_binding"): + source = value.get("source") + if source and source not in sources: + sources.append(source) + return sources diff --git a/artdag/effects/frame_processor.py b/artdag/effects/frame_processor.py new file mode 100644 index 0000000..c2a04d2 --- /dev/null +++ b/artdag/effects/frame_processor.py @@ -0,0 +1,347 @@ +""" +FFmpeg pipe-based frame processing. + +Processes video through Python frame-by-frame effects using FFmpeg pipes: + FFmpeg decode -> Python process_frame -> FFmpeg encode + +This avoids writing intermediate frames to disk. +""" + +import logging +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclass +class VideoInfo: + """Video metadata.""" + + width: int + height: int + frame_rate: float + total_frames: int + duration: float + pixel_format: str = "rgb24" + + +def probe_video(path: Path) -> VideoInfo: + """ + Get video information using ffprobe. + + Args: + path: Path to video file + + Returns: + VideoInfo with dimensions, frame rate, etc. + """ + cmd = [ + "ffprobe", + "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=width,height,r_frame_rate,nb_frames,duration", + "-of", "csv=p=0", + str(path), + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"ffprobe failed: {result.stderr}") + + parts = result.stdout.strip().split(",") + if len(parts) < 4: + raise RuntimeError(f"Unexpected ffprobe output: {result.stdout}") + + width = int(parts[0]) + height = int(parts[1]) + + # Parse frame rate (could be "30/1" or "30") + fr_parts = parts[2].split("/") + if len(fr_parts) == 2: + frame_rate = float(fr_parts[0]) / float(fr_parts[1]) + else: + frame_rate = float(fr_parts[0]) + + # nb_frames might be N/A + total_frames = 0 + duration = 0.0 + try: + total_frames = int(parts[3]) + except (ValueError, IndexError): + pass + + try: + duration = float(parts[4]) if len(parts) > 4 else 0.0 + except (ValueError, IndexError): + pass + + if total_frames == 0 and duration > 0: + total_frames = int(duration * frame_rate) + + return VideoInfo( + width=width, + height=height, + frame_rate=frame_rate, + total_frames=total_frames, + duration=duration, + ) + + +FrameProcessor = Callable[[np.ndarray, Dict[str, Any], Any], Tuple[np.ndarray, Any]] + + +def process_video( + input_path: Path, + output_path: Path, + process_frame: FrameProcessor, + params: Dict[str, Any], + bindings: Dict[str, List[float]] = None, + initial_state: Any = None, + pixel_format: str = "rgb24", + output_codec: str = "libx264", + output_options: List[str] = None, +) -> Tuple[Path, Any]: + """ + Process video through frame-by-frame effect. + + Args: + input_path: Input video path + output_path: Output video path + process_frame: Function (frame, params, state) -> (frame, state) + params: Static parameter dict + bindings: Per-frame parameter lookup tables + initial_state: Initial state for process_frame + pixel_format: Pixel format for frame data + output_codec: Video codec for output + output_options: Additional ffmpeg output options + + Returns: + Tuple of (output_path, final_state) + """ + bindings = bindings or {} + output_options = output_options or [] + + # Probe input + info = probe_video(input_path) + logger.info(f"Processing {info.width}x{info.height} @ {info.frame_rate}fps") + + # Calculate bytes per frame + if pixel_format == "rgb24": + bytes_per_pixel = 3 + elif pixel_format == "rgba": + bytes_per_pixel = 4 + else: + bytes_per_pixel = 3 # Default to RGB + + frame_size = info.width * info.height * bytes_per_pixel + + # Start decoder process + decode_cmd = [ + "ffmpeg", + "-i", str(input_path), + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-", + ] + + # Start encoder process + encode_cmd = [ + "ffmpeg", + "-y", + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-s", f"{info.width}x{info.height}", + "-r", str(info.frame_rate), + "-i", "-", + "-i", str(input_path), # For audio + "-map", "0:v", + "-map", "1:a?", + "-c:v", output_codec, + "-c:a", "aac", + *output_options, + str(output_path), + ] + + logger.debug(f"Decoder: {' '.join(decode_cmd)}") + logger.debug(f"Encoder: {' '.join(encode_cmd)}") + + decoder = subprocess.Popen( + decode_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + encoder = subprocess.Popen( + encode_cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + state = initial_state + frame_idx = 0 + + try: + while True: + # Read frame from decoder + raw_frame = decoder.stdout.read(frame_size) + if len(raw_frame) < frame_size: + break + + # Convert to numpy + frame = np.frombuffer(raw_frame, dtype=np.uint8) + frame = frame.reshape((info.height, info.width, bytes_per_pixel)) + + # Build per-frame params + frame_params = dict(params) + for param_name, values in bindings.items(): + if frame_idx < len(values): + frame_params[param_name] = values[frame_idx] + + # Process frame + processed, state = process_frame(frame, frame_params, state) + + # Ensure correct shape and dtype + if processed.shape != frame.shape: + raise ValueError( + f"Frame shape mismatch: {processed.shape} vs {frame.shape}" + ) + processed = processed.astype(np.uint8) + + # Write to encoder + encoder.stdin.write(processed.tobytes()) + frame_idx += 1 + + if frame_idx % 100 == 0: + logger.debug(f"Processed frame {frame_idx}") + + except Exception as e: + logger.error(f"Frame processing failed at frame {frame_idx}: {e}") + raise + finally: + decoder.stdout.close() + decoder.wait() + encoder.stdin.close() + encoder.wait() + + if encoder.returncode != 0: + stderr = encoder.stderr.read().decode() if encoder.stderr else "" + raise RuntimeError(f"Encoder failed: {stderr}") + + logger.info(f"Processed {frame_idx} frames") + return output_path, state + + +def process_video_batch( + input_path: Path, + output_path: Path, + process_frames: Callable[[List[np.ndarray], Dict[str, Any]], List[np.ndarray]], + params: Dict[str, Any], + batch_size: int = 30, + pixel_format: str = "rgb24", + output_codec: str = "libx264", +) -> Path: + """ + Process video in batches for effects that need temporal context. + + Args: + input_path: Input video path + output_path: Output video path + process_frames: Function (frames_batch, params) -> processed_batch + params: Parameter dict + batch_size: Number of frames per batch + pixel_format: Pixel format + output_codec: Output codec + + Returns: + Output path + """ + info = probe_video(input_path) + + if pixel_format == "rgb24": + bytes_per_pixel = 3 + elif pixel_format == "rgba": + bytes_per_pixel = 4 + else: + bytes_per_pixel = 3 + + frame_size = info.width * info.height * bytes_per_pixel + + decode_cmd = [ + "ffmpeg", + "-i", str(input_path), + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-", + ] + + encode_cmd = [ + "ffmpeg", + "-y", + "-f", "rawvideo", + "-pix_fmt", pixel_format, + "-s", f"{info.width}x{info.height}", + "-r", str(info.frame_rate), + "-i", "-", + "-i", str(input_path), + "-map", "0:v", + "-map", "1:a?", + "-c:v", output_codec, + "-c:a", "aac", + str(output_path), + ] + + decoder = subprocess.Popen( + decode_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.DEVNULL, + ) + + encoder = subprocess.Popen( + encode_cmd, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + batch = [] + total_processed = 0 + + try: + while True: + raw_frame = decoder.stdout.read(frame_size) + if len(raw_frame) < frame_size: + # Process remaining batch + if batch: + processed = process_frames(batch, params) + for frame in processed: + encoder.stdin.write(frame.astype(np.uint8).tobytes()) + total_processed += 1 + break + + frame = np.frombuffer(raw_frame, dtype=np.uint8) + frame = frame.reshape((info.height, info.width, bytes_per_pixel)) + batch.append(frame) + + if len(batch) >= batch_size: + processed = process_frames(batch, params) + for frame in processed: + encoder.stdin.write(frame.astype(np.uint8).tobytes()) + total_processed += 1 + batch = [] + + finally: + decoder.stdout.close() + decoder.wait() + encoder.stdin.close() + encoder.wait() + + if encoder.returncode != 0: + stderr = encoder.stderr.read().decode() if encoder.stderr else "" + raise RuntimeError(f"Encoder failed: {stderr}") + + logger.info(f"Processed {total_processed} frames in batches of {batch_size}") + return output_path diff --git a/artdag/effects/loader.py b/artdag/effects/loader.py new file mode 100644 index 0000000..47ee36c --- /dev/null +++ b/artdag/effects/loader.py @@ -0,0 +1,455 @@ +""" +Effect file loader. + +Parses effect files with: +- PEP 723 inline script metadata for dependencies +- @-tag docstrings for effect metadata +- META object for programmatic access +""" + +import ast +import hashlib +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from .meta import EffectMeta, ParamSpec + + +@dataclass +class LoadedEffect: + """ + A loaded effect with all metadata. + + Attributes: + source: Original source code + cid: SHA3-256 hash of source + meta: Extracted EffectMeta + dependencies: List of pip dependencies + requires_python: Python version requirement + module: Compiled module (if loaded) + """ + + source: str + cid: str + meta: EffectMeta + dependencies: List[str] = field(default_factory=list) + requires_python: str = ">=3.10" + module: Any = None + + def has_frame_api(self) -> bool: + """Check if effect has frame-by-frame API.""" + return self.meta.api_type == "frame" + + def has_video_api(self) -> bool: + """Check if effect has whole-video API.""" + return self.meta.api_type == "video" + + +def compute_cid(source: str) -> str: + """Compute SHA3-256 hash of effect source.""" + return hashlib.sha3_256(source.encode("utf-8")).hexdigest() + + +def parse_pep723_metadata(source: str) -> Tuple[List[str], str]: + """ + Parse PEP 723 inline script metadata. + + Looks for: + # /// script + # requires-python = ">=3.10" + # dependencies = ["numpy", "opencv-python"] + # /// + + Returns: + Tuple of (dependencies list, requires_python string) + """ + dependencies = [] + requires_python = ">=3.10" + + # Match the script block + pattern = r"# /// script\n(.*?)# ///" + match = re.search(pattern, source, re.DOTALL) + + if not match: + return dependencies, requires_python + + block = match.group(1) + + # Parse dependencies + deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL) + if deps_match: + deps_str = deps_match.group(1) + # Extract quoted strings + dependencies = re.findall(r'"([^"]+)"', deps_str) + + # Parse requires-python + python_match = re.search(r'# requires-python = "([^"]+)"', block) + if python_match: + requires_python = python_match.group(1) + + return dependencies, requires_python + + +def parse_docstring_metadata(docstring: str) -> Dict[str, Any]: + """ + Parse @-tag metadata from docstring. + + Supports: + @effect name + @version 1.0.0 + @author @user@domain + @temporal false + @description + Multi-line description text. + + @param name type + @range lo hi + @default value + Description text. + + @example + (fx effect :param value) + + Returns: + Dictionary with extracted metadata + """ + if not docstring: + return {} + + result = { + "name": "", + "version": "1.0.0", + "author": "", + "temporal": False, + "description": "", + "params": [], + "examples": [], + } + + lines = docstring.strip().split("\n") + i = 0 + current_param = None + + while i < len(lines): + line = lines[i].strip() + + if line.startswith("@effect "): + result["name"] = line[8:].strip() + + elif line.startswith("@version "): + result["version"] = line[9:].strip() + + elif line.startswith("@author "): + result["author"] = line[8:].strip() + + elif line.startswith("@temporal "): + val = line[10:].strip().lower() + result["temporal"] = val in ("true", "yes", "1") + + elif line.startswith("@description"): + # Collect multi-line description + desc_lines = [] + i += 1 + while i < len(lines): + next_line = lines[i] + if next_line.strip().startswith("@"): + i -= 1 # Back up to process this tag + break + desc_lines.append(next_line) + i += 1 + result["description"] = "\n".join(desc_lines).strip() + + elif line.startswith("@param "): + # Parse parameter: @param name type + parts = line[7:].split() + if len(parts) >= 2: + current_param = { + "name": parts[0], + "type": parts[1], + "range": None, + "default": None, + "description": "", + } + # Collect param details + desc_lines = [] + i += 1 + while i < len(lines): + next_line = lines[i] + stripped = next_line.strip() + + if stripped.startswith("@range "): + range_parts = stripped[7:].split() + if len(range_parts) >= 2: + try: + current_param["range"] = ( + float(range_parts[0]), + float(range_parts[1]), + ) + except ValueError: + pass + + elif stripped.startswith("@default "): + current_param["default"] = stripped[9:].strip() + + elif stripped.startswith("@param ") or stripped.startswith("@example"): + i -= 1 # Back up + break + + elif stripped.startswith("@"): + i -= 1 + break + + elif stripped: + desc_lines.append(stripped) + + i += 1 + + current_param["description"] = " ".join(desc_lines) + result["params"].append(current_param) + current_param = None + + elif line.startswith("@example"): + # Collect example + example_lines = [] + i += 1 + while i < len(lines): + next_line = lines[i] + if next_line.strip().startswith("@") and not next_line.strip().startswith("@example"): + if next_line.strip().startswith("@example"): + i -= 1 + break + if next_line.strip().startswith("@example"): + i -= 1 + break + example_lines.append(next_line) + i += 1 + example = "\n".join(example_lines).strip() + if example: + result["examples"].append(example) + + i += 1 + + return result + + +def extract_meta_from_ast(source: str) -> Optional[Dict[str, Any]]: + """ + Extract META object from source AST. + + Looks for: + META = EffectMeta(...) + + Returns the keyword arguments if found. + """ + try: + tree = ast.parse(source) + except SyntaxError: + return None + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "META": + if isinstance(node.value, ast.Call): + return _extract_call_kwargs(node.value) + return None + + +def _extract_call_kwargs(call: ast.Call) -> Dict[str, Any]: + """Extract keyword arguments from an AST Call node.""" + result = {} + + for keyword in call.keywords: + if keyword.arg is None: + continue + value = _ast_to_value(keyword.value) + if value is not None: + result[keyword.arg] = value + + return result + + +def _ast_to_value(node: ast.expr) -> Any: + """Convert AST node to Python value.""" + if isinstance(node, ast.Constant): + return node.value + elif isinstance(node, ast.Str): # Python 3.7 compat + return node.s + elif isinstance(node, ast.Num): # Python 3.7 compat + return node.n + elif isinstance(node, ast.NameConstant): # Python 3.7 compat + return node.value + elif isinstance(node, ast.List): + return [_ast_to_value(elt) for elt in node.elts] + elif isinstance(node, ast.Tuple): + return tuple(_ast_to_value(elt) for elt in node.elts) + elif isinstance(node, ast.Dict): + return { + _ast_to_value(k): _ast_to_value(v) + for k, v in zip(node.keys, node.values) + if k is not None + } + elif isinstance(node, ast.Call): + # Handle ParamSpec(...) calls + if isinstance(node.func, ast.Name) and node.func.id == "ParamSpec": + return _extract_call_kwargs(node) + return None + + +def get_module_docstring(source: str) -> str: + """Extract the module-level docstring from source.""" + try: + tree = ast.parse(source) + except SyntaxError: + return "" + + if tree.body and isinstance(tree.body[0], ast.Expr): + if isinstance(tree.body[0].value, ast.Constant): + return tree.body[0].value.value + elif isinstance(tree.body[0].value, ast.Str): # Python 3.7 compat + return tree.body[0].value.s + return "" + + +def load_effect(source: str) -> LoadedEffect: + """ + Load an effect from source code. + + Parses: + 1. PEP 723 metadata for dependencies + 2. Module docstring for @-tag metadata + 3. META object for programmatic metadata + + Priority: META object > docstring > defaults + + Args: + source: Effect source code + + Returns: + LoadedEffect with all metadata + + Raises: + ValueError: If effect is invalid + """ + cid = compute_cid(source) + + # Parse PEP 723 metadata + dependencies, requires_python = parse_pep723_metadata(source) + + # Parse docstring metadata + docstring = get_module_docstring(source) + doc_meta = parse_docstring_metadata(docstring) + + # Try to extract META from AST + ast_meta = extract_meta_from_ast(source) + + # Build EffectMeta, preferring META object over docstring + name = "" + if ast_meta and "name" in ast_meta: + name = ast_meta["name"] + elif doc_meta.get("name"): + name = doc_meta["name"] + + if not name: + raise ValueError("Effect must have a name (@effect or META.name)") + + version = ast_meta.get("version") if ast_meta else doc_meta.get("version", "1.0.0") + temporal = ast_meta.get("temporal") if ast_meta else doc_meta.get("temporal", False) + author = ast_meta.get("author") if ast_meta else doc_meta.get("author", "") + description = ast_meta.get("description") if ast_meta else doc_meta.get("description", "") + examples = ast_meta.get("examples") if ast_meta else doc_meta.get("examples", []) + + # Build params + params = [] + if ast_meta and "params" in ast_meta: + for p in ast_meta["params"]: + if isinstance(p, dict): + type_map = {"float": float, "int": int, "bool": bool, "str": str} + param_type = type_map.get(p.get("param_type", "float"), float) + if isinstance(p.get("param_type"), type): + param_type = p["param_type"] + params.append( + ParamSpec( + name=p.get("name", ""), + param_type=param_type, + default=p.get("default"), + range=p.get("range"), + description=p.get("description", ""), + ) + ) + elif doc_meta.get("params"): + for p in doc_meta["params"]: + type_map = {"float": float, "int": int, "bool": bool, "str": str} + param_type = type_map.get(p.get("type", "float"), float) + + default = p.get("default") + if default is not None: + try: + default = param_type(default) + except (ValueError, TypeError): + pass + + params.append( + ParamSpec( + name=p["name"], + param_type=param_type, + default=default, + range=p.get("range"), + description=p.get("description", ""), + ) + ) + + # Determine API type by checking for function definitions + api_type = "frame" # default + try: + tree = ast.parse(source) + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + if node.name == "process": + api_type = "video" + break + elif node.name == "process_frame": + api_type = "frame" + break + except SyntaxError: + pass + + meta = EffectMeta( + name=name, + version=version if isinstance(version, str) else "1.0.0", + temporal=bool(temporal), + params=params, + author=author if isinstance(author, str) else "", + description=description if isinstance(description, str) else "", + examples=examples if isinstance(examples, list) else [], + dependencies=dependencies, + requires_python=requires_python, + api_type=api_type, + ) + + return LoadedEffect( + source=source, + cid=cid, + meta=meta, + dependencies=dependencies, + requires_python=requires_python, + ) + + +def load_effect_file(path: Path) -> LoadedEffect: + """Load an effect from a file path.""" + source = path.read_text(encoding="utf-8") + return load_effect(source) + + +def compute_deps_hash(dependencies: List[str]) -> str: + """ + Compute hash of sorted dependencies. + + Used for venv caching - same deps = same hash = reuse venv. + """ + sorted_deps = sorted(dep.lower().strip() for dep in dependencies) + deps_str = "\n".join(sorted_deps) + return hashlib.sha3_256(deps_str.encode("utf-8")).hexdigest() diff --git a/artdag/effects/meta.py b/artdag/effects/meta.py new file mode 100644 index 0000000..810623a --- /dev/null +++ b/artdag/effects/meta.py @@ -0,0 +1,247 @@ +""" +Effect metadata types. + +Defines the core dataclasses for effect metadata: +- ParamSpec: Parameter specification with type, range, and default +- EffectMeta: Complete effect metadata including params and flags +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type, Union + + +@dataclass +class ParamSpec: + """ + Specification for an effect parameter. + + Attributes: + name: Parameter name (used in recipes as :name) + param_type: Python type (float, int, bool, str) + default: Default value if not specified + range: Optional (min, max) tuple for numeric types + description: Human-readable description + choices: Optional list of allowed values (for enums) + """ + + name: str + param_type: Type + default: Any = None + range: Optional[Tuple[float, float]] = None + description: str = "" + choices: Optional[List[Any]] = None + + def validate(self, value: Any) -> Any: + """ + Validate and coerce a parameter value. + + Args: + value: Input value to validate + + Returns: + Validated and coerced value + + Raises: + ValueError: If value is invalid + """ + if value is None: + if self.default is not None: + return self.default + raise ValueError(f"Parameter '{self.name}' requires a value") + + # Type coercion + try: + if self.param_type == bool: + if isinstance(value, str): + value = value.lower() in ("true", "1", "yes") + else: + value = bool(value) + elif self.param_type == int: + value = int(value) + elif self.param_type == float: + value = float(value) + elif self.param_type == str: + value = str(value) + else: + value = self.param_type(value) + except (ValueError, TypeError) as e: + raise ValueError( + f"Parameter '{self.name}' expects {self.param_type.__name__}, " + f"got {type(value).__name__}: {e}" + ) + + # Range check for numeric types + if self.range is not None and self.param_type in (int, float): + min_val, max_val = self.range + if value < min_val or value > max_val: + raise ValueError( + f"Parameter '{self.name}' must be in range " + f"[{min_val}, {max_val}], got {value}" + ) + + # Choices check + if self.choices is not None and value not in self.choices: + raise ValueError( + f"Parameter '{self.name}' must be one of {self.choices}, got {value}" + ) + + return value + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + d = { + "name": self.name, + "type": self.param_type.__name__, + "description": self.description, + } + if self.default is not None: + d["default"] = self.default + if self.range is not None: + d["range"] = list(self.range) + if self.choices is not None: + d["choices"] = self.choices + return d + + +@dataclass +class EffectMeta: + """ + Complete metadata for an effect. + + Attributes: + name: Effect name (used in recipes) + version: Semantic version string + temporal: If True, effect needs complete input (can't be collapsed) + params: List of parameter specifications + author: Optional author identifier + description: Human-readable description + examples: List of example S-expression usages + dependencies: List of Python package dependencies + requires_python: Minimum Python version + api_type: "frame" for frame-by-frame, "video" for whole-video + """ + + name: str + version: str = "1.0.0" + temporal: bool = False + params: List[ParamSpec] = field(default_factory=list) + author: str = "" + description: str = "" + examples: List[str] = field(default_factory=list) + dependencies: List[str] = field(default_factory=list) + requires_python: str = ">=3.10" + api_type: str = "frame" # "frame" or "video" + + def get_param(self, name: str) -> Optional[ParamSpec]: + """Get a parameter spec by name.""" + for param in self.params: + if param.name == name: + return param + return None + + def validate_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + Validate all parameters. + + Args: + params: Dictionary of parameter values + + Returns: + Dictionary with validated/coerced values including defaults + + Raises: + ValueError: If any parameter is invalid + """ + result = {} + for spec in self.params: + value = params.get(spec.name) + result[spec.name] = spec.validate(value) + return result + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "name": self.name, + "version": self.version, + "temporal": self.temporal, + "params": [p.to_dict() for p in self.params], + "author": self.author, + "description": self.description, + "examples": self.examples, + "dependencies": self.dependencies, + "requires_python": self.requires_python, + "api_type": self.api_type, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "EffectMeta": + """Create from dictionary.""" + params = [] + for p in data.get("params", []): + # Map type name back to Python type + type_map = {"float": float, "int": int, "bool": bool, "str": str} + param_type = type_map.get(p.get("type", "float"), float) + params.append( + ParamSpec( + name=p["name"], + param_type=param_type, + default=p.get("default"), + range=tuple(p["range"]) if p.get("range") else None, + description=p.get("description", ""), + choices=p.get("choices"), + ) + ) + + return cls( + name=data["name"], + version=data.get("version", "1.0.0"), + temporal=data.get("temporal", False), + params=params, + author=data.get("author", ""), + description=data.get("description", ""), + examples=data.get("examples", []), + dependencies=data.get("dependencies", []), + requires_python=data.get("requires_python", ">=3.10"), + api_type=data.get("api_type", "frame"), + ) + + +@dataclass +class ExecutionContext: + """ + Context passed to effect execution. + + Provides controlled access to resources within sandbox. + """ + + input_paths: List[str] + output_path: str + params: Dict[str, Any] + seed: int # Deterministic seed for RNG + frame_rate: float = 30.0 + width: int = 1920 + height: int = 1080 + + # Resolved bindings (frame -> param value lookup) + bindings: Dict[str, List[float]] = field(default_factory=dict) + + def get_param_at_frame(self, param_name: str, frame: int) -> Any: + """ + Get parameter value at a specific frame. + + If parameter has a binding, looks up the bound value. + Otherwise returns the static parameter value. + """ + if param_name in self.bindings: + binding_values = self.bindings[param_name] + if frame < len(binding_values): + return binding_values[frame] + # Past end of binding data, use last value + return binding_values[-1] if binding_values else self.params.get(param_name) + return self.params.get(param_name) + + def get_rng(self) -> "random.Random": + """Get a seeded random number generator.""" + import random + + return random.Random(self.seed) diff --git a/artdag/effects/runner.py b/artdag/effects/runner.py new file mode 100644 index 0000000..2f58c12 --- /dev/null +++ b/artdag/effects/runner.py @@ -0,0 +1,259 @@ +""" +Effect runner. + +Main entry point for executing cached effects with sandboxing. +""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Optional + +from .binding import AnalysisData, bindings_to_lookup_table, resolve_all_bindings +from .loader import load_effect, LoadedEffect +from .meta import ExecutionContext +from .sandbox import Sandbox, SandboxConfig, SandboxResult, get_venv_path + +logger = logging.getLogger(__name__) + + +def run_effect( + effect_source: str, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + analysis: Optional[AnalysisData] = None, + cache_id: str = None, + seed: int = 0, + trust_level: str = "untrusted", + timeout: int = 3600, +) -> SandboxResult: + """ + Run an effect with full sandboxing. + + This is the main entry point for effect execution. + + Args: + effect_source: Effect source code + input_paths: List of input file paths + output_path: Output file path + params: Effect parameters (may contain bindings) + analysis: Optional analysis data for binding resolution + cache_id: Cache ID for deterministic seeding + seed: RNG seed (overrides cache_id-based seed) + trust_level: "untrusted" or "trusted" + timeout: Maximum execution time in seconds + + Returns: + SandboxResult with success status and output + """ + # Load and validate effect + loaded = load_effect(effect_source) + logger.info(f"Running effect '{loaded.meta.name}' v{loaded.meta.version}") + + # Resolve bindings if analysis data available + bindings = {} + if analysis: + resolved = resolve_all_bindings(params, analysis, cache_id) + bindings = bindings_to_lookup_table(resolved) + # Remove binding dicts from params, keeping only resolved values + params = { + k: v for k, v in params.items() + if not (isinstance(v, dict) and v.get("_binding")) + } + + # Validate parameters + validated_params = loaded.meta.validate_params(params) + + # Get or create venv for dependencies + venv_path = None + if loaded.dependencies: + venv_path = get_venv_path(loaded.dependencies) + + # Configure sandbox + config = SandboxConfig( + trust_level=trust_level, + venv_path=venv_path, + timeout=timeout, + ) + + # Write effect to temp file + import tempfile + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + delete=False, + ) as f: + f.write(effect_source) + effect_path = Path(f.name) + + try: + with Sandbox(config) as sandbox: + result = sandbox.run_effect( + effect_path=effect_path, + input_paths=input_paths, + output_path=output_path, + params=validated_params, + bindings=bindings, + seed=seed, + ) + finally: + effect_path.unlink(missing_ok=True) + + return result + + +def run_effect_from_cache( + cache, + effect_hash: str, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + analysis: Optional[AnalysisData] = None, + cache_id: str = None, + seed: int = 0, + trust_level: str = "untrusted", + timeout: int = 3600, +) -> SandboxResult: + """ + Run an effect from cache by content hash. + + Args: + cache: Cache instance + effect_hash: Content hash of effect + input_paths: Input file paths + output_path: Output file path + params: Effect parameters + analysis: Optional analysis data + cache_id: Cache ID for seeding + seed: RNG seed + trust_level: "untrusted" or "trusted" + timeout: Max execution time + + Returns: + SandboxResult + """ + effect_source = cache.get_effect(effect_hash) + if not effect_source: + return SandboxResult( + success=False, + error=f"Effect not found in cache: {effect_hash[:16]}...", + ) + + return run_effect( + effect_source=effect_source, + input_paths=input_paths, + output_path=output_path, + params=params, + analysis=analysis, + cache_id=cache_id, + seed=seed, + trust_level=trust_level, + timeout=timeout, + ) + + +def check_effect_temporal(cache, effect_hash: str) -> bool: + """ + Check if an effect is temporal (can't be collapsed). + + Args: + cache: Cache instance + effect_hash: Content hash of effect + + Returns: + True if effect is temporal + """ + metadata = cache.get_effect_metadata(effect_hash) + if not metadata: + return False + + meta = metadata.get("meta", {}) + return meta.get("temporal", False) + + +def get_effect_api_type(cache, effect_hash: str) -> str: + """ + Get the API type of an effect. + + Args: + cache: Cache instance + effect_hash: Content hash of effect + + Returns: + "frame" or "video" + """ + metadata = cache.get_effect_metadata(effect_hash) + if not metadata: + return "frame" + + meta = metadata.get("meta", {}) + return meta.get("api_type", "frame") + + +class EffectExecutor: + """ + Executor for cached effects. + + Provides a higher-level interface for effect execution. + """ + + def __init__(self, cache, trust_level: str = "untrusted"): + """ + Initialize executor. + + Args: + cache: Cache instance + trust_level: Default trust level + """ + self.cache = cache + self.trust_level = trust_level + + def execute( + self, + effect_hash: str, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + analysis: Optional[AnalysisData] = None, + step_cache_id: str = None, + ) -> SandboxResult: + """ + Execute an effect. + + Args: + effect_hash: Content hash of effect + input_paths: Input file paths + output_path: Output path + params: Effect parameters + analysis: Analysis data for bindings + step_cache_id: Step cache ID for seeding + + Returns: + SandboxResult + """ + # Check effect metadata for trust level override + metadata = self.cache.get_effect_metadata(effect_hash) + trust_level = self.trust_level + if metadata: + # L1 owner can mark effect as trusted + if metadata.get("trust_level") == "trusted": + trust_level = "trusted" + + return run_effect_from_cache( + cache=self.cache, + effect_hash=effect_hash, + input_paths=input_paths, + output_path=output_path, + params=params, + analysis=analysis, + cache_id=step_cache_id, + trust_level=trust_level, + ) + + def is_temporal(self, effect_hash: str) -> bool: + """Check if effect is temporal.""" + return check_effect_temporal(self.cache, effect_hash) + + def get_api_type(self, effect_hash: str) -> str: + """Get effect API type.""" + return get_effect_api_type(self.cache, effect_hash) diff --git a/artdag/effects/sandbox.py b/artdag/effects/sandbox.py new file mode 100644 index 0000000..d0d545e --- /dev/null +++ b/artdag/effects/sandbox.py @@ -0,0 +1,431 @@ +""" +Sandbox for effect execution. + +Uses bubblewrap (bwrap) for Linux namespace isolation. +Provides controlled access to: + - Input files (read-only) + - Output file (write) + - stderr (logging) + - Seeded RNG +""" + +import hashlib +import json +import logging +import os +import shutil +import subprocess +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class SandboxConfig: + """ + Sandbox configuration. + + Attributes: + trust_level: "untrusted" (full isolation) or "trusted" (allows subprocess) + venv_path: Path to effect's virtual environment + wheel_cache: Shared wheel cache directory + timeout: Maximum execution time in seconds + memory_limit: Memory limit in bytes (0 = unlimited) + allow_network: Whether to allow network access + """ + + trust_level: str = "untrusted" + venv_path: Optional[Path] = None + wheel_cache: Path = field(default_factory=lambda: Path("/var/cache/artdag/wheels")) + timeout: int = 3600 # 1 hour default + memory_limit: int = 0 + allow_network: bool = False + + +def is_bwrap_available() -> bool: + """Check if bubblewrap is available.""" + try: + result = subprocess.run( + ["bwrap", "--version"], + capture_output=True, + text=True, + ) + return result.returncode == 0 + except FileNotFoundError: + return False + + +def get_venv_path(dependencies: List[str], cache_dir: Path = None) -> Path: + """ + Get or create venv for given dependencies. + + Uses hash of sorted dependencies for cache key. + + Args: + dependencies: List of pip package specifiers + cache_dir: Base directory for venv cache + + Returns: + Path to venv directory + """ + cache_dir = cache_dir or Path("/var/cache/artdag/venvs") + cache_dir.mkdir(parents=True, exist_ok=True) + + # Compute deps hash + sorted_deps = sorted(dep.lower().strip() for dep in dependencies) + deps_str = "\n".join(sorted_deps) + deps_hash = hashlib.sha3_256(deps_str.encode()).hexdigest()[:16] + + venv_path = cache_dir / deps_hash + + if venv_path.exists(): + logger.debug(f"Reusing venv at {venv_path}") + return venv_path + + # Create new venv + logger.info(f"Creating venv for {len(dependencies)} deps at {venv_path}") + + subprocess.run( + ["python", "-m", "venv", str(venv_path)], + check=True, + ) + + # Install dependencies + pip_path = venv_path / "bin" / "pip" + wheel_cache = Path("/var/cache/artdag/wheels") + + if dependencies: + cmd = [ + str(pip_path), + "install", + "--cache-dir", str(wheel_cache), + *dependencies, + ] + subprocess.run(cmd, check=True) + + return venv_path + + +@dataclass +class SandboxResult: + """Result of sandboxed execution.""" + + success: bool + output_path: Optional[Path] = None + stderr: str = "" + exit_code: int = 0 + error: Optional[str] = None + + +class Sandbox: + """ + Sandboxed effect execution environment. + + Uses bubblewrap for namespace isolation when available, + falls back to subprocess with restricted permissions. + """ + + def __init__(self, config: SandboxConfig = None): + self.config = config or SandboxConfig() + self._temp_dirs: List[Path] = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.cleanup() + + def cleanup(self): + """Clean up temporary directories.""" + for temp_dir in self._temp_dirs: + if temp_dir.exists(): + shutil.rmtree(temp_dir, ignore_errors=True) + self._temp_dirs = [] + + def _create_temp_dir(self) -> Path: + """Create a temporary directory for sandbox use.""" + temp_dir = Path(tempfile.mkdtemp(prefix="artdag_sandbox_")) + self._temp_dirs.append(temp_dir) + return temp_dir + + def run_effect( + self, + effect_path: Path, + input_paths: List[Path], + output_path: Path, + params: Dict[str, Any], + bindings: Dict[str, List[float]] = None, + seed: int = 0, + ) -> SandboxResult: + """ + Run an effect in the sandbox. + + Args: + effect_path: Path to effect.py + input_paths: List of input file paths + output_path: Output file path + params: Effect parameters + bindings: Per-frame parameter bindings + seed: RNG seed for determinism + + Returns: + SandboxResult with success status and output + """ + bindings = bindings or {} + + # Create work directory + work_dir = self._create_temp_dir() + config_path = work_dir / "config.json" + effect_copy = work_dir / "effect.py" + + # Copy effect to work dir + shutil.copy(effect_path, effect_copy) + + # Write config file + config_data = { + "input_paths": [str(p) for p in input_paths], + "output_path": str(output_path), + "params": params, + "bindings": bindings, + "seed": seed, + } + config_path.write_text(json.dumps(config_data)) + + if is_bwrap_available() and self.config.trust_level == "untrusted": + return self._run_with_bwrap( + effect_copy, config_path, input_paths, output_path, work_dir + ) + else: + return self._run_subprocess( + effect_copy, config_path, input_paths, output_path, work_dir + ) + + def _run_with_bwrap( + self, + effect_path: Path, + config_path: Path, + input_paths: List[Path], + output_path: Path, + work_dir: Path, + ) -> SandboxResult: + """Run effect with bubblewrap isolation.""" + logger.info("Running effect in bwrap sandbox") + + # Build bwrap command + cmd = [ + "bwrap", + # New PID namespace + "--unshare-pid", + # No network + "--unshare-net", + # Read-only root filesystem + "--ro-bind", "/", "/", + # Read-write work directory + "--bind", str(work_dir), str(work_dir), + # Read-only input files + ] + + for input_path in input_paths: + cmd.extend(["--ro-bind", str(input_path), str(input_path)]) + + # Bind output directory as writable + output_dir = output_path.parent + output_dir.mkdir(parents=True, exist_ok=True) + cmd.extend(["--bind", str(output_dir), str(output_dir)]) + + # Bind venv if available + if self.config.venv_path and self.config.venv_path.exists(): + cmd.extend(["--ro-bind", str(self.config.venv_path), str(self.config.venv_path)]) + python_path = self.config.venv_path / "bin" / "python" + else: + python_path = Path("/usr/bin/python3") + + # Add runner script + runner_script = self._get_runner_script() + runner_path = work_dir / "runner.py" + runner_path.write_text(runner_script) + + # Run the effect + cmd.extend([ + str(python_path), + str(runner_path), + str(effect_path), + str(config_path), + ]) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=self.config.timeout, + ) + + if result.returncode == 0 and output_path.exists(): + return SandboxResult( + success=True, + output_path=output_path, + stderr=result.stderr, + exit_code=0, + ) + else: + return SandboxResult( + success=False, + stderr=result.stderr, + exit_code=result.returncode, + error=result.stderr or "Effect execution failed", + ) + + except subprocess.TimeoutExpired: + return SandboxResult( + success=False, + error=f"Effect timed out after {self.config.timeout}s", + exit_code=-1, + ) + except Exception as e: + return SandboxResult( + success=False, + error=str(e), + exit_code=-1, + ) + + def _run_subprocess( + self, + effect_path: Path, + config_path: Path, + input_paths: List[Path], + output_path: Path, + work_dir: Path, + ) -> SandboxResult: + """Run effect in subprocess (fallback without bwrap).""" + logger.warning("Running effect without sandbox isolation") + + # Create runner script + runner_script = self._get_runner_script() + runner_path = work_dir / "runner.py" + runner_path.write_text(runner_script) + + # Determine Python path + if self.config.venv_path and self.config.venv_path.exists(): + python_path = self.config.venv_path / "bin" / "python" + else: + python_path = "python3" + + cmd = [ + str(python_path), + str(runner_path), + str(effect_path), + str(config_path), + ] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=self.config.timeout, + cwd=str(work_dir), + ) + + if result.returncode == 0 and output_path.exists(): + return SandboxResult( + success=True, + output_path=output_path, + stderr=result.stderr, + exit_code=0, + ) + else: + return SandboxResult( + success=False, + stderr=result.stderr, + exit_code=result.returncode, + error=result.stderr or "Effect execution failed", + ) + + except subprocess.TimeoutExpired: + return SandboxResult( + success=False, + error=f"Effect timed out after {self.config.timeout}s", + exit_code=-1, + ) + except Exception as e: + return SandboxResult( + success=False, + error=str(e), + exit_code=-1, + ) + + def _get_runner_script(self) -> str: + """Get the runner script that executes effects.""" + return '''#!/usr/bin/env python3 +"""Effect runner script - executed in sandbox.""" + +import importlib.util +import json +import sys +from pathlib import Path + +def load_effect(effect_path): + """Load effect module from path.""" + spec = importlib.util.spec_from_file_location("effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +def main(): + if len(sys.argv) < 3: + print("Usage: runner.py ", file=sys.stderr) + sys.exit(1) + + effect_path = Path(sys.argv[1]) + config_path = Path(sys.argv[2]) + + # Load config + config = json.loads(config_path.read_text()) + + input_paths = [Path(p) for p in config["input_paths"]] + output_path = Path(config["output_path"]) + params = config["params"] + bindings = config.get("bindings", {}) + seed = config.get("seed", 0) + + # Load effect + effect = load_effect(effect_path) + + # Check API type + if hasattr(effect, "process"): + # Whole-video API + from artdag.effects.meta import ExecutionContext + ctx = ExecutionContext( + input_paths=[str(p) for p in input_paths], + output_path=str(output_path), + params=params, + seed=seed, + bindings=bindings, + ) + effect.process(input_paths, output_path, params, ctx) + + elif hasattr(effect, "process_frame"): + # Frame-by-frame API + from artdag.effects.frame_processor import process_video + + result_path, _ = process_video( + input_path=input_paths[0], + output_path=output_path, + process_frame=effect.process_frame, + params=params, + bindings=bindings, + ) + + else: + print("Effect must have process() or process_frame()", file=sys.stderr) + sys.exit(1) + + print(f"Success: {output_path}", file=sys.stderr) + +if __name__ == "__main__": + main() +''' diff --git a/artdag/engine.py b/artdag/engine.py new file mode 100644 index 0000000..0e70154 --- /dev/null +++ b/artdag/engine.py @@ -0,0 +1,246 @@ +# primitive/engine.py +""" +DAG execution engine. + +Executes DAGs by: +1. Resolving nodes in topological order +2. Checking cache for each node +3. Running executors for cache misses +4. Storing results in cache +""" + +import logging +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +from .dag import DAG, Node, NodeType +from .cache import Cache +from .executor import Executor, get_executor + +logger = logging.getLogger(__name__) + + +@dataclass +class ExecutionResult: + """Result of executing a DAG.""" + success: bool + output_path: Optional[Path] = None + error: Optional[str] = None + execution_time: float = 0.0 + nodes_executed: int = 0 + nodes_cached: int = 0 + node_results: Dict[str, Path] = field(default_factory=dict) + + +@dataclass +class NodeProgress: + """Progress update for a node.""" + node_id: str + node_type: str + status: str # "pending", "running", "cached", "completed", "failed" + progress: float = 0.0 # 0.0 to 1.0 + message: str = "" + + +# Progress callback type +ProgressCallback = Callable[[NodeProgress], None] + + +class Engine: + """ + DAG execution engine. + + Manages cache, resolves dependencies, and runs executors. + """ + + def __init__(self, cache_dir: Path | str): + self.cache = Cache(cache_dir) + self._progress_callback: Optional[ProgressCallback] = None + + def set_progress_callback(self, callback: ProgressCallback): + """Set callback for progress updates.""" + self._progress_callback = callback + + def _report_progress(self, progress: NodeProgress): + """Report progress to callback if set.""" + if self._progress_callback: + try: + self._progress_callback(progress) + except Exception as e: + logger.warning(f"Progress callback error: {e}") + + def execute(self, dag: DAG) -> ExecutionResult: + """ + Execute a DAG and return the result. + + Args: + dag: The DAG to execute + + Returns: + ExecutionResult with output path or error + """ + start_time = time.time() + node_results: Dict[str, Path] = {} + nodes_executed = 0 + nodes_cached = 0 + + # Validate DAG + errors = dag.validate() + if errors: + return ExecutionResult( + success=False, + error=f"Invalid DAG: {errors}", + execution_time=time.time() - start_time, + ) + + # Get topological order + try: + order = dag.topological_order() + except Exception as e: + return ExecutionResult( + success=False, + error=f"Failed to order DAG: {e}", + execution_time=time.time() - start_time, + ) + + # Execute each node + for node_id in order: + node = dag.get_node(node_id) + type_str = node.node_type.name if isinstance(node.node_type, NodeType) else str(node.node_type) + + # Report starting + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="pending", + message=f"Processing {type_str}", + )) + + # Check cache first + cached_path = self.cache.get(node_id) + if cached_path is not None: + node_results[node_id] = cached_path + nodes_cached += 1 + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="cached", + progress=1.0, + message="Using cached result", + )) + continue + + # Get executor + executor = get_executor(node.node_type) + if executor is None: + return ExecutionResult( + success=False, + error=f"No executor for node type: {node.node_type}", + execution_time=time.time() - start_time, + node_results=node_results, + ) + + # Resolve input paths + input_paths = [] + for input_id in node.inputs: + if input_id not in node_results: + return ExecutionResult( + success=False, + error=f"Missing input {input_id} for node {node_id}", + execution_time=time.time() - start_time, + node_results=node_results, + ) + input_paths.append(node_results[input_id]) + + # Determine output path + output_path = self.cache.get_output_path(node_id, ".mkv") + + # Execute + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="running", + progress=0.5, + message=f"Executing {type_str}", + )) + + node_start = time.time() + try: + result_path = executor.execute( + config=node.config, + inputs=input_paths, + output_path=output_path, + ) + node_time = time.time() - node_start + + # Store in cache (file is already at output_path) + self.cache.put( + node_id=node_id, + source_path=result_path, + node_type=type_str, + execution_time=node_time, + move=False, # Already in place + ) + + node_results[node_id] = result_path + nodes_executed += 1 + + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="completed", + progress=1.0, + message=f"Completed in {node_time:.2f}s", + )) + + except Exception as e: + logger.error(f"Node {node_id} failed: {e}") + self._report_progress(NodeProgress( + node_id=node_id, + node_type=type_str, + status="failed", + message=str(e), + )) + return ExecutionResult( + success=False, + error=f"Node {node_id} ({type_str}) failed: {e}", + execution_time=time.time() - start_time, + node_results=node_results, + nodes_executed=nodes_executed, + nodes_cached=nodes_cached, + ) + + # Get final output + output_path = node_results.get(dag.output_id) + + return ExecutionResult( + success=True, + output_path=output_path, + execution_time=time.time() - start_time, + nodes_executed=nodes_executed, + nodes_cached=nodes_cached, + node_results=node_results, + ) + + def execute_node(self, node: Node, inputs: List[Path]) -> Path: + """ + Execute a single node (bypassing DAG structure). + + Useful for testing individual executors. + """ + executor = get_executor(node.node_type) + if executor is None: + raise ValueError(f"No executor for node type: {node.node_type}") + + output_path = self.cache.get_output_path(node.node_id, ".mkv") + return executor.execute(node.config, inputs, output_path) + + def get_cache_stats(self): + """Get cache statistics.""" + return self.cache.get_stats() + + def clear_cache(self): + """Clear the cache.""" + self.cache.clear() diff --git a/artdag/executor.py b/artdag/executor.py new file mode 100644 index 0000000..e2deba8 --- /dev/null +++ b/artdag/executor.py @@ -0,0 +1,106 @@ +# primitive/executor.py +""" +Executor base class and registry. + +Executors implement the actual operations for each node type. +They are registered by node type and looked up during execution. +""" + +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Type + +from .dag import NodeType + +logger = logging.getLogger(__name__) + +# Global executor registry +_EXECUTORS: Dict[NodeType | str, Type["Executor"]] = {} + + +class Executor(ABC): + """ + Base class for node executors. + + Subclasses implement execute() to perform the actual operation. + """ + + @abstractmethod + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + """ + Execute the node operation. + + Args: + config: Node configuration + inputs: Paths to input files (from resolved input nodes) + output_path: Where to write the output + + Returns: + Path to the output file + """ + pass + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + """ + Validate node configuration. + + Returns list of error messages (empty if valid). + Override in subclasses for specific validation. + """ + return [] + + def estimate_output_size(self, config: Dict[str, Any], input_sizes: List[int]) -> int: + """ + Estimate output size in bytes. + + Override for better estimates. Default returns sum of inputs. + """ + return sum(input_sizes) if input_sizes else 0 + + +def register_executor(node_type: NodeType | str) -> Callable: + """ + Decorator to register an executor for a node type. + + Usage: + @register_executor(NodeType.SOURCE) + class SourceExecutor(Executor): + ... + """ + def decorator(cls: Type[Executor]) -> Type[Executor]: + if node_type in _EXECUTORS: + logger.warning(f"Overwriting executor for {node_type}") + _EXECUTORS[node_type] = cls + return cls + return decorator + + +def get_executor(node_type: NodeType | str) -> Optional[Executor]: + """ + Get an executor instance for a node type. + + Returns None if no executor is registered. + """ + executor_cls = _EXECUTORS.get(node_type) + if executor_cls is None: + return None + return executor_cls() + + +def list_executors() -> Dict[str, Type[Executor]]: + """List all registered executors.""" + return { + (k.name if isinstance(k, NodeType) else k): v + for k, v in _EXECUTORS.items() + } + + +def clear_executors(): + """Clear all registered executors (for testing).""" + _EXECUTORS.clear() diff --git a/artdag/nodes/__init__.py b/artdag/nodes/__init__.py new file mode 100644 index 0000000..e821b54 --- /dev/null +++ b/artdag/nodes/__init__.py @@ -0,0 +1,11 @@ +# primitive/nodes/__init__.py +""" +Built-in node executors. + +Import this module to register all built-in executors. +""" + +from . import source +from . import transform +from . import compose +from . import effect diff --git a/artdag/nodes/compose.py b/artdag/nodes/compose.py new file mode 100644 index 0000000..a7121c6 --- /dev/null +++ b/artdag/nodes/compose.py @@ -0,0 +1,548 @@ +# primitive/nodes/compose.py +""" +Compose executors: Combine multiple media inputs. + +Primitives: SEQUENCE, LAYER, MUX, BLEND +""" + +import logging +import shutil +import subprocess +from pathlib import Path +from typing import Any, Dict, List + +from ..dag import NodeType +from ..executor import Executor, register_executor +from .encoding import WEB_ENCODING_ARGS_STR, get_web_encoding_args + +logger = logging.getLogger(__name__) + + +def _get_duration(path: Path) -> float: + """Get media duration in seconds.""" + cmd = [ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + return float(result.stdout.strip()) + + +def _get_video_info(path: Path) -> dict: + """Get video width, height, frame rate, and sample rate.""" + cmd = [ + "ffprobe", "-v", "error", + "-select_streams", "v:0", + "-show_entries", "stream=width,height,r_frame_rate", + "-of", "csv=p=0", + str(path) + ] + result = subprocess.run(cmd, capture_output=True, text=True) + parts = result.stdout.strip().split(",") + width = int(parts[0]) if len(parts) > 0 and parts[0] else 1920 + height = int(parts[1]) if len(parts) > 1 and parts[1] else 1080 + fps_str = parts[2] if len(parts) > 2 else "30/1" + # Parse frame rate (e.g., "30/1" or "30000/1001") + if "/" in fps_str: + num, den = fps_str.split("/") + fps = float(num) / float(den) if float(den) != 0 else 30 + else: + fps = float(fps_str) if fps_str else 30 + + # Get audio sample rate + cmd_audio = [ + "ffprobe", "-v", "error", + "-select_streams", "a:0", + "-show_entries", "stream=sample_rate", + "-of", "csv=p=0", + str(path) + ] + result_audio = subprocess.run(cmd_audio, capture_output=True, text=True) + sample_rate = int(result_audio.stdout.strip()) if result_audio.stdout.strip() else 44100 + + return {"width": width, "height": height, "fps": fps, "sample_rate": sample_rate} + + +@register_executor(NodeType.SEQUENCE) +class SequenceExecutor(Executor): + """ + Concatenate inputs in time order. + + Config: + transition: Transition config + type: "cut" | "crossfade" | "fade" + duration: Transition duration in seconds + target_size: How to determine output dimensions when inputs differ + "first": Use first input's dimensions (default) + "last": Use last input's dimensions + "largest": Use largest width and height from all inputs + "explicit": Use width/height config values + width: Target width (when target_size="explicit") + height: Target height (when target_size="explicit") + background: Padding color for letterbox/pillarbox (default: "black") + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 1: + raise ValueError("SEQUENCE requires at least one input") + + if len(inputs) == 1: + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(inputs[0], output_path) + return output_path + + transition = config.get("transition", {"type": "cut"}) + transition_type = transition.get("type", "cut") + transition_duration = transition.get("duration", 0.5) + + # Size handling config + target_size = config.get("target_size", "first") + width = config.get("width") + height = config.get("height") + background = config.get("background", "black") + + if transition_type == "cut": + return self._concat_cut(inputs, output_path, target_size, width, height, background) + elif transition_type == "crossfade": + return self._concat_crossfade(inputs, output_path, transition_duration) + elif transition_type == "fade": + return self._concat_fade(inputs, output_path, transition_duration) + else: + raise ValueError(f"Unknown transition type: {transition_type}") + + def _concat_cut( + self, + inputs: List[Path], + output_path: Path, + target_size: str = "first", + width: int = None, + height: int = None, + background: str = "black", + ) -> Path: + """ + Concatenate with scaling/padding to handle different resolutions. + + Args: + inputs: Input video paths + output_path: Output path + target_size: How to determine output size: + - "first": Use first input's dimensions (default) + - "last": Use last input's dimensions + - "largest": Use largest dimensions from all inputs + - "explicit": Use width/height params + width: Explicit width (when target_size="explicit") + height: Explicit height (when target_size="explicit") + background: Padding color (default: black) + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + + n = len(inputs) + input_args = [] + for p in inputs: + input_args.extend(["-i", str(p)]) + + # Get video info for all inputs + infos = [_get_video_info(p) for p in inputs] + + # Determine target dimensions + if target_size == "explicit" and width and height: + target_w, target_h = width, height + elif target_size == "last": + target_w, target_h = infos[-1]["width"], infos[-1]["height"] + elif target_size == "largest": + target_w = max(i["width"] for i in infos) + target_h = max(i["height"] for i in infos) + else: # "first" or default + target_w, target_h = infos[0]["width"], infos[0]["height"] + + # Use common frame rate (from first input) and sample rate + target_fps = infos[0]["fps"] + target_sr = max(i["sample_rate"] for i in infos) + + # Build filter for each input: scale to fit + pad to target size + filter_parts = [] + for i in range(n): + # Scale to fit within target, maintaining aspect ratio, then pad + vf = ( + f"[{i}:v]scale={target_w}:{target_h}:force_original_aspect_ratio=decrease," + f"pad={target_w}:{target_h}:(ow-iw)/2:(oh-ih)/2:color={background}," + f"setsar=1,fps={target_fps:.6f}[v{i}]" + ) + # Resample audio to common rate + af = f"[{i}:a]aresample={target_sr}[a{i}]" + filter_parts.append(vf) + filter_parts.append(af) + + # Build concat filter + stream_labels = "".join(f"[v{i}][a{i}]" for i in range(n)) + filter_parts.append(f"{stream_labels}concat=n={n}:v=1:a=1[outv][outa]") + + filter_complex = ";".join(filter_parts) + + cmd = [ + "ffmpeg", "-y", + *input_args, + "-filter_complex", filter_complex, + "-map", "[outv]", + "-map", "[outa]", + *get_web_encoding_args(), + str(output_path) + ] + + logger.debug(f"SEQUENCE cut: {n} clips -> {target_w}x{target_h} (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Concat failed: {result.stderr}") + + return output_path + + def _concat_crossfade( + self, + inputs: List[Path], + output_path: Path, + duration: float, + ) -> Path: + """Concatenate with crossfade transitions.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + durations = [_get_duration(p) for p in inputs] + n = len(inputs) + input_args = " ".join(f"-i {p}" for p in inputs) + + # Build xfade filter chain + filter_parts = [] + current = "[0:v]" + + for i in range(1, n): + offset = sum(durations[:i]) - duration * i + next_input = f"[{i}:v]" + output_label = f"[v{i}]" if i < n - 1 else "[outv]" + filter_parts.append( + f"{current}{next_input}xfade=transition=fade:duration={duration}:offset={offset}{output_label}" + ) + current = output_label + + # Audio crossfade chain + audio_current = "[0:a]" + for i in range(1, n): + next_input = f"[{i}:a]" + output_label = f"[a{i}]" if i < n - 1 else "[outa]" + filter_parts.append( + f"{audio_current}{next_input}acrossfade=d={duration}{output_label}" + ) + audio_current = output_label + + filter_complex = ";".join(filter_parts) + + cmd = f'ffmpeg -y {input_args} -filter_complex "{filter_complex}" -map [outv] -map [outa] {WEB_ENCODING_ARGS_STR} {output_path}' + + logger.debug(f"SEQUENCE crossfade: {n} clips (web-optimized)") + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + if result.returncode != 0: + logger.warning(f"Crossfade failed, falling back to cut: {result.stderr[:200]}") + return self._concat_cut(inputs, output_path) + + return output_path + + def _concat_fade( + self, + inputs: List[Path], + output_path: Path, + duration: float, + ) -> Path: + """Concatenate with fade out/in transitions.""" + output_path.parent.mkdir(parents=True, exist_ok=True) + + faded_paths = [] + for i, path in enumerate(inputs): + clip_dur = _get_duration(path) + faded_path = output_path.parent / f"_faded_{i}.mkv" + + cmd = [ + "ffmpeg", "-y", + "-i", str(path), + "-vf", f"fade=in:st=0:d={duration},fade=out:st={clip_dur - duration}:d={duration}", + "-af", f"afade=in:st=0:d={duration},afade=out:st={clip_dur - duration}:d={duration}", + "-c:v", "libx264", "-preset", "ultrafast", "-crf", "18", + "-c:a", "aac", + str(faded_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + faded_paths.append(faded_path) + + result = self._concat_cut(faded_paths, output_path) + + for p in faded_paths: + p.unlink() + + return result + + +@register_executor(NodeType.LAYER) +class LayerExecutor(Executor): + """ + Layer inputs spatially (overlay/composite). + + Config: + inputs: List of per-input configs + position: [x, y] offset + opacity: 0.0-1.0 + scale: Scale factor + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 1: + raise ValueError("LAYER requires at least one input") + + if len(inputs) == 1: + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(inputs[0], output_path) + return output_path + + input_configs = config.get("inputs", [{}] * len(inputs)) + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_args = " ".join(f"-i {p}" for p in inputs) + n = len(inputs) + filter_parts = [] + current = "[0:v]" + + for i in range(1, n): + cfg = input_configs[i] if i < len(input_configs) else {} + x, y = cfg.get("position", [0, 0]) + opacity = cfg.get("opacity", 1.0) + scale = cfg.get("scale", 1.0) + + scale_label = f"[s{i}]" + if scale != 1.0: + filter_parts.append(f"[{i}:v]scale=iw*{scale}:ih*{scale}{scale_label}") + overlay_input = scale_label + else: + overlay_input = f"[{i}:v]" + + output_label = f"[v{i}]" if i < n - 1 else "[outv]" + + if opacity < 1.0: + filter_parts.append( + f"{overlay_input}format=rgba,colorchannelmixer=aa={opacity}[a{i}]" + ) + overlay_input = f"[a{i}]" + + filter_parts.append( + f"{current}{overlay_input}overlay=x={x}:y={y}:format=auto{output_label}" + ) + current = output_label + + filter_complex = ";".join(filter_parts) + + cmd = f'ffmpeg -y {input_args} -filter_complex "{filter_complex}" -map [outv] -map 0:a? {WEB_ENCODING_ARGS_STR} {output_path}' + + logger.debug(f"LAYER: {n} inputs (web-optimized)") + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Layer failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.MUX) +class MuxExecutor(Executor): + """ + Combine video and audio streams. + + Config: + video_stream: Index of video input (default: 0) + audio_stream: Index of audio input (default: 1) + shortest: End when shortest stream ends (default: True) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 2: + raise ValueError("MUX requires at least 2 inputs (video + audio)") + + video_idx = config.get("video_stream", 0) + audio_idx = config.get("audio_stream", 1) + shortest = config.get("shortest", True) + + video_path = inputs[video_idx] + audio_path = inputs[audio_idx] + + output_path.parent.mkdir(parents=True, exist_ok=True) + + cmd = [ + "ffmpeg", "-y", + "-i", str(video_path), + "-i", str(audio_path), + "-c:v", "copy", + "-c:a", "aac", + "-map", "0:v:0", + "-map", "1:a:0", + ] + + if shortest: + cmd.append("-shortest") + + cmd.append(str(output_path)) + + logger.debug(f"MUX: video={video_path.name} + audio={audio_path.name}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Mux failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.BLEND) +class BlendExecutor(Executor): + """ + Blend two inputs using a blend mode. + + Config: + mode: Blend mode (multiply, screen, overlay, add, etc.) + opacity: 0.0-1.0 for second input + """ + + BLEND_MODES = { + "multiply": "multiply", + "screen": "screen", + "overlay": "overlay", + "add": "addition", + "subtract": "subtract", + "average": "average", + "difference": "difference", + "lighten": "lighten", + "darken": "darken", + } + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 2: + raise ValueError("BLEND requires exactly 2 inputs") + + mode = config.get("mode", "overlay") + opacity = config.get("opacity", 0.5) + + if mode not in self.BLEND_MODES: + raise ValueError(f"Unknown blend mode: {mode}") + + output_path.parent.mkdir(parents=True, exist_ok=True) + blend_mode = self.BLEND_MODES[mode] + + if opacity < 1.0: + filter_complex = ( + f"[1:v]format=rgba,colorchannelmixer=aa={opacity}[b];" + f"[0:v][b]blend=all_mode={blend_mode}" + ) + else: + filter_complex = f"[0:v][1:v]blend=all_mode={blend_mode}" + + cmd = [ + "ffmpeg", "-y", + "-i", str(inputs[0]), + "-i", str(inputs[1]), + "-filter_complex", filter_complex, + "-map", "0:a?", + *get_web_encoding_args(), + str(output_path) + ] + + logger.debug(f"BLEND: {mode} (opacity={opacity}) (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Blend failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.AUDIO_MIX) +class AudioMixExecutor(Executor): + """ + Mix multiple audio streams. + + Config: + gains: List of gain values per input (0.0-2.0, default 1.0) + normalize: Normalize output to prevent clipping (default True) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) < 2: + raise ValueError("AUDIO_MIX requires at least 2 inputs") + + gains = config.get("gains", [1.0] * len(inputs)) + normalize = config.get("normalize", True) + + # Pad gains list if too short + while len(gains) < len(inputs): + gains.append(1.0) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Build filter: apply volume to each input, then mix + filter_parts = [] + mix_inputs = [] + + for i, gain in enumerate(gains[:len(inputs)]): + if gain != 1.0: + filter_parts.append(f"[{i}:a]volume={gain}[a{i}]") + mix_inputs.append(f"[a{i}]") + else: + mix_inputs.append(f"[{i}:a]") + + # amix filter + normalize_flag = 1 if normalize else 0 + mix_filter = f"{''.join(mix_inputs)}amix=inputs={len(inputs)}:normalize={normalize_flag}[aout]" + filter_parts.append(mix_filter) + + filter_complex = ";".join(filter_parts) + + cmd = [ + "ffmpeg", "-y", + ] + for p in inputs: + cmd.extend(["-i", str(p)]) + + cmd.extend([ + "-filter_complex", filter_complex, + "-map", "[aout]", + "-c:a", "aac", + str(output_path) + ]) + + logger.debug(f"AUDIO_MIX: {len(inputs)} inputs, gains={gains[:len(inputs)]}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Audio mix failed: {result.stderr}") + + return output_path diff --git a/artdag/nodes/effect.py b/artdag/nodes/effect.py new file mode 100644 index 0000000..7b36a3d --- /dev/null +++ b/artdag/nodes/effect.py @@ -0,0 +1,520 @@ +# artdag/nodes/effect.py +""" +Effect executor: Apply effects from the registry or IPFS. + +Primitives: EFFECT + +Effects can be: +1. Built-in (registered with @register_effect) +2. Stored in IPFS (referenced by CID) +""" + +import importlib.util +import logging +import os +import re +import shutil +import tempfile +from pathlib import Path +from types import ModuleType +from typing import Any, Callable, Dict, List, Optional, TypeVar + +import requests + +from ..executor import Executor, register_executor + +logger = logging.getLogger(__name__) + +# Type alias for effect functions: (input_path, output_path, config) -> output_path +EffectFn = Callable[[Path, Path, Dict[str, Any]], Path] + +# Type variable for decorator +F = TypeVar("F", bound=Callable[..., Any]) + +# IPFS API multiaddr - same as ipfs_client.py for consistency +# Docker uses /dns/ipfs/tcp/5001, local dev uses /ip4/127.0.0.1/tcp/5001 +IPFS_API = os.environ.get("IPFS_API", "/ip4/127.0.0.1/tcp/5001") + +# Connection timeout in seconds +IPFS_TIMEOUT = int(os.environ.get("IPFS_TIMEOUT", "30")) + + +def _get_ipfs_base_url() -> str: + """ + Convert IPFS multiaddr to HTTP URL. + + Matches the conversion logic in ipfs_client.py for consistency. + """ + multiaddr = IPFS_API + + # Handle /dns/hostname/tcp/port format (Docker) + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format (local) + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +def _get_effects_cache_dir() -> Optional[Path]: + """Get the effects cache directory from environment or default.""" + # Check both env var names (CACHE_DIR used by art-celery, ARTDAG_CACHE_DIR for standalone) + for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]: + cache_dir = os.environ.get(env_var) + if cache_dir: + effects_dir = Path(cache_dir) / "_effects" + if effects_dir.exists(): + return effects_dir + + # Try default locations + for base in [Path.home() / ".artdag" / "cache", Path("/var/cache/artdag")]: + effects_dir = base / "_effects" + if effects_dir.exists(): + return effects_dir + + return None + + +def _fetch_effect_from_ipfs(cid: str, effect_path: Path) -> bool: + """ + Fetch an effect from IPFS and cache locally. + + Uses the IPFS API endpoint (/api/v0/cat) for consistency with ipfs_client.py. + This works reliably in Docker where IPFS_API=/dns/ipfs/tcp/5001. + + Returns True on success, False on failure. + """ + try: + # Use IPFS API (same as ipfs_client.py) + base_url = _get_ipfs_base_url() + url = f"{base_url}/api/v0/cat" + params = {"arg": cid} + + response = requests.post(url, params=params, timeout=IPFS_TIMEOUT) + response.raise_for_status() + + # Cache locally + effect_path.parent.mkdir(parents=True, exist_ok=True) + effect_path.write_bytes(response.content) + logger.info(f"Fetched effect from IPFS: {cid[:16]}...") + return True + + except Exception as e: + logger.error(f"Failed to fetch effect from IPFS {cid[:16]}...: {e}") + return False + + +def _parse_pep723_dependencies(source: str) -> List[str]: + """ + Parse PEP 723 dependencies from effect source code. + + Returns list of package names (e.g., ["numpy", "opencv-python"]). + """ + match = re.search(r"# /// script\n(.*?)# ///", source, re.DOTALL) + if not match: + return [] + + block = match.group(1) + deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL) + if not deps_match: + return [] + + return re.findall(r'"([^"]+)"', deps_match.group(1)) + + +def _ensure_dependencies(dependencies: List[str], effect_cid: str) -> bool: + """ + Ensure effect dependencies are installed. + + Installs missing packages using pip. Returns True on success. + """ + if not dependencies: + return True + + missing = [] + for dep in dependencies: + # Extract package name (strip version specifiers) + pkg_name = re.split(r'[<>=!]', dep)[0].strip() + # Normalize name (pip uses underscores, imports use underscores or hyphens) + pkg_name_normalized = pkg_name.replace('-', '_').lower() + + try: + __import__(pkg_name_normalized) + except ImportError: + # Some packages have different import names + try: + # Try original name with hyphens replaced + __import__(pkg_name.replace('-', '_')) + except ImportError: + missing.append(dep) + + if not missing: + return True + + logger.info(f"Installing effect dependencies for {effect_cid[:16]}...: {missing}") + + try: + import subprocess + import sys + + result = subprocess.run( + [sys.executable, "-m", "pip", "install", "--quiet"] + missing, + capture_output=True, + text=True, + timeout=120, + ) + + if result.returncode != 0: + logger.error(f"Failed to install dependencies: {result.stderr}") + return False + + logger.info(f"Installed dependencies: {missing}") + return True + + except Exception as e: + logger.error(f"Error installing dependencies: {e}") + return False + + +def _load_cached_effect(effect_cid: str) -> Optional[EffectFn]: + """ + Load an effect by CID, fetching from IPFS if not cached locally. + + Returns the effect function or None if not found. + """ + effects_dir = _get_effects_cache_dir() + + # Create cache dir if needed + if not effects_dir: + # Try to create default cache dir + for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]: + cache_dir = os.environ.get(env_var) + if cache_dir: + effects_dir = Path(cache_dir) / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + break + + if not effects_dir: + effects_dir = Path.home() / ".artdag" / "cache" / "_effects" + effects_dir.mkdir(parents=True, exist_ok=True) + + effect_path = effects_dir / effect_cid / "effect.py" + + # If not cached locally, fetch from IPFS + if not effect_path.exists(): + if not _fetch_effect_from_ipfs(effect_cid, effect_path): + logger.warning(f"Effect not found: {effect_cid[:16]}...") + return None + + # Parse and install dependencies before loading + try: + source = effect_path.read_text() + dependencies = _parse_pep723_dependencies(source) + if dependencies: + logger.info(f"Effect {effect_cid[:16]}... requires: {dependencies}") + if not _ensure_dependencies(dependencies, effect_cid): + logger.error(f"Failed to install dependencies for effect {effect_cid[:16]}...") + return None + except Exception as e: + logger.error(f"Error parsing effect dependencies: {e}") + # Continue anyway - the effect might work without the deps check + + # Load the effect module + try: + spec = importlib.util.spec_from_file_location("cached_effect", effect_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Check for frame-by-frame API + if hasattr(module, "process_frame"): + return _wrap_frame_effect(module, effect_path) + + # Check for whole-video API + if hasattr(module, "process"): + return _wrap_video_effect(module) + + # Check for old-style effect function + if hasattr(module, "effect"): + return module.effect + + logger.warning(f"Effect has no recognized API: {effect_cid[:16]}...") + return None + + except Exception as e: + logger.error(f"Failed to load effect {effect_cid[:16]}...: {e}") + return None + + +def _wrap_frame_effect(module: ModuleType, effect_path: Path) -> EffectFn: + """Wrap a frame-by-frame effect to work with the executor API.""" + + def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run frame-by-frame effect through FFmpeg pipes.""" + try: + from ..effects.frame_processor import process_video + except ImportError: + logger.error("Frame processor not available - falling back to copy") + shutil.copy2(input_path, output_path) + return output_path + + # Extract params from config (excluding internal keys) + params = {k: v for k, v in config.items() + if k not in ("effect", "hash", "_binding")} + + # Get bindings if present + bindings = {} + for key, value in config.items(): + if isinstance(value, dict) and value.get("_resolved_values"): + bindings[key] = value["_resolved_values"] + + output_path.parent.mkdir(parents=True, exist_ok=True) + actual_output = output_path.with_suffix(".mp4") + + process_video( + input_path=input_path, + output_path=actual_output, + process_frame=module.process_frame, + params=params, + bindings=bindings, + ) + + return actual_output + + return wrapped_effect + + +def _wrap_video_effect(module: ModuleType) -> EffectFn: + """Wrap a whole-video effect to work with the executor API.""" + + def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run whole-video effect.""" + from ..effects.meta import ExecutionContext + + params = {k: v for k, v in config.items() + if k not in ("effect", "hash", "_binding")} + + output_path.parent.mkdir(parents=True, exist_ok=True) + + ctx = ExecutionContext( + input_paths=[str(input_path)], + output_path=str(output_path), + params=params, + seed=hash(str(input_path)) & 0xFFFFFFFF, + ) + + module.process([input_path], output_path, params, ctx) + return output_path + + return wrapped_effect + + +# Effect registry - maps effect names to implementations +_EFFECTS: Dict[str, EffectFn] = {} + + +def register_effect(name: str) -> Callable[[F], F]: + """Decorator to register an effect implementation.""" + def decorator(func: F) -> F: + _EFFECTS[name] = func # type: ignore[assignment] + return func + return decorator + + +def get_effect(name: str) -> Optional[EffectFn]: + """Get an effect implementation by name.""" + return _EFFECTS.get(name) + + +# Built-in effects + +@register_effect("identity") +def effect_identity(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """ + Identity effect - returns input unchanged. + + This is the foundational effect: identity(x) = x + """ + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Remove existing output if any + if output_path.exists() or output_path.is_symlink(): + output_path.unlink() + + # Preserve extension from input + actual_output = output_path.with_suffix(input_path.suffix) + if actual_output.exists() or actual_output.is_symlink(): + actual_output.unlink() + + # Symlink to input (zero-copy identity) + os.symlink(input_path.resolve(), actual_output) + logger.debug(f"EFFECT identity: {input_path.name} -> {actual_output}") + + return actual_output + + +def _get_sexp_effect(effect_path: str, recipe_dir: Path = None) -> Optional[EffectFn]: + """ + Load a sexp effect from a .sexp file. + + Args: + effect_path: Relative path to the .sexp effect file + recipe_dir: Base directory for resolving paths + + Returns: + Effect function or None if not a sexp effect + """ + if not effect_path or not effect_path.endswith(".sexp"): + return None + + try: + from ..sexp.effect_loader import SexpEffectLoader + except ImportError: + logger.warning("Sexp effect loader not available") + return None + + try: + loader = SexpEffectLoader(recipe_dir or Path.cwd()) + return loader.load_effect_path(effect_path) + except Exception as e: + logger.error(f"Failed to load sexp effect from {effect_path}: {e}") + return None + + +def _get_python_primitive_effect(effect_name: str) -> Optional[EffectFn]: + """ + Get a Python primitive frame processor effect. + + Checks if the effect has a python_primitive in FFmpegCompiler.EFFECT_MAPPINGS + and wraps it for the executor API. + """ + try: + from ..sexp.ffmpeg_compiler import FFmpegCompiler + from ..sexp.primitives import get_primitive + from ..effects.frame_processor import process_video + except ImportError: + return None + + compiler = FFmpegCompiler() + primitive_name = compiler.has_python_primitive(effect_name) + if not primitive_name: + return None + + primitive_fn = get_primitive(primitive_name) + if not primitive_fn: + logger.warning(f"Python primitive '{primitive_name}' not found for effect '{effect_name}'") + return None + + def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run Python primitive effect via frame processor.""" + # Extract params (excluding internal keys) + params = {k: v for k, v in config.items() + if k not in ("effect", "cid", "hash", "effect_path", "_binding")} + + # Get bindings if present + bindings = {} + for key, value in config.items(): + if isinstance(value, dict) and value.get("_resolved_values"): + bindings[key] = value["_resolved_values"] + + output_path.parent.mkdir(parents=True, exist_ok=True) + actual_output = output_path.with_suffix(".mp4") + + # Wrap primitive to match frame processor signature + def process_frame(frame, frame_params, state): + # Call primitive with frame and params + result = primitive_fn(frame, **frame_params) + return result, state + + process_video( + input_path=input_path, + output_path=actual_output, + process_frame=process_frame, + params=params, + bindings=bindings, + ) + + logger.info(f"Processed effect '{effect_name}' via Python primitive '{primitive_name}'") + return actual_output + + return wrapped_effect + + +@register_executor("EFFECT") +class EffectExecutor(Executor): + """ + Apply an effect from the registry or IPFS. + + Config: + effect: Name of the effect to apply + cid: IPFS CID for the effect (fetched from IPFS if not cached) + hash: Legacy alias for cid (backwards compatibility) + params: Optional parameters for the effect + + Inputs: + Single input file to transform + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + effect_name = config.get("effect") + # Support both "cid" (new) and "hash" (legacy) + effect_cid = config.get("cid") or config.get("hash") + + if not effect_name: + raise ValueError("EFFECT requires 'effect' config") + + if len(inputs) != 1: + raise ValueError(f"EFFECT expects 1 input, got {len(inputs)}") + + # Try IPFS effect first if CID provided + effect_fn: Optional[EffectFn] = None + if effect_cid: + effect_fn = _load_cached_effect(effect_cid) + if effect_fn: + logger.info(f"Running effect '{effect_name}' (cid={effect_cid[:16]}...)") + + # Try sexp effect from effect_path (.sexp file) + if effect_fn is None: + effect_path = config.get("effect_path") + if effect_path and effect_path.endswith(".sexp"): + effect_fn = _get_sexp_effect(effect_path) + if effect_fn: + logger.info(f"Running effect '{effect_name}' via sexp definition") + + # Try Python primitive (from FFmpegCompiler.EFFECT_MAPPINGS) + if effect_fn is None: + effect_fn = _get_python_primitive_effect(effect_name) + if effect_fn: + logger.info(f"Running effect '{effect_name}' via Python primitive") + + # Fall back to built-in effect + if effect_fn is None: + effect_fn = get_effect(effect_name) + + if effect_fn is None: + raise ValueError(f"Unknown effect: {effect_name}") + + # Pass full config (effect can extract what it needs) + return effect_fn(inputs[0], output_path, config) + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + errors = [] + if "effect" not in config: + errors.append("EFFECT requires 'effect' config") + else: + # If CID provided, we'll load from IPFS - skip built-in check + has_cid = config.get("cid") or config.get("hash") + if not has_cid and get_effect(config["effect"]) is None: + errors.append(f"Unknown effect: {config['effect']}") + return errors diff --git a/artdag/nodes/encoding.py b/artdag/nodes/encoding.py new file mode 100644 index 0000000..863d062 --- /dev/null +++ b/artdag/nodes/encoding.py @@ -0,0 +1,50 @@ +# artdag/nodes/encoding.py +""" +Web-optimized video encoding settings. + +Provides common FFmpeg arguments for producing videos that: +- Stream efficiently (faststart) +- Play on all browsers (H.264 High profile) +- Support seeking (regular keyframes) +""" + +from typing import List + +# Standard web-optimized video encoding arguments +WEB_VIDEO_ARGS: List[str] = [ + "-c:v", "libx264", + "-preset", "fast", + "-crf", "18", + "-profile:v", "high", + "-level", "4.1", + "-pix_fmt", "yuv420p", # Ensure broad compatibility + "-movflags", "+faststart", # Enable streaming before full download + "-g", "48", # Keyframe every ~2 seconds at 24fps (for seeking) +] + +# Standard audio encoding arguments +WEB_AUDIO_ARGS: List[str] = [ + "-c:a", "aac", + "-b:a", "192k", +] + + +def get_web_encoding_args() -> List[str]: + """Get FFmpeg args for web-optimized video+audio encoding.""" + return WEB_VIDEO_ARGS + WEB_AUDIO_ARGS + + +def get_web_video_args() -> List[str]: + """Get FFmpeg args for web-optimized video encoding only.""" + return WEB_VIDEO_ARGS.copy() + + +def get_web_audio_args() -> List[str]: + """Get FFmpeg args for web-optimized audio encoding only.""" + return WEB_AUDIO_ARGS.copy() + + +# For shell commands (string format) +WEB_VIDEO_ARGS_STR = " ".join(WEB_VIDEO_ARGS) +WEB_AUDIO_ARGS_STR = " ".join(WEB_AUDIO_ARGS) +WEB_ENCODING_ARGS_STR = f"{WEB_VIDEO_ARGS_STR} {WEB_AUDIO_ARGS_STR}" diff --git a/artdag/nodes/source.py b/artdag/nodes/source.py new file mode 100644 index 0000000..1fc7ef1 --- /dev/null +++ b/artdag/nodes/source.py @@ -0,0 +1,62 @@ +# primitive/nodes/source.py +""" +Source executors: Load media from paths. + +Primitives: SOURCE +""" + +import logging +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List + +from ..dag import NodeType +from ..executor import Executor, register_executor + +logger = logging.getLogger(__name__) + + +@register_executor(NodeType.SOURCE) +class SourceExecutor(Executor): + """ + Load source media from a path. + + Config: + path: Path to source file + + Creates a symlink to the source file for zero-copy loading. + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + source_path = Path(config["path"]) + + if not source_path.exists(): + raise FileNotFoundError(f"Source file not found: {source_path}") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Use symlink for zero-copy + if output_path.exists() or output_path.is_symlink(): + output_path.unlink() + + # Preserve extension from source + actual_output = output_path.with_suffix(source_path.suffix) + if actual_output.exists() or actual_output.is_symlink(): + actual_output.unlink() + + os.symlink(source_path.resolve(), actual_output) + logger.debug(f"SOURCE: {source_path.name} -> {actual_output}") + + return actual_output + + def validate_config(self, config: Dict[str, Any]) -> List[str]: + errors = [] + if "path" not in config: + errors.append("SOURCE requires 'path' config") + return errors diff --git a/artdag/nodes/transform.py b/artdag/nodes/transform.py new file mode 100644 index 0000000..e91ba6f --- /dev/null +++ b/artdag/nodes/transform.py @@ -0,0 +1,224 @@ +# primitive/nodes/transform.py +""" +Transform executors: Modify single media inputs. + +Primitives: SEGMENT, RESIZE, TRANSFORM +""" + +import logging +import subprocess +from pathlib import Path +from typing import Any, Dict, List + +from ..dag import NodeType +from ..executor import Executor, register_executor +from .encoding import get_web_encoding_args, get_web_video_args + +logger = logging.getLogger(__name__) + + +@register_executor(NodeType.SEGMENT) +class SegmentExecutor(Executor): + """ + Extract a time segment from media. + + Config: + offset: Start time in seconds (default: 0) + duration: Duration in seconds (optional, default: to end) + precise: Use frame-accurate seeking (default: True) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 1: + raise ValueError("SEGMENT requires exactly one input") + + input_path = inputs[0] + offset = config.get("offset", 0) + duration = config.get("duration") + precise = config.get("precise", True) + + output_path.parent.mkdir(parents=True, exist_ok=True) + + if precise: + # Frame-accurate: decode-seek (slower but precise) + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if offset > 0: + cmd.extend(["-ss", str(offset)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend([*get_web_encoding_args(), str(output_path)]) + else: + # Fast: input-seek at keyframes (may be slightly off) + cmd = ["ffmpeg", "-y"] + if offset > 0: + cmd.extend(["-ss", str(offset)]) + cmd.extend(["-i", str(input_path)]) + if duration: + cmd.extend(["-t", str(duration)]) + cmd.extend(["-c", "copy", str(output_path)]) + + logger.debug(f"SEGMENT: offset={offset}, duration={duration}, precise={precise}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Segment failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.RESIZE) +class ResizeExecutor(Executor): + """ + Resize media to target dimensions. + + Config: + width: Target width + height: Target height + mode: "fit" (letterbox), "fill" (crop), "stretch", "pad" + background: Background color for pad mode (default: black) + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 1: + raise ValueError("RESIZE requires exactly one input") + + input_path = inputs[0] + width = config["width"] + height = config["height"] + mode = config.get("mode", "fit") + background = config.get("background", "black") + + output_path.parent.mkdir(parents=True, exist_ok=True) + + if mode == "fit": + # Scale to fit, add letterboxing + vf = f"scale={width}:{height}:force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:color={background}" + elif mode == "fill": + # Scale to fill, crop excess + vf = f"scale={width}:{height}:force_original_aspect_ratio=increase,crop={width}:{height}" + elif mode == "stretch": + # Stretch to exact size + vf = f"scale={width}:{height}" + elif mode == "pad": + # Scale down only if larger, then pad + vf = f"scale='min({width},iw)':'min({height},ih)':force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:color={background}" + else: + raise ValueError(f"Unknown resize mode: {mode}") + + cmd = [ + "ffmpeg", "-y", + "-i", str(input_path), + "-vf", vf, + *get_web_video_args(), + "-c:a", "copy", + str(output_path) + ] + + logger.debug(f"RESIZE: {width}x{height} ({mode}) (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Resize failed: {result.stderr}") + + return output_path + + +@register_executor(NodeType.TRANSFORM) +class TransformExecutor(Executor): + """ + Apply visual effects to media. + + Config: + effects: Dict of effect -> value + saturation: 0.0-2.0 (1.0 = normal) + contrast: 0.0-2.0 (1.0 = normal) + brightness: -1.0 to 1.0 (0.0 = normal) + gamma: 0.1-10.0 (1.0 = normal) + hue: degrees shift + blur: blur radius + sharpen: sharpen amount + speed: playback speed multiplier + """ + + def execute( + self, + config: Dict[str, Any], + inputs: List[Path], + output_path: Path, + ) -> Path: + if len(inputs) != 1: + raise ValueError("TRANSFORM requires exactly one input") + + input_path = inputs[0] + effects = config.get("effects", {}) + + if not effects: + # No effects - just copy + import shutil + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(input_path, output_path) + return output_path + + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Build filter chain + vf_parts = [] + af_parts = [] + + # Color adjustments via eq filter + eq_parts = [] + if "saturation" in effects: + eq_parts.append(f"saturation={effects['saturation']}") + if "contrast" in effects: + eq_parts.append(f"contrast={effects['contrast']}") + if "brightness" in effects: + eq_parts.append(f"brightness={effects['brightness']}") + if "gamma" in effects: + eq_parts.append(f"gamma={effects['gamma']}") + if eq_parts: + vf_parts.append(f"eq={':'.join(eq_parts)}") + + # Hue adjustment + if "hue" in effects: + vf_parts.append(f"hue=h={effects['hue']}") + + # Blur + if "blur" in effects: + vf_parts.append(f"boxblur={effects['blur']}") + + # Sharpen + if "sharpen" in effects: + vf_parts.append(f"unsharp=5:5:{effects['sharpen']}:5:5:0") + + # Speed change + if "speed" in effects: + speed = effects["speed"] + vf_parts.append(f"setpts={1/speed}*PTS") + af_parts.append(f"atempo={speed}") + + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + + if vf_parts: + cmd.extend(["-vf", ",".join(vf_parts)]) + if af_parts: + cmd.extend(["-af", ",".join(af_parts)]) + + cmd.extend([*get_web_encoding_args(), str(output_path)]) + + logger.debug(f"TRANSFORM: {list(effects.keys())} (web-optimized)") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + raise RuntimeError(f"Transform failed: {result.stderr}") + + return output_path diff --git a/artdag/planning/__init__.py b/artdag/planning/__init__.py new file mode 100644 index 0000000..1d5c89f --- /dev/null +++ b/artdag/planning/__init__.py @@ -0,0 +1,29 @@ +# artdag/planning - Execution plan generation +# +# Provides the Planning phase of the 3-phase execution model: +# 1. ANALYZE - Extract features from inputs +# 2. PLAN - Generate execution plan with cache IDs +# 3. EXECUTE - Run steps with caching + +from .schema import ( + ExecutionStep, + ExecutionPlan, + StepStatus, + StepOutput, + StepInput, + PlanInput, +) +from .planner import RecipePlanner, Recipe +from .tree_reduction import TreeReducer + +__all__ = [ + "ExecutionStep", + "ExecutionPlan", + "StepStatus", + "StepOutput", + "StepInput", + "PlanInput", + "RecipePlanner", + "Recipe", + "TreeReducer", +] diff --git a/artdag/planning/planner.py b/artdag/planning/planner.py new file mode 100644 index 0000000..18f30d8 --- /dev/null +++ b/artdag/planning/planner.py @@ -0,0 +1,756 @@ +# artdag/planning/planner.py +""" +Recipe planner - converts recipes into execution plans. + +The planner is the second phase of the 3-phase execution model. +It takes a recipe and analysis results and generates a complete +execution plan with pre-computed cache IDs. +""" + +import hashlib +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml + +from .schema import ExecutionPlan, ExecutionStep, StepOutput, StepInput, PlanInput +from .tree_reduction import TreeReducer, reduce_sequence +from ..analysis import AnalysisResult + + +def _infer_media_type(node_type: str, config: Dict[str, Any] = None) -> str: + """Infer media type from node type and config.""" + config = config or {} + + # Audio operations + if node_type in ("AUDIO", "MIX_AUDIO", "EXTRACT_AUDIO"): + return "audio/wav" + if "audio" in node_type.lower(): + return "audio/wav" + + # Image operations + if node_type in ("FRAME", "THUMBNAIL", "IMAGE"): + return "image/png" + + # Default to video + return "video/mp4" + +logger = logging.getLogger(__name__) + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """Create stable hash from arbitrary data.""" + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +@dataclass +class RecipeNode: + """A node in the recipe DAG.""" + id: str + type: str + config: Dict[str, Any] + inputs: List[str] + + +@dataclass +class Recipe: + """Parsed recipe structure.""" + name: str + version: str + description: str + nodes: List[RecipeNode] + output: str + registry: Dict[str, Any] + owner: str + raw_yaml: str + + @property + def recipe_hash(self) -> str: + """Compute hash of recipe content.""" + return _stable_hash({"yaml": self.raw_yaml}) + + @classmethod + def from_yaml(cls, yaml_content: str) -> "Recipe": + """Parse recipe from YAML string.""" + data = yaml.safe_load(yaml_content) + + nodes = [] + for node_data in data.get("dag", {}).get("nodes", []): + # Handle both 'inputs' as list and 'inputs' as dict + inputs = node_data.get("inputs", []) + if isinstance(inputs, dict): + # Extract input references from dict structure + input_list = [] + for key, value in inputs.items(): + if isinstance(value, str): + input_list.append(value) + elif isinstance(value, list): + input_list.extend(value) + inputs = input_list + elif isinstance(inputs, str): + inputs = [inputs] + + nodes.append(RecipeNode( + id=node_data["id"], + type=node_data["type"], + config=node_data.get("config", {}), + inputs=inputs, + )) + + return cls( + name=data.get("name", "unnamed"), + version=data.get("version", "1.0"), + description=data.get("description", ""), + nodes=nodes, + output=data.get("dag", {}).get("output", ""), + registry=data.get("registry", {}), + owner=data.get("owner", ""), + raw_yaml=yaml_content, + ) + + @classmethod + def from_file(cls, path: Path) -> "Recipe": + """Load recipe from YAML file.""" + with open(path, "r") as f: + return cls.from_yaml(f.read()) + + +class RecipePlanner: + """ + Generates execution plans from recipes. + + The planner: + 1. Parses the recipe + 2. Resolves fixed inputs from registry + 3. Maps variable inputs to provided hashes + 4. Expands MAP/iteration nodes + 5. Applies tree reduction for SEQUENCE nodes + 6. Computes cache IDs for all steps + """ + + def __init__(self, use_tree_reduction: bool = True): + """ + Initialize the planner. + + Args: + use_tree_reduction: Whether to use tree reduction for SEQUENCE + """ + self.use_tree_reduction = use_tree_reduction + + def plan( + self, + recipe: Recipe, + input_hashes: Dict[str, str], + analysis: Optional[Dict[str, AnalysisResult]] = None, + seed: Optional[int] = None, + ) -> ExecutionPlan: + """ + Generate an execution plan from a recipe. + + Args: + recipe: The parsed recipe + input_hashes: Mapping from input name to content hash + analysis: Analysis results for inputs (keyed by hash) + seed: Random seed for deterministic planning + + Returns: + ExecutionPlan with pre-computed cache IDs + """ + logger.info(f"Planning recipe: {recipe.name}") + + # Build node lookup + nodes_by_id = {n.id: n for n in recipe.nodes} + + # Topologically sort nodes + sorted_ids = self._topological_sort(recipe.nodes) + + # Resolve registry references + registry_hashes = self._resolve_registry(recipe.registry) + + # Build PlanInput objects from input_hashes + plan_inputs = [] + for name, cid in input_hashes.items(): + # Try to find matching SOURCE node for media type + media_type = "application/octet-stream" + for node in recipe.nodes: + if node.id == name and node.type == "SOURCE": + media_type = _infer_media_type("SOURCE", node.config) + break + + plan_inputs.append(PlanInput( + name=name, + cache_id=cid, + cid=cid, + media_type=media_type, + )) + + # Generate steps + steps = [] + step_id_map = {} # Maps recipe node ID to step ID(s) + step_name_map = {} # Maps recipe node ID to human-readable name + analysis_cache_ids = {} + + for node_id in sorted_ids: + node = nodes_by_id[node_id] + logger.debug(f"Processing node: {node.id} ({node.type})") + + new_steps, output_step_id = self._process_node( + node=node, + step_id_map=step_id_map, + step_name_map=step_name_map, + input_hashes=input_hashes, + registry_hashes=registry_hashes, + analysis=analysis or {}, + recipe_name=recipe.name, + ) + + steps.extend(new_steps) + step_id_map[node_id] = output_step_id + # Track human-readable name for this node + if new_steps: + step_name_map[node_id] = new_steps[-1].name + + # Find output step + output_step = step_id_map.get(recipe.output) + if not output_step: + raise ValueError(f"Output node '{recipe.output}' not found") + + # Determine output name + output_name = f"{recipe.name}.output" + output_step_obj = next((s for s in steps if s.step_id == output_step), None) + if output_step_obj and output_step_obj.outputs: + output_name = output_step_obj.outputs[0].name + + # Build analysis cache IDs + if analysis: + analysis_cache_ids = { + h: a.cache_id for h, a in analysis.items() + if a.cache_id + } + + # Create plan + plan = ExecutionPlan( + plan_id=None, # Computed in __post_init__ + name=f"{recipe.name}_plan", + recipe_id=recipe.name, + recipe_name=recipe.name, + recipe_hash=recipe.recipe_hash, + seed=seed, + inputs=plan_inputs, + steps=steps, + output_step=output_step, + output_name=output_name, + analysis_cache_ids=analysis_cache_ids, + input_hashes=input_hashes, + metadata={ + "recipe_version": recipe.version, + "recipe_description": recipe.description, + "owner": recipe.owner, + }, + ) + + # Compute all cache IDs and then generate outputs + plan.compute_all_cache_ids() + plan.compute_levels() + + # Now add outputs to each step (needs cache_id to be computed first) + self._add_step_outputs(plan, recipe.name) + + # Recompute plan_id after outputs are added + plan.plan_id = plan._compute_plan_id() + + logger.info(f"Generated plan with {len(steps)} steps") + return plan + + def _add_step_outputs(self, plan: ExecutionPlan, recipe_name: str) -> None: + """Add output definitions to each step after cache_ids are computed.""" + for step in plan.steps: + if step.outputs: + continue # Already has outputs + + # Generate output name from step name + base_name = step.name or step.step_id + output_name = f"{recipe_name}.{base_name}.out" + + media_type = _infer_media_type(step.node_type, step.config) + + step.add_output( + name=output_name, + media_type=media_type, + index=0, + metadata={}, + ) + + def plan_from_yaml( + self, + yaml_content: str, + input_hashes: Dict[str, str], + analysis: Optional[Dict[str, AnalysisResult]] = None, + ) -> ExecutionPlan: + """ + Generate plan from YAML string. + + Args: + yaml_content: Recipe YAML content + input_hashes: Mapping from input name to content hash + analysis: Analysis results + + Returns: + ExecutionPlan + """ + recipe = Recipe.from_yaml(yaml_content) + return self.plan(recipe, input_hashes, analysis) + + def plan_from_file( + self, + recipe_path: Path, + input_hashes: Dict[str, str], + analysis: Optional[Dict[str, AnalysisResult]] = None, + ) -> ExecutionPlan: + """ + Generate plan from recipe file. + + Args: + recipe_path: Path to recipe YAML file + input_hashes: Mapping from input name to content hash + analysis: Analysis results + + Returns: + ExecutionPlan + """ + recipe = Recipe.from_file(recipe_path) + return self.plan(recipe, input_hashes, analysis) + + def _topological_sort(self, nodes: List[RecipeNode]) -> List[str]: + """Topologically sort recipe nodes.""" + nodes_by_id = {n.id: n for n in nodes} + visited = set() + order = [] + + def visit(node_id: str): + if node_id in visited: + return + if node_id not in nodes_by_id: + return # External input + visited.add(node_id) + node = nodes_by_id[node_id] + for input_id in node.inputs: + visit(input_id) + order.append(node_id) + + for node in nodes: + visit(node.id) + + return order + + def _resolve_registry(self, registry: Dict[str, Any]) -> Dict[str, str]: + """ + Resolve registry references to content hashes. + + Args: + registry: Registry section from recipe + + Returns: + Mapping from name to content hash + """ + hashes = {} + + # Assets + for name, asset_data in registry.get("assets", {}).items(): + if isinstance(asset_data, dict) and "hash" in asset_data: + hashes[name] = asset_data["hash"] + elif isinstance(asset_data, str): + hashes[name] = asset_data + + # Effects + for name, effect_data in registry.get("effects", {}).items(): + if isinstance(effect_data, dict) and "hash" in effect_data: + hashes[f"effect:{name}"] = effect_data["hash"] + elif isinstance(effect_data, str): + hashes[f"effect:{name}"] = effect_data + + return hashes + + def _process_node( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + step_name_map: Dict[str, str], + input_hashes: Dict[str, str], + registry_hashes: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process a recipe node into execution steps. + + Args: + node: Recipe node to process + step_id_map: Mapping from processed node IDs to step IDs + step_name_map: Mapping from node IDs to human-readable names + input_hashes: User-provided input hashes + registry_hashes: Registry-resolved hashes + analysis: Analysis results + recipe_name: Name of the recipe (for generating readable names) + + Returns: + Tuple of (new steps, output step ID) + """ + # SOURCE nodes + if node.type == "SOURCE": + return self._process_source(node, input_hashes, registry_hashes, recipe_name) + + # SOURCE_LIST nodes + if node.type == "SOURCE_LIST": + return self._process_source_list(node, input_hashes, recipe_name) + + # ANALYZE nodes + if node.type == "ANALYZE": + return self._process_analyze(node, step_id_map, analysis, recipe_name) + + # MAP nodes + if node.type == "MAP": + return self._process_map(node, step_id_map, input_hashes, analysis, recipe_name) + + # SEQUENCE nodes (may use tree reduction) + if node.type == "SEQUENCE": + return self._process_sequence(node, step_id_map, recipe_name) + + # SEGMENT_AT nodes + if node.type == "SEGMENT_AT": + return self._process_segment_at(node, step_id_map, analysis, recipe_name) + + # Standard nodes (SEGMENT, RESIZE, TRANSFORM, LAYER, MUX, BLEND, etc.) + return self._process_standard(node, step_id_map, recipe_name) + + def _process_source( + self, + node: RecipeNode, + input_hashes: Dict[str, str], + registry_hashes: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """Process SOURCE node.""" + config = dict(node.config) + + # Variable input? + if config.get("input"): + # Look up in user-provided inputs + if node.id not in input_hashes: + raise ValueError(f"Missing input for SOURCE node '{node.id}'") + cid = input_hashes[node.id] + # Fixed asset from registry? + elif config.get("asset"): + asset_name = config["asset"] + if asset_name not in registry_hashes: + raise ValueError(f"Asset '{asset_name}' not found in registry") + cid = registry_hashes[asset_name] + else: + raise ValueError(f"SOURCE node '{node.id}' has no input or asset") + + # Human-readable name + display_name = config.get("name", node.id) + step_name = f"{recipe_name}.inputs.{display_name}" if recipe_name else display_name + + step = ExecutionStep( + step_id=node.id, + node_type="SOURCE", + config={"input_ref": node.id, "cid": cid}, + input_steps=[], + cache_id=cid, # SOURCE cache_id is just the content hash + name=step_name, + ) + + return [step], step.step_id + + def _process_source_list( + self, + node: RecipeNode, + input_hashes: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process SOURCE_LIST node. + + Creates individual SOURCE steps for each item in the list. + """ + # Look for list input + if node.id not in input_hashes: + raise ValueError(f"Missing input for SOURCE_LIST node '{node.id}'") + + input_value = input_hashes[node.id] + + # Parse as comma-separated list if string + if isinstance(input_value, str): + items = [h.strip() for h in input_value.split(",")] + else: + items = list(input_value) + + display_name = node.config.get("name", node.id) + base_name = f"{recipe_name}.{display_name}" if recipe_name else display_name + + steps = [] + for i, cid in enumerate(items): + step = ExecutionStep( + step_id=f"{node.id}_{i}", + node_type="SOURCE", + config={"input_ref": f"{node.id}[{i}]", "cid": cid}, + input_steps=[], + cache_id=cid, + name=f"{base_name}[{i}]", + ) + steps.append(step) + + # Return list marker as output + list_step = ExecutionStep( + step_id=node.id, + node_type="_LIST", + config={"items": [s.step_id for s in steps]}, + input_steps=[s.step_id for s in steps], + name=f"{base_name}.list", + ) + steps.append(list_step) + + return steps, list_step.step_id + + def _process_analyze( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process ANALYZE node. + + ANALYZE nodes reference pre-computed analysis results. + """ + input_step = step_id_map.get(node.inputs[0]) if node.inputs else None + if not input_step: + raise ValueError(f"ANALYZE node '{node.id}' has no input") + + feature = node.config.get("feature", "all") + step_name = f"{recipe_name}.analysis.{feature}" if recipe_name else f"analysis.{feature}" + + step = ExecutionStep( + step_id=node.id, + node_type="ANALYZE", + config={ + "feature": feature, + **node.config, + }, + input_steps=[input_step], + name=step_name, + ) + + return [step], step.step_id + + def _process_map( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + input_hashes: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process MAP node - expand iteration over list. + + MAP applies an operation to each item in a list. + """ + operation = node.config.get("operation", "TRANSFORM") + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + + # Get items input + items_ref = node.config.get("items") or ( + node.inputs[0] if isinstance(node.inputs, list) else + node.inputs.get("items") if isinstance(node.inputs, dict) else None + ) + + if not items_ref: + raise ValueError(f"MAP node '{node.id}' has no items input") + + # Resolve items to list of step IDs + if items_ref in step_id_map: + # Reference to SOURCE_LIST output + items_step = step_id_map[items_ref] + # TODO: expand list items + logger.warning(f"MAP node '{node.id}' references list step, expansion TBD") + item_steps = [items_step] + else: + item_steps = [items_ref] + + # Generate step for each item + steps = [] + output_steps = [] + + for i, item_step in enumerate(item_steps): + step_id = f"{node.id}_{i}" + + if operation == "RANDOM_SLICE": + step = ExecutionStep( + step_id=step_id, + node_type="SEGMENT", + config={ + "random": True, + "seed_from": node.config.get("seed_from"), + "index": i, + }, + input_steps=[item_step], + name=f"{base_name}.slice[{i}]", + ) + elif operation == "TRANSFORM": + step = ExecutionStep( + step_id=step_id, + node_type="TRANSFORM", + config=node.config.get("effects", {}), + input_steps=[item_step], + name=f"{base_name}.transform[{i}]", + ) + elif operation == "ANALYZE": + step = ExecutionStep( + step_id=step_id, + node_type="ANALYZE", + config={"feature": node.config.get("feature", "all")}, + input_steps=[item_step], + name=f"{base_name}.analyze[{i}]", + ) + else: + step = ExecutionStep( + step_id=step_id, + node_type=operation, + config=node.config, + input_steps=[item_step], + name=f"{base_name}.{operation.lower()}[{i}]", + ) + + steps.append(step) + output_steps.append(step_id) + + # Create list output + list_step = ExecutionStep( + step_id=node.id, + node_type="_LIST", + config={"items": output_steps}, + input_steps=output_steps, + name=f"{base_name}.results", + ) + steps.append(list_step) + + return steps, list_step.step_id + + def _process_sequence( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process SEQUENCE node. + + Uses tree reduction for parallel composition if enabled. + """ + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + + # Resolve input steps + input_steps = [] + for input_id in node.inputs: + if input_id in step_id_map: + input_steps.append(step_id_map[input_id]) + else: + input_steps.append(input_id) + + if len(input_steps) == 0: + raise ValueError(f"SEQUENCE node '{node.id}' has no inputs") + + if len(input_steps) == 1: + # Single input, no sequence needed + return [], input_steps[0] + + transition_config = node.config.get("transition", {"type": "cut"}) + config = {"transition": transition_config} + + if self.use_tree_reduction and len(input_steps) > 2: + # Use tree reduction + reduction_steps, output_id = reduce_sequence( + input_steps, + transition_config=config, + id_prefix=node.id, + ) + + steps = [] + for i, (step_id, inputs, step_config) in enumerate(reduction_steps): + step = ExecutionStep( + step_id=step_id, + node_type="SEQUENCE", + config=step_config, + input_steps=inputs, + name=f"{base_name}.reduce[{i}]", + ) + steps.append(step) + + return steps, output_id + else: + # Direct sequence + step = ExecutionStep( + step_id=node.id, + node_type="SEQUENCE", + config=config, + input_steps=input_steps, + name=f"{base_name}.concat", + ) + return [step], step.step_id + + def _process_segment_at( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + analysis: Dict[str, AnalysisResult], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """ + Process SEGMENT_AT node - cut at specific times. + + Creates SEGMENT steps for each time range. + """ + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + times_from = node.config.get("times_from") + distribute = node.config.get("distribute", "round_robin") + + # TODO: Resolve times from analysis + # For now, create a placeholder + step = ExecutionStep( + step_id=node.id, + node_type="SEGMENT_AT", + config=node.config, + input_steps=[step_id_map.get(i, i) for i in node.inputs], + name=f"{base_name}.segment", + ) + + return [step], step.step_id + + def _process_standard( + self, + node: RecipeNode, + step_id_map: Dict[str, str], + recipe_name: str = "", + ) -> Tuple[List[ExecutionStep], str]: + """Process standard transformation/composition node.""" + base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id + input_steps = [step_id_map.get(i, i) for i in node.inputs] + + step = ExecutionStep( + step_id=node.id, + node_type=node.type, + config=node.config, + input_steps=input_steps, + name=f"{base_name}.{node.type.lower()}", + ) + + return [step], step.step_id diff --git a/artdag/planning/schema.py b/artdag/planning/schema.py new file mode 100644 index 0000000..9831d16 --- /dev/null +++ b/artdag/planning/schema.py @@ -0,0 +1,594 @@ +# artdag/planning/schema.py +""" +Data structures for execution plans. + +An ExecutionPlan contains all steps needed to execute a recipe, +with pre-computed cache IDs for each step. +""" + +import hashlib +import json +import os +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional + + +# Cluster key for trust domains +# Systems with the same key produce the same cache_ids and can share work +# Systems with different keys have isolated cache namespaces +CLUSTER_KEY: Optional[str] = os.environ.get("ARTDAG_CLUSTER_KEY") + + +def set_cluster_key(key: Optional[str]) -> None: + """Set the cluster key programmatically.""" + global CLUSTER_KEY + CLUSTER_KEY = key + + +def get_cluster_key() -> Optional[str]: + """Get the current cluster key.""" + return CLUSTER_KEY + + +def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str: + """ + Create stable hash from arbitrary data. + + If ARTDAG_CLUSTER_KEY is set, it's mixed into the hash to create + isolated trust domains. Systems with the same key can share work; + systems with different keys have separate cache namespaces. + """ + # Mix in cluster key if set + if CLUSTER_KEY: + data = {"_cluster_key": CLUSTER_KEY, "_data": data} + + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + hasher = hashlib.new(algorithm) + hasher.update(json_str.encode()) + return hasher.hexdigest() + + +class StepStatus(Enum): + """Status of an execution step.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + CACHED = "cached" + FAILED = "failed" + SKIPPED = "skipped" + + +@dataclass +class StepOutput: + """ + A single output from an execution step. + + Nodes may produce multiple outputs (e.g., split_on_beats produces N segments). + Each output has a human-readable name and a cache_id for storage. + + Attributes: + name: Human-readable name (e.g., "beats.split.segment[0]") + cache_id: Content-addressed hash for caching + media_type: MIME type of the output (e.g., "video/mp4", "audio/wav") + index: Output index for multi-output nodes + metadata: Optional additional metadata (time_range, etc.) + """ + name: str + cache_id: str + media_type: str = "application/octet-stream" + index: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "cache_id": self.cache_id, + "media_type": self.media_type, + "index": self.index, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StepOutput": + return cls( + name=data["name"], + cache_id=data["cache_id"], + media_type=data.get("media_type", "application/octet-stream"), + index=data.get("index", 0), + metadata=data.get("metadata", {}), + ) + + +@dataclass +class StepInput: + """ + Reference to an input for a step. + + Inputs can reference outputs from other steps by name. + + Attributes: + name: Input slot name (e.g., "video", "audio", "segments") + source: Source output name (e.g., "beats.split.segment[0]") + cache_id: Resolved cache_id of the source (populated during planning) + """ + name: str + source: str + cache_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "source": self.source, + "cache_id": self.cache_id, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "StepInput": + return cls( + name=data["name"], + source=data["source"], + cache_id=data.get("cache_id"), + ) + + +@dataclass +class ExecutionStep: + """ + A single step in the execution plan. + + Each step has a pre-computed cache_id that uniquely identifies + its output based on its configuration and input cache_ids. + + Steps can produce multiple outputs (e.g., split_on_beats produces N segments). + Each output has its own cache_id derived from the step's cache_id + index. + + Attributes: + name: Human-readable name relating to recipe (e.g., "beats.split") + step_id: Unique identifier (hash) for this step + node_type: The primitive type (SOURCE, SEQUENCE, TRANSFORM, etc.) + config: Configuration for the primitive + input_steps: IDs of steps this depends on (legacy, use inputs for new code) + inputs: Structured input references with names and sources + cache_id: Pre-computed cache ID (hash of config + input cache_ids) + outputs: List of outputs this step produces + estimated_duration: Optional estimated execution time + level: Dependency level (0 = no dependencies, higher = more deps) + """ + step_id: str + node_type: str + config: Dict[str, Any] + input_steps: List[str] = field(default_factory=list) + inputs: List[StepInput] = field(default_factory=list) + cache_id: Optional[str] = None + outputs: List[StepOutput] = field(default_factory=list) + name: Optional[str] = None + estimated_duration: Optional[float] = None + level: int = 0 + + def compute_cache_id(self, input_cache_ids: Dict[str, str]) -> str: + """ + Compute cache ID from configuration and input cache IDs. + + cache_id = SHA3-256(node_type + config + sorted(input_cache_ids)) + + Args: + input_cache_ids: Mapping from input step_id/name to their cache_id + + Returns: + The computed cache_id + """ + # Use structured inputs if available, otherwise fall back to input_steps + if self.inputs: + resolved_inputs = [ + inp.cache_id or input_cache_ids.get(inp.source, inp.source) + for inp in sorted(self.inputs, key=lambda x: x.name) + ] + else: + resolved_inputs = [input_cache_ids.get(s, s) for s in sorted(self.input_steps)] + + content = { + "node_type": self.node_type, + "config": self.config, + "inputs": resolved_inputs, + } + self.cache_id = _stable_hash(content) + return self.cache_id + + def compute_output_cache_id(self, index: int) -> str: + """ + Compute cache ID for a specific output index. + + output_cache_id = SHA3-256(step_cache_id + index) + + Args: + index: The output index + + Returns: + Cache ID for that output + """ + if not self.cache_id: + raise ValueError("Step cache_id must be computed first") + content = {"step_cache_id": self.cache_id, "output_index": index} + return _stable_hash(content) + + def add_output( + self, + name: str, + media_type: str = "application/octet-stream", + index: Optional[int] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> StepOutput: + """ + Add an output to this step. + + Args: + name: Human-readable output name + media_type: MIME type of the output + index: Output index (defaults to next available) + metadata: Optional metadata + + Returns: + The created StepOutput + """ + if index is None: + index = len(self.outputs) + + cache_id = self.compute_output_cache_id(index) + output = StepOutput( + name=name, + cache_id=cache_id, + media_type=media_type, + index=index, + metadata=metadata or {}, + ) + self.outputs.append(output) + return output + + def get_output(self, index: int = 0) -> Optional[StepOutput]: + """Get output by index.""" + if index < len(self.outputs): + return self.outputs[index] + return None + + def get_output_by_name(self, name: str) -> Optional[StepOutput]: + """Get output by name.""" + for output in self.outputs: + if output.name == name: + return output + return None + + def to_dict(self) -> Dict[str, Any]: + return { + "step_id": self.step_id, + "name": self.name, + "node_type": self.node_type, + "config": self.config, + "input_steps": self.input_steps, + "inputs": [inp.to_dict() for inp in self.inputs], + "cache_id": self.cache_id, + "outputs": [out.to_dict() for out in self.outputs], + "estimated_duration": self.estimated_duration, + "level": self.level, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep": + inputs = [StepInput.from_dict(i) for i in data.get("inputs", [])] + outputs = [StepOutput.from_dict(o) for o in data.get("outputs", [])] + return cls( + step_id=data["step_id"], + node_type=data["node_type"], + config=data.get("config", {}), + input_steps=data.get("input_steps", []), + inputs=inputs, + cache_id=data.get("cache_id"), + outputs=outputs, + name=data.get("name"), + estimated_duration=data.get("estimated_duration"), + level=data.get("level", 0), + ) + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> "ExecutionStep": + return cls.from_dict(json.loads(json_str)) + + +@dataclass +class PlanInput: + """ + An input to the execution plan. + + Attributes: + name: Human-readable name from recipe (e.g., "source_video") + cache_id: Content hash of the input file + cid: Same as cache_id (for clarity) + media_type: MIME type of the input + """ + name: str + cache_id: str + cid: str + media_type: str = "application/octet-stream" + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "cache_id": self.cache_id, + "cid": self.cid, + "media_type": self.media_type, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PlanInput": + return cls( + name=data["name"], + cache_id=data["cache_id"], + cid=data.get("cid", data["cache_id"]), + media_type=data.get("media_type", "application/octet-stream"), + ) + + +@dataclass +class ExecutionPlan: + """ + Complete execution plan for a recipe. + + Contains all steps in topological order with pre-computed cache IDs. + The plan is deterministic: same recipe + same inputs = same plan. + + Attributes: + name: Human-readable plan name from recipe + plan_id: Hash of the entire plan (for deduplication) + recipe_id: Source recipe identifier + recipe_name: Human-readable recipe name + recipe_hash: Hash of the recipe content + seed: Random seed used for planning + steps: List of steps in execution order + output_step: ID of the final output step + output_name: Human-readable name of the final output + inputs: Structured input definitions + analysis_cache_ids: Cache IDs of analysis results used + input_hashes: Content hashes of input files (legacy, use inputs) + created_at: When the plan was generated + metadata: Optional additional metadata + """ + plan_id: Optional[str] + recipe_id: str + recipe_hash: str + steps: List[ExecutionStep] + output_step: str + name: Optional[str] = None + recipe_name: Optional[str] = None + seed: Optional[int] = None + output_name: Optional[str] = None + inputs: List[PlanInput] = field(default_factory=list) + analysis_cache_ids: Dict[str, str] = field(default_factory=dict) + input_hashes: Dict[str, str] = field(default_factory=dict) + created_at: Optional[str] = None + metadata: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if self.created_at is None: + self.created_at = datetime.now(timezone.utc).isoformat() + if self.plan_id is None: + self.plan_id = self._compute_plan_id() + + def _compute_plan_id(self) -> str: + """Compute plan ID from contents.""" + content = { + "recipe_hash": self.recipe_hash, + "steps": [s.to_dict() for s in self.steps], + "input_hashes": self.input_hashes, + "analysis_cache_ids": self.analysis_cache_ids, + } + return _stable_hash(content) + + def compute_all_cache_ids(self) -> None: + """ + Compute cache IDs for all steps in dependency order. + + Must be called after all steps are added to ensure + cache IDs propagate correctly through dependencies. + """ + # Build step lookup + step_by_id = {s.step_id: s for s in self.steps} + + # Cache IDs start with input hashes + cache_ids = dict(self.input_hashes) + + # Process in order (assumes topological order) + for step in self.steps: + # For SOURCE steps referencing inputs, use input hash + if step.node_type == "SOURCE" and step.config.get("input_ref"): + ref = step.config["input_ref"] + if ref in self.input_hashes: + step.cache_id = self.input_hashes[ref] + cache_ids[step.step_id] = step.cache_id + continue + + # For other steps, compute from inputs + input_cache_ids = {} + for input_step_id in step.input_steps: + if input_step_id in cache_ids: + input_cache_ids[input_step_id] = cache_ids[input_step_id] + elif input_step_id in step_by_id: + # Step should have been processed already + input_cache_ids[input_step_id] = step_by_id[input_step_id].cache_id + else: + raise ValueError(f"Input step {input_step_id} not found for {step.step_id}") + + step.compute_cache_id(input_cache_ids) + cache_ids[step.step_id] = step.cache_id + + # Recompute plan_id with final cache IDs + self.plan_id = self._compute_plan_id() + + def compute_levels(self) -> int: + """ + Compute dependency levels for all steps. + + Level 0 = no dependencies (can start immediately) + Level N = depends on steps at level N-1 + + Returns: + Maximum level (number of sequential dependency levels) + """ + step_by_id = {s.step_id: s for s in self.steps} + levels = {} + + def compute_level(step_id: str) -> int: + if step_id in levels: + return levels[step_id] + + step = step_by_id.get(step_id) + if step is None: + return 0 # Input from outside the plan + + if not step.input_steps: + levels[step_id] = 0 + step.level = 0 + return 0 + + max_input_level = max(compute_level(s) for s in step.input_steps) + level = max_input_level + 1 + levels[step_id] = level + step.level = level + return level + + for step in self.steps: + compute_level(step.step_id) + + return max(levels.values()) if levels else 0 + + def get_steps_by_level(self) -> Dict[int, List[ExecutionStep]]: + """ + Group steps by dependency level. + + Steps at the same level can execute in parallel. + + Returns: + Dict mapping level -> list of steps at that level + """ + by_level: Dict[int, List[ExecutionStep]] = {} + for step in self.steps: + by_level.setdefault(step.level, []).append(step) + return by_level + + def get_step(self, step_id: str) -> Optional[ExecutionStep]: + """Get step by ID.""" + for step in self.steps: + if step.step_id == step_id: + return step + return None + + def get_step_by_cache_id(self, cache_id: str) -> Optional[ExecutionStep]: + """Get step by cache ID.""" + for step in self.steps: + if step.cache_id == cache_id: + return step + return None + + def get_step_by_name(self, name: str) -> Optional[ExecutionStep]: + """Get step by human-readable name.""" + for step in self.steps: + if step.name == name: + return step + return None + + def get_all_outputs(self) -> Dict[str, StepOutput]: + """ + Get all outputs from all steps, keyed by output name. + + Returns: + Dict mapping output name -> StepOutput + """ + outputs = {} + for step in self.steps: + for output in step.outputs: + outputs[output.name] = output + return outputs + + def get_output_cache_ids(self) -> Dict[str, str]: + """ + Get mapping of output names to cache IDs. + + Returns: + Dict mapping output name -> cache_id + """ + return { + output.name: output.cache_id + for step in self.steps + for output in step.outputs + } + + def to_dict(self) -> Dict[str, Any]: + return { + "plan_id": self.plan_id, + "name": self.name, + "recipe_id": self.recipe_id, + "recipe_name": self.recipe_name, + "recipe_hash": self.recipe_hash, + "seed": self.seed, + "inputs": [i.to_dict() for i in self.inputs], + "steps": [s.to_dict() for s in self.steps], + "output_step": self.output_step, + "output_name": self.output_name, + "analysis_cache_ids": self.analysis_cache_ids, + "input_hashes": self.input_hashes, + "created_at": self.created_at, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExecutionPlan": + inputs = [PlanInput.from_dict(i) for i in data.get("inputs", [])] + return cls( + plan_id=data.get("plan_id"), + name=data.get("name"), + recipe_id=data["recipe_id"], + recipe_name=data.get("recipe_name"), + recipe_hash=data["recipe_hash"], + seed=data.get("seed"), + inputs=inputs, + steps=[ExecutionStep.from_dict(s) for s in data.get("steps", [])], + output_step=data["output_step"], + output_name=data.get("output_name"), + analysis_cache_ids=data.get("analysis_cache_ids", {}), + input_hashes=data.get("input_hashes", {}), + created_at=data.get("created_at"), + metadata=data.get("metadata", {}), + ) + + def to_json(self, indent: int = 2) -> str: + return json.dumps(self.to_dict(), indent=indent) + + @classmethod + def from_json(cls, json_str: str) -> "ExecutionPlan": + return cls.from_dict(json.loads(json_str)) + + def summary(self) -> str: + """Get a human-readable summary of the plan.""" + by_level = self.get_steps_by_level() + max_level = max(by_level.keys()) if by_level else 0 + + lines = [ + f"Execution Plan: {self.plan_id[:16]}...", + f"Recipe: {self.recipe_id}", + f"Steps: {len(self.steps)}", + f"Levels: {max_level + 1}", + "", + ] + + for level in sorted(by_level.keys()): + steps = by_level[level] + lines.append(f"Level {level}: ({len(steps)} steps, can run in parallel)") + for step in steps: + cache_status = f"[{step.cache_id[:8]}...]" if step.cache_id else "[no cache_id]" + lines.append(f" - {step.step_id}: {step.node_type} {cache_status}") + + return "\n".join(lines) diff --git a/artdag/planning/tree_reduction.py b/artdag/planning/tree_reduction.py new file mode 100644 index 0000000..3ab4147 --- /dev/null +++ b/artdag/planning/tree_reduction.py @@ -0,0 +1,231 @@ +# artdag/planning/tree_reduction.py +""" +Tree reduction for parallel composition. + +Instead of sequential pairwise composition: + A → AB → ABC → ABCD (3 sequential steps) + +Use parallel tree reduction: + A ─┬─ AB ─┬─ ABCD + B ─┘ │ + C ─┬─ CD ─┘ + D ─┘ + +This reduces O(N) to O(log N) levels of sequential dependency. +""" + +import math +from dataclasses import dataclass +from typing import List, Tuple, Any, Dict + + +@dataclass +class ReductionNode: + """A node in the reduction tree.""" + node_id: str + input_ids: List[str] + level: int + position: int # Position within level + + +class TreeReducer: + """ + Generates tree reduction plans for parallel composition. + + Used to convert N inputs into optimal parallel SEQUENCE operations. + """ + + def __init__(self, node_type: str = "SEQUENCE"): + """ + Initialize the reducer. + + Args: + node_type: The composition node type (SEQUENCE, AUDIO_MIX, etc.) + """ + self.node_type = node_type + + def reduce( + self, + input_ids: List[str], + id_prefix: str = "reduce", + ) -> Tuple[List[ReductionNode], str]: + """ + Generate a tree reduction plan for the given inputs. + + Args: + input_ids: List of input step IDs to reduce + id_prefix: Prefix for generated node IDs + + Returns: + Tuple of (list of reduction nodes, final output node ID) + """ + if len(input_ids) == 0: + raise ValueError("Cannot reduce empty input list") + + if len(input_ids) == 1: + # Single input, no reduction needed + return [], input_ids[0] + + if len(input_ids) == 2: + # Two inputs, single reduction + node_id = f"{id_prefix}_final" + node = ReductionNode( + node_id=node_id, + input_ids=input_ids, + level=0, + position=0, + ) + return [node], node_id + + # Build tree levels + nodes = [] + current_level = list(input_ids) + level_num = 0 + + while len(current_level) > 1: + next_level = [] + position = 0 + + # Pair up nodes at current level + i = 0 + while i < len(current_level): + if i + 1 < len(current_level): + # Pair available + left = current_level[i] + right = current_level[i + 1] + node_id = f"{id_prefix}_L{level_num}_P{position}" + node = ReductionNode( + node_id=node_id, + input_ids=[left, right], + level=level_num, + position=position, + ) + nodes.append(node) + next_level.append(node_id) + i += 2 + else: + # Odd one out, promote to next level + next_level.append(current_level[i]) + i += 1 + + position += 1 + + current_level = next_level + level_num += 1 + + # The last remaining node is the output + output_id = current_level[0] + + # Rename final node for clarity + if nodes and nodes[-1].node_id == output_id: + nodes[-1].node_id = f"{id_prefix}_final" + output_id = f"{id_prefix}_final" + + return nodes, output_id + + def get_reduction_depth(self, n: int) -> int: + """ + Calculate the number of reduction levels needed. + + Args: + n: Number of inputs + + Returns: + Number of sequential reduction levels (log2(n) ceiling) + """ + if n <= 1: + return 0 + return math.ceil(math.log2(n)) + + def get_total_operations(self, n: int) -> int: + """ + Calculate total number of reduction operations. + + Args: + n: Number of inputs + + Returns: + Total composition operations (always n-1) + """ + return max(0, n - 1) + + def reduce_with_config( + self, + input_ids: List[str], + base_config: Dict[str, Any], + id_prefix: str = "reduce", + ) -> Tuple[List[Tuple[ReductionNode, Dict[str, Any]]], str]: + """ + Generate reduction plan with configuration for each node. + + Args: + input_ids: List of input step IDs + base_config: Base configuration to use for each reduction + id_prefix: Prefix for generated node IDs + + Returns: + Tuple of (list of (node, config) pairs, final output ID) + """ + nodes, output_id = self.reduce(input_ids, id_prefix) + result = [(node, dict(base_config)) for node in nodes] + return result, output_id + + +def reduce_sequence( + input_ids: List[str], + transition_config: Dict[str, Any] = None, + id_prefix: str = "seq", +) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]: + """ + Convenience function for SEQUENCE reduction. + + Args: + input_ids: Input step IDs to sequence + transition_config: Transition configuration (default: cut) + id_prefix: Prefix for generated step IDs + + Returns: + Tuple of (list of (step_id, inputs, config), final step ID) + """ + if transition_config is None: + transition_config = {"transition": {"type": "cut"}} + + reducer = TreeReducer("SEQUENCE") + nodes, output_id = reducer.reduce(input_ids, id_prefix) + + result = [ + (node.node_id, node.input_ids, dict(transition_config)) + for node in nodes + ] + + return result, output_id + + +def reduce_audio_mix( + input_ids: List[str], + mix_config: Dict[str, Any] = None, + id_prefix: str = "mix", +) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]: + """ + Convenience function for AUDIO_MIX reduction. + + Args: + input_ids: Input step IDs to mix + mix_config: Mix configuration + id_prefix: Prefix for generated step IDs + + Returns: + Tuple of (list of (step_id, inputs, config), final step ID) + """ + if mix_config is None: + mix_config = {"normalize": True} + + reducer = TreeReducer("AUDIO_MIX") + nodes, output_id = reducer.reduce(input_ids, id_prefix) + + result = [ + (node.node_id, node.input_ids, dict(mix_config)) + for node in nodes + ] + + return result, output_id diff --git a/artdag/registry/__init__.py b/artdag/registry/__init__.py new file mode 100644 index 0000000..3163387 --- /dev/null +++ b/artdag/registry/__init__.py @@ -0,0 +1,20 @@ +# primitive/registry/__init__.py +""" +Art DAG Registry. + +The registry is the foundational data structure that maps named assets +to their source paths or content-addressed IDs. Assets in the registry +can be referenced by DAGs. + +Example: + registry = Registry("/path/to/registry") + registry.add("cat", "/path/to/cat.jpg", tags=["animal", "photo"]) + + # Later, in a DAG: + builder = DAGBuilder() + cat = builder.source(registry.get("cat").path) +""" + +from .registry import Registry, Asset + +__all__ = ["Registry", "Asset"] diff --git a/artdag/registry/registry.py b/artdag/registry/registry.py new file mode 100644 index 0000000..3290411 --- /dev/null +++ b/artdag/registry/registry.py @@ -0,0 +1,294 @@ +# primitive/registry/registry.py +""" +Asset registry for the Art DAG. + +The registry stores named assets with metadata, enabling: +- Named references to source files +- Tagging and categorization +- Content-addressed deduplication +- Asset discovery and search +""" + +import hashlib +import json +import shutil +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + + +def _file_hash(path: Path, algorithm: str = "sha3_256") -> str: + """ + Compute content hash of a file. + + Uses SHA-3 (Keccak) by default for quantum resistance. + SHA-3-256 provides 128-bit security against quantum attacks (Grover's algorithm). + + Args: + path: File to hash + algorithm: Hash algorithm (sha3_256, sha3_512, sha256, blake2b) + + Returns: + Full hex digest (no truncation) + """ + hasher = hashlib.new(algorithm) + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +@dataclass +class Asset: + """ + A registered asset in the Art DAG. + + The cid is the true identifier. URL and local_path are + locations where the content can be fetched. + + Attributes: + name: Unique name for the asset + cid: SHA-3-256 hash - the canonical identifier + url: Public URL (canonical location) + local_path: Optional local path (for local execution) + asset_type: Type of asset (image, video, audio, etc.) + tags: List of tags for categorization + metadata: Additional metadata (dimensions, duration, etc.) + created_at: Timestamp when added to registry + """ + name: str + cid: str + url: Optional[str] = None + local_path: Optional[Path] = None + asset_type: str = "unknown" + tags: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + created_at: float = field(default_factory=time.time) + + @property + def path(self) -> Optional[Path]: + """Backwards compatible path property.""" + return self.local_path + + def to_dict(self) -> Dict[str, Any]: + data = { + "name": self.name, + "cid": self.cid, + "asset_type": self.asset_type, + "tags": self.tags, + "metadata": self.metadata, + "created_at": self.created_at, + } + if self.url: + data["url"] = self.url + if self.local_path: + data["local_path"] = str(self.local_path) + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Asset": + local_path = data.get("local_path") or data.get("path") # backwards compat + return cls( + name=data["name"], + cid=data["cid"], + url=data.get("url"), + local_path=Path(local_path) if local_path else None, + asset_type=data.get("asset_type", "unknown"), + tags=data.get("tags", []), + metadata=data.get("metadata", {}), + created_at=data.get("created_at", time.time()), + ) + + +class Registry: + """ + The Art DAG registry. + + Stores named assets that can be referenced by DAGs. + + Structure: + registry_dir/ + registry.json # Index of all assets + assets/ # Optional: copied asset files + / + + """ + + def __init__(self, registry_dir: Path | str, copy_assets: bool = False): + """ + Initialize the registry. + + Args: + registry_dir: Directory to store registry data + copy_assets: If True, copy assets into registry (content-addressed) + """ + self.registry_dir = Path(registry_dir) + self.registry_dir.mkdir(parents=True, exist_ok=True) + self.copy_assets = copy_assets + self._assets: Dict[str, Asset] = {} + self._load() + + def _index_path(self) -> Path: + return self.registry_dir / "registry.json" + + def _assets_dir(self) -> Path: + return self.registry_dir / "assets" + + def _load(self): + """Load registry from disk.""" + index_path = self._index_path() + if index_path.exists(): + with open(index_path) as f: + data = json.load(f) + self._assets = { + name: Asset.from_dict(asset_data) + for name, asset_data in data.get("assets", {}).items() + } + + def _save(self): + """Save registry to disk.""" + data = { + "version": "1.0", + "assets": {name: asset.to_dict() for name, asset in self._assets.items()}, + } + with open(self._index_path(), "w") as f: + json.dump(data, f, indent=2) + + def add( + self, + name: str, + cid: str, + url: str = None, + local_path: Path | str = None, + asset_type: str = None, + tags: List[str] = None, + metadata: Dict[str, Any] = None, + ) -> Asset: + """ + Add an asset to the registry. + + Args: + name: Unique name for the asset + cid: SHA-3-256 hash of the content (the canonical identifier) + url: Public URL where the asset can be fetched + local_path: Optional local path (for local execution) + asset_type: Type of asset (image, video, audio, etc.) + tags: List of tags for categorization + metadata: Additional metadata + + Returns: + The created Asset + """ + # Auto-detect asset type from URL or path extension + if asset_type is None: + ext = None + if url: + ext = Path(url.split("?")[0]).suffix.lower() + elif local_path: + ext = Path(local_path).suffix.lower() + if ext: + type_map = { + ".jpg": "image", ".jpeg": "image", ".png": "image", + ".gif": "image", ".webp": "image", ".bmp": "image", + ".mp4": "video", ".mkv": "video", ".avi": "video", + ".mov": "video", ".webm": "video", + ".mp3": "audio", ".wav": "audio", ".flac": "audio", + ".ogg": "audio", ".aac": "audio", + } + asset_type = type_map.get(ext, "unknown") + else: + asset_type = "unknown" + + asset = Asset( + name=name, + cid=cid, + url=url, + local_path=Path(local_path).resolve() if local_path else None, + asset_type=asset_type, + tags=tags or [], + metadata=metadata or {}, + ) + + self._assets[name] = asset + self._save() + return asset + + def add_from_file( + self, + name: str, + path: Path | str, + url: str = None, + asset_type: str = None, + tags: List[str] = None, + metadata: Dict[str, Any] = None, + ) -> Asset: + """ + Add an asset from a local file (computes hash automatically). + + Args: + name: Unique name for the asset + path: Path to the source file + url: Optional public URL + asset_type: Type of asset (auto-detected if not provided) + tags: List of tags for categorization + metadata: Additional metadata + + Returns: + The created Asset + """ + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Asset file not found: {path}") + + cid = _file_hash(path) + + return self.add( + name=name, + cid=cid, + url=url, + local_path=path, + asset_type=asset_type, + tags=tags, + metadata=metadata, + ) + + def get(self, name: str) -> Optional[Asset]: + """Get an asset by name.""" + return self._assets.get(name) + + def remove(self, name: str) -> bool: + """Remove an asset from the registry.""" + if name not in self._assets: + return False + del self._assets[name] + self._save() + return True + + def list(self) -> List[Asset]: + """List all assets.""" + return list(self._assets.values()) + + def find_by_tag(self, tag: str) -> List[Asset]: + """Find assets with a specific tag.""" + return [a for a in self._assets.values() if tag in a.tags] + + def find_by_type(self, asset_type: str) -> List[Asset]: + """Find assets of a specific type.""" + return [a for a in self._assets.values() if a.asset_type == asset_type] + + def find_by_hash(self, cid: str) -> Optional[Asset]: + """Find an asset by content hash.""" + for asset in self._assets.values(): + if asset.cid == cid: + return asset + return None + + def __contains__(self, name: str) -> bool: + return name in self._assets + + def __len__(self) -> int: + return len(self._assets) + + def __iter__(self): + return iter(self._assets.values()) diff --git a/artdag/server.py b/artdag/server.py new file mode 100644 index 0000000..f10374c --- /dev/null +++ b/artdag/server.py @@ -0,0 +1,253 @@ +# primitive/server.py +""" +HTTP server for primitive execution engine. + +Provides a REST API for submitting DAGs and retrieving results. + +Endpoints: + POST /execute - Submit DAG for execution + GET /status/:id - Get execution status + GET /result/:id - Get execution result + GET /cache/stats - Get cache statistics + DELETE /cache - Clear cache +""" + +import json +import logging +import threading +import uuid +from dataclasses import dataclass, field +from http.server import HTTPServer, BaseHTTPRequestHandler +from pathlib import Path +from typing import Any, Dict, Optional +from urllib.parse import urlparse + +from .dag import DAG +from .engine import Engine, ExecutionResult +from . import nodes # Register built-in executors + +logger = logging.getLogger(__name__) + + +@dataclass +class Job: + """A pending or completed execution job.""" + job_id: str + dag: DAG + status: str = "pending" # pending, running, completed, failed + result: Optional[ExecutionResult] = None + error: Optional[str] = None + + +class PrimitiveServer: + """ + HTTP server for the primitive engine. + + Usage: + server = PrimitiveServer(cache_dir="/tmp/primitive_cache", port=8080) + server.start() # Blocking + """ + + def __init__(self, cache_dir: Path | str, host: str = "127.0.0.1", port: int = 8080): + self.cache_dir = Path(cache_dir) + self.host = host + self.port = port + self.engine = Engine(self.cache_dir) + self.jobs: Dict[str, Job] = {} + self._lock = threading.Lock() + + def submit_job(self, dag: DAG) -> str: + """Submit a DAG for execution, return job ID.""" + job_id = str(uuid.uuid4())[:8] + job = Job(job_id=job_id, dag=dag) + + with self._lock: + self.jobs[job_id] = job + + # Execute in background thread + thread = threading.Thread(target=self._execute_job, args=(job_id,)) + thread.daemon = True + thread.start() + + return job_id + + def _execute_job(self, job_id: str): + """Execute a job in background.""" + with self._lock: + job = self.jobs.get(job_id) + if not job: + return + job.status = "running" + + try: + result = self.engine.execute(job.dag) + with self._lock: + job.result = result + job.status = "completed" if result.success else "failed" + if not result.success: + job.error = result.error + except Exception as e: + logger.exception(f"Job {job_id} failed") + with self._lock: + job.status = "failed" + job.error = str(e) + + def get_job(self, job_id: str) -> Optional[Job]: + """Get job by ID.""" + with self._lock: + return self.jobs.get(job_id) + + def _create_handler(server_instance): + """Create request handler with access to server instance.""" + + class RequestHandler(BaseHTTPRequestHandler): + server_ref = server_instance + + def log_message(self, format, *args): + logger.debug(format % args) + + def _send_json(self, data: Any, status: int = 200): + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(data).encode()) + + def _send_error(self, message: str, status: int = 400): + self._send_json({"error": message}, status) + + def do_GET(self): + parsed = urlparse(self.path) + path = parsed.path + + if path.startswith("/status/"): + job_id = path[8:] + job = self.server_ref.get_job(job_id) + if not job: + self._send_error("Job not found", 404) + return + self._send_json({ + "job_id": job.job_id, + "status": job.status, + "error": job.error, + }) + + elif path.startswith("/result/"): + job_id = path[8:] + job = self.server_ref.get_job(job_id) + if not job: + self._send_error("Job not found", 404) + return + if job.status == "pending" or job.status == "running": + self._send_json({ + "job_id": job.job_id, + "status": job.status, + "ready": False, + }) + return + + result = job.result + self._send_json({ + "job_id": job.job_id, + "status": job.status, + "ready": True, + "success": result.success if result else False, + "output_path": str(result.output_path) if result and result.output_path else None, + "error": job.error, + "execution_time": result.execution_time if result else 0, + "nodes_executed": result.nodes_executed if result else 0, + "nodes_cached": result.nodes_cached if result else 0, + }) + + elif path == "/cache/stats": + stats = self.server_ref.engine.get_cache_stats() + self._send_json({ + "total_entries": stats.total_entries, + "total_size_bytes": stats.total_size_bytes, + "hits": stats.hits, + "misses": stats.misses, + "hit_rate": stats.hit_rate, + }) + + elif path == "/health": + self._send_json({"status": "ok"}) + + else: + self._send_error("Not found", 404) + + def do_POST(self): + if self.path == "/execute": + try: + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length).decode() + data = json.loads(body) + + dag = DAG.from_dict(data) + job_id = self.server_ref.submit_job(dag) + + self._send_json({ + "job_id": job_id, + "status": "pending", + }) + except json.JSONDecodeError as e: + self._send_error(f"Invalid JSON: {e}") + except Exception as e: + self._send_error(str(e), 500) + else: + self._send_error("Not found", 404) + + def do_DELETE(self): + if self.path == "/cache": + self.server_ref.engine.clear_cache() + self._send_json({"status": "cleared"}) + else: + self._send_error("Not found", 404) + + return RequestHandler + + def start(self): + """Start the HTTP server (blocking).""" + handler = self._create_handler() + server = HTTPServer((self.host, self.port), handler) + logger.info(f"Primitive server starting on {self.host}:{self.port}") + print(f"Primitive server running on http://{self.host}:{self.port}") + try: + server.serve_forever() + except KeyboardInterrupt: + print("\nShutting down...") + server.shutdown() + + def start_background(self) -> threading.Thread: + """Start the server in a background thread.""" + thread = threading.Thread(target=self.start) + thread.daemon = True + thread.start() + return thread + + +def main(): + """CLI entry point.""" + import argparse + + parser = argparse.ArgumentParser(description="Primitive execution server") + parser.add_argument("--host", default="127.0.0.1", help="Host to bind to") + parser.add_argument("--port", type=int, default=8080, help="Port to bind to") + parser.add_argument("--cache-dir", default="/tmp/primitive_cache", help="Cache directory") + parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging") + + args = parser.parse_args() + + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + ) + + server = PrimitiveServer( + cache_dir=args.cache_dir, + host=args.host, + port=args.port, + ) + server.start() + + +if __name__ == "__main__": + main() diff --git a/artdag/sexp/__init__.py b/artdag/sexp/__init__.py new file mode 100644 index 0000000..08b646f --- /dev/null +++ b/artdag/sexp/__init__.py @@ -0,0 +1,75 @@ +""" +S-expression parsing, compilation, and planning for ArtDAG. + +This module provides: +- parser: Parse S-expression text into Python data structures +- compiler: Compile recipe S-expressions into DAG format +- planner: Generate execution plans from recipes +""" + +from .parser import ( + parse, + parse_all, + serialize, + Symbol, + Keyword, + ParseError, +) + +from .compiler import ( + compile_recipe, + compile_string, + CompiledRecipe, + CompileError, + ParamDef, + _parse_params, +) + +from .planner import ( + create_plan, + ExecutionPlanSexp, + PlanStep, + step_to_task_sexp, + task_cache_id, +) + +from .scheduler import ( + PlanScheduler, + PlanResult, + StepResult, + schedule_plan, + step_to_sexp, + step_sexp_to_string, + verify_step_cache_id, +) + +__all__ = [ + # Parser + 'parse', + 'parse_all', + 'serialize', + 'Symbol', + 'Keyword', + 'ParseError', + # Compiler + 'compile_recipe', + 'compile_string', + 'CompiledRecipe', + 'CompileError', + 'ParamDef', + '_parse_params', + # Planner + 'create_plan', + 'ExecutionPlanSexp', + 'PlanStep', + 'step_to_task_sexp', + 'task_cache_id', + # Scheduler + 'PlanScheduler', + 'PlanResult', + 'StepResult', + 'schedule_plan', + 'step_to_sexp', + 'step_sexp_to_string', + 'verify_step_cache_id', +] diff --git a/artdag/sexp/compiler.py b/artdag/sexp/compiler.py new file mode 100644 index 0000000..9729312 --- /dev/null +++ b/artdag/sexp/compiler.py @@ -0,0 +1,2463 @@ +""" +Compiler for S-expression recipes. + +Transforms S-expression recipes into internal DAG format. +Handles: +- Threading macro expansion (->) +- def bindings for named nodes +- Registry resolution (assets, effects) +- Node ID generation (content-addressed) +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple +import hashlib +import json + +from .parser import Symbol, Keyword, Lambda, parse, serialize +from pathlib import Path + + +def compute_content_cid(content: str) -> str: + """Compute content-addressed ID (SHA256 hash) for content. + + This is used for effects, recipes, and other text content that + will be stored in the cache. The cid can be used to fetch the + content from cache or IPFS. + """ + return hashlib.sha256(content.encode()).hexdigest() + + +def compute_file_cid(file_path: Path) -> str: + """Compute content-addressed ID for a file. + + Args: + file_path: Path to the file + + Returns: + SHA3-256 hash of file contents + """ + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + content = file_path.read_text() + return compute_content_cid(content) + + +def _serialize_for_hash(obj) -> str: + """Serialize any value to canonical S-expression string for hashing.""" + if obj is None: + return "nil" + if isinstance(obj, bool): + return "true" if obj else "false" + if isinstance(obj, (int, float)): + return str(obj) + if isinstance(obj, str): + escaped = obj.replace('\\', '\\\\').replace('"', '\\"') + return f'"{escaped}"' + if isinstance(obj, Symbol): + return obj.name + if isinstance(obj, Keyword): + return f":{obj.name}" + if isinstance(obj, Lambda): + params = " ".join(obj.params) + body = _serialize_for_hash(obj.body) + return f"(fn [{params}] {body})" + if isinstance(obj, dict): + items = [] + for k, v in sorted(obj.items()): + items.append(f":{k} {_serialize_for_hash(v)}") + return "{" + " ".join(items) + "}" + if isinstance(obj, list): + items = [_serialize_for_hash(x) for x in obj] + return "(" + " ".join(items) + ")" + return str(obj) + + +class CompileError(Exception): + """Error during recipe compilation.""" + pass + + +@dataclass +class ParamDef: + """Definition of a recipe parameter.""" + name: str + param_type: str # "string", "int", "float", "bool" + default: Any + description: str = "" + range_min: Optional[float] = None + range_max: Optional[float] = None + choices: Optional[List[str]] = None # For enum-like params + + +@dataclass +class CompiledStage: + """A compiled stage with dependencies and outputs.""" + name: str + requires: List[str] # Names of required stages + inputs: List[str] # Names of bindings consumed from required stages + outputs: List[str] # Names of bindings produced by this stage + node_ids: List[str] # Node IDs created in this stage + output_bindings: Dict[str, str] # output_name -> node_id mapping + + +@dataclass +class CompiledRecipe: + """Result of compiling an S-expression recipe.""" + name: str + version: str + description: str + owner: Optional[str] + registry: Dict[str, Dict[str, Any]] # {assets: {...}, effects: {...}} + nodes: List[Dict[str, Any]] # List of node definitions + output_node_id: str + encoding: Dict[str, Any] = field(default_factory=dict) # {codec, crf, preset, audio_codec} + metadata: Dict[str, Any] = field(default_factory=dict) + params: List[ParamDef] = field(default_factory=list) # Declared parameters + stages: List[CompiledStage] = field(default_factory=list) # Compiled stages + stage_order: List[str] = field(default_factory=list) # Topologically sorted stage names + minimal_primitives: bool = False # If True, only core primitives available + source_text: str = "" # Original source text for stable hashing + resolved_params: Dict[str, Any] = field(default_factory=dict) # Resolved parameter values + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format (compatible with YAML structure).""" + return { + "name": self.name, + "version": self.version, + "description": self.description, + "owner": self.owner, + "registry": self.registry, + "dag": { + "nodes": self.nodes, + "output": self.output_node_id, + }, + "encoding": self.encoding, + "metadata": self.metadata, + } + + +@dataclass +class CompilerContext: + """Compilation context tracking bindings and nodes.""" + registry: Dict[str, Dict[str, Any]] = field(default_factory=lambda: {"assets": {}, "effects": {}, "analyzers": {}, "constructs": {}, "templates": {}, "includes": {}}) + template_call_count: int = 0 + bindings: Dict[str, str] = field(default_factory=dict) # name -> node_id + nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict) # node_id -> node + + # Recipe directory for resolving relative paths + recipe_dir: Optional[Path] = None + + # Stage tracking + current_stage: Optional[str] = None # Name of stage currently being compiled + defined_stages: Dict[str, 'CompiledStage'] = field(default_factory=dict) # stage_name -> CompiledStage + stage_bindings: Dict[str, Dict[str, str]] = field(default_factory=dict) # stage_name -> {binding_name -> node_id} + pre_stage_bindings: Dict[str, Any] = field(default_factory=dict) # bindings defined before any stage + stage_node_ids: List[str] = field(default_factory=list) # node IDs created in current stage + + def add_node(self, node_type: str, config: Dict[str, Any], + inputs: List[str] = None, name: str = None) -> str: + """ + Add a node and return its code-addressed ID. + + The node_id is a hash of the S-expression subtree (type, config, inputs), + creating a Merkle-tree like a blockchain - each node's hash includes all + upstream hashes. This is computed purely from the plan, before execution. + + The node_id is a pre-computed "bucket" where the computation result will + be stored. Same plan = same buckets = automatic cache reuse. + """ + # Build canonical S-expression for hashing + # Inputs are already code-addressed node IDs (hashes) + canonical = { + "type": node_type, + "config": config, + "inputs": inputs or [], + } + # Hash the canonical S-expression form using SHA3-256 + canonical_sexp = _serialize_for_hash(canonical) + node_id = hashlib.sha3_256(canonical_sexp.encode()).hexdigest() + + # Check for collision (same hash = same computation, reuse) + if node_id in self.nodes: + return node_id + + self.nodes[node_id] = { + "id": node_id, + "type": node_type, + "config": config, + "inputs": inputs or [], + "name": name, + } + + # Track node in current stage + if self.current_stage is not None: + self.stage_node_ids.append(node_id) + + return node_id + + def get_accessible_bindings(self, stage_inputs: List[str] = None) -> Dict[str, Any]: + """ + Get bindings accessible to the current stage. + + If inside a stage with declared inputs, only those inputs plus pre-stage + bindings are accessible. If outside a stage, all bindings are accessible. + """ + if self.current_stage is None: + return dict(self.bindings) + + # Start with pre-stage bindings (sources, etc.) + accessible = dict(self.pre_stage_bindings) + + # Add declared inputs from required stages + if stage_inputs: + for input_name in stage_inputs: + # Look for the binding in required stages + for stage_name, stage in self.defined_stages.items(): + if input_name in stage.output_bindings: + accessible[input_name] = stage.output_bindings[input_name] + break + else: + # Check if it's in pre-stage bindings (might be a source) + if input_name not in accessible: + raise CompileError( + f"Stage '{self.current_stage}' declares input '{input_name}' " + f"but it's not produced by any required stage" + ) + + return accessible + + +def _topological_sort_stages(stages: Dict[str, 'CompiledStage']) -> List[str]: + """ + Topologically sort stages by their dependencies. + + Returns list of stage names in execution order (dependencies first). + """ + if not stages: + return [] + + # Build dependency graph + in_degree = {name: 0 for name in stages} + dependents = {name: [] for name in stages} + + for name, stage in stages.items(): + for req in stage.requires: + if req in stages: + dependents[req].append(name) + in_degree[name] += 1 + + # Kahn's algorithm + queue = [name for name, degree in in_degree.items() if degree == 0] + result = [] + + while queue: + # Sort for deterministic ordering + queue.sort() + current = queue.pop(0) + result.append(current) + + for dependent in dependents[current]: + in_degree[dependent] -= 1 + if in_degree[dependent] == 0: + queue.append(dependent) + + if len(result) != len(stages): + # This shouldn't happen if we validated cycles earlier + missing = set(stages.keys()) - set(result) + raise CompileError(f"Circular stage dependency detected: {missing}") + + return result + + +def _parse_encoding(value: Any) -> Dict[str, Any]: + """ + Parse encoding settings from S-expression. + + Expects a list like: (:codec "libx264" :crf 18 :preset "fast" :audio-codec "aac") + Returns: {"codec": "libx264", "crf": 18, "preset": "fast", "audio_codec": "aac"} + """ + if not isinstance(value, list): + raise CompileError(f"Encoding must be a list, got {type(value).__name__}") + + result = {} + i = 0 + while i < len(value): + item = value[i] + if isinstance(item, Keyword): + if i + 1 >= len(value): + raise CompileError(f"Encoding keyword {item.name} missing value") + # Convert kebab-case to snake_case for Python + key = item.name.replace("-", "_") + result[key] = value[i + 1] + i += 2 + else: + raise CompileError(f"Expected keyword in encoding, got {type(item).__name__}") + return result + + +def _parse_params(value: Any) -> List[ParamDef]: + """ + Parse parameter definitions from S-expression. + + Syntax: + :params ( + (param_name :type string :default "value" :desc "Description") + (param_name :type float :default 1.0 :range [0 10] :desc "Description") + (param_name :type string :default "a" :choices ["a" "b" "c"] :desc "Description") + ) + + Supported types: string, int, float, bool + Optional: :range [min max], :choices [...], :desc "..." + """ + if not isinstance(value, list): + raise CompileError(f"Params must be a list, got {type(value).__name__}") + + params = [] + for param_def in value: + if not isinstance(param_def, list) or len(param_def) < 1: + raise CompileError(f"Invalid param definition: {param_def}") + + # First element is the parameter name + first = param_def[0] + if isinstance(first, Symbol): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + raise CompileError(f"Param name must be symbol or string, got {type(first).__name__}") + + # Parse keyword arguments + param_type = "string" + default = None + desc = "" + range_min = None + range_max = None + choices = None + + i = 1 + while i < len(param_def): + item = param_def[i] + if isinstance(item, Keyword): + if i + 1 >= len(param_def): + raise CompileError(f"Param keyword {item.name} missing value") + kw_value = param_def[i + 1] + + if item.name == "type": + if isinstance(kw_value, Symbol): + param_type = kw_value.name + else: + param_type = str(kw_value) + elif item.name == "default": + # Convert nil symbol to Python None + if isinstance(kw_value, Symbol) and kw_value.name == "nil": + default = None + else: + default = kw_value + elif item.name == "desc" or item.name == "description": + desc = str(kw_value) + elif item.name == "range": + if isinstance(kw_value, list) and len(kw_value) >= 2: + range_min = float(kw_value[0]) + range_max = float(kw_value[1]) + else: + raise CompileError(f"Param range must be [min max], got {kw_value}") + elif item.name == "choices": + if isinstance(kw_value, list): + choices = [str(c) if not isinstance(c, Symbol) else c.name for c in kw_value] + else: + raise CompileError(f"Param choices must be a list, got {kw_value}") + else: + raise CompileError(f"Unknown param keyword :{item.name}") + i += 2 + else: + i += 1 + + # Convert default to appropriate type + if default is not None: + if param_type == "int": + default = int(default) + elif param_type == "float": + default = float(default) + elif param_type == "bool": + if isinstance(default, (int, float)): + default = bool(default) + elif isinstance(default, str): + default = default.lower() in ("true", "1", "yes") + elif param_type == "string": + default = str(default) + + params.append(ParamDef( + name=param_name, + param_type=param_type, + default=default, + description=desc, + range_min=range_min, + range_max=range_max, + choices=choices, + )) + + return params + + +def compile_recipe(sexp: Any, initial_bindings: Dict[str, Any] = None, recipe_dir: Path = None, source_text: str = "") -> CompiledRecipe: + """ + Compile an S-expression recipe into internal format. + + Args: + sexp: Parsed S-expression (list starting with 'recipe' symbol) + initial_bindings: Optional dict of name -> value bindings to inject before compilation. + These can be referenced as variables in the recipe. + recipe_dir: Directory containing the recipe file, for resolving relative paths. + source_text: Original source text for stable hashing. + + Returns: + CompiledRecipe with nodes and registry + + Example: + >>> sexp = parse('(recipe "test" :version "1.0" (-> (source cat) (effect identity)))') + >>> result = compile_recipe(sexp) + >>> # With parameters: + >>> result = compile_recipe(sexp, {"effect_num": 5}) + """ + if not isinstance(sexp, list) or len(sexp) < 2: + raise CompileError("Recipe must be a list starting with 'recipe'") + + head = sexp[0] + if not (isinstance(head, Symbol) and head.name == "recipe"): + raise CompileError(f"Expected 'recipe', got {head}") + + # Extract recipe name + if len(sexp) < 2 or not isinstance(sexp[1], str): + raise CompileError("Recipe name must be a string") + name = sexp[1] + + # Parse keyword arguments and body + ctx = CompilerContext(recipe_dir=recipe_dir) + + version = "1.0" + description = "" + owner = None + encoding = {} + params = [] + body_exprs = [] + minimal_primitives = False + + i = 2 + while i < len(sexp): + item = sexp[i] + + if isinstance(item, Keyword): + if i + 1 >= len(sexp): + raise CompileError(f"Keyword {item.name} missing value") + value = sexp[i + 1] + + if item.name == "version": + version = str(value) + elif item.name == "description": + description = str(value) + elif item.name == "owner": + owner = str(value) + elif item.name == "encoding": + encoding = _parse_encoding(value) + elif item.name == "params": + params = _parse_params(value) + elif item.name == "minimal-primitives": + # Handle boolean value (could be Symbol('true') or Python bool) + if isinstance(value, Symbol): + minimal_primitives = value.name.lower() == "true" + else: + minimal_primitives = bool(value) + else: + raise CompileError(f"Unknown keyword :{item.name}") + i += 2 + else: + # Body expression + body_exprs.append(item) + i += 1 + + # Create bindings from params with their default values + # Initial bindings override param defaults + for param in params: + if initial_bindings and param.name in initial_bindings: + ctx.bindings[param.name] = initial_bindings[param.name] + else: + ctx.bindings[param.name] = param.default + + # Inject any additional initial bindings not covered by params + if initial_bindings: + for k, v in initial_bindings.items(): + if k not in ctx.bindings: + ctx.bindings[k] = v + + # Compile body expressions + # Track when we encounter the first stage to capture pre-stage bindings + output_node_id = None + first_stage_seen = False + + for expr in body_exprs: + # Check if this is a stage form + is_stage_form = ( + isinstance(expr, list) and + len(expr) > 0 and + isinstance(expr[0], Symbol) and + expr[0].name == "stage" + ) + + # Before the first stage, capture bindings as pre-stage bindings + if is_stage_form and not first_stage_seen: + first_stage_seen = True + ctx.pre_stage_bindings = dict(ctx.bindings) + + result = _compile_expr(expr, ctx) + if result is not None: + output_node_id = result + + if output_node_id is None: + raise CompileError("Recipe has no output (no DAG expression)") + + # Build stage order (topological sort) + stage_order = _topological_sort_stages(ctx.defined_stages) + + # Collect stages in order + stages = [ctx.defined_stages[name] for name in stage_order] + + return CompiledRecipe( + name=name, + version=version, + description=description, + owner=owner, + registry=ctx.registry, + nodes=list(ctx.nodes.values()), + output_node_id=output_node_id, + encoding=encoding, + params=params, + stages=stages, + stage_order=stage_order, + minimal_primitives=minimal_primitives, + source_text=source_text, + resolved_params=initial_bindings or {}, + ) + + +def _compile_expr(expr: Any, ctx: CompilerContext) -> Optional[str]: + """ + Compile an expression, returning node_id if it produces a node. + + Handles: + - (asset name :hash "..." :url "...") + - (effect name :hash "..." :url "...") + - (def name expr) + - (-> expr expr ...) + - (source ...), (effect ...), (sequence ...), etc. + """ + if not isinstance(expr, list) or len(expr) == 0: + # Atom - could be a reference + if isinstance(expr, Symbol): + # Look up binding + if expr.name in ctx.bindings: + return ctx.bindings[expr.name] + raise CompileError(f"Undefined symbol: {expr.name}") + return None + + head = expr[0] + if not isinstance(head, Symbol): + raise CompileError(f"Expected symbol at head of expression, got {head}") + + name = head.name + + # Registry declarations + if name == "asset": + return _compile_asset(expr, ctx) + if name == "effect": + return _compile_effect_decl(expr, ctx) + if name == "analyzer": + return _compile_analyzer_decl(expr, ctx) + if name == "construct": + return _compile_construct_decl(expr, ctx) + + # Template definition + if name == "deftemplate": + return _compile_deftemplate(expr, ctx) + + # Include - load and evaluate external sexp file + if name == "include": + return _compile_include(expr, ctx) + + # Binding + if name == "def": + return _compile_def(expr, ctx) + + # Stage form + if name == "stage": + return _compile_stage(expr, ctx) + + # Threading macro + if name == "->": + return _compile_threading(expr, ctx) + + # Node types + if name == "source": + return _compile_source(expr, ctx) + if name in ("effect", "fx"): + return _compile_effect_node(expr, ctx) + if name == "segment": + return _compile_segment(expr, ctx) + if name == "resize": + return _compile_resize(expr, ctx) + if name == "sequence": + return _compile_sequence(expr, ctx) + # Note: layer and blend are now regular effects, not special forms + # Use: (effect layer bg fg :x 0 :y 0) or (effect blend a b :mode "overlay") + if name == "mux": + return _compile_mux(expr, ctx) + if name == "analyze": + return _compile_analyze(expr, ctx) + if name == "scan": + return _compile_scan(expr, ctx) + if name == "blend-multi": + return _compile_blend_multi(expr, ctx) + if name == "make-rng": + return _compile_make_rng(expr, ctx) + if name == "next-seed": + return _compile_next_seed(expr, ctx) + + # Check if it's a registered construct call BEFORE built-in slice-on + # This allows user-defined constructs to override built-ins + if name in ctx.registry.get("constructs", {}): + return _compile_construct_call(expr, ctx) + + if name == "slice-on": + return _compile_slice_on(expr, ctx) + + # Binding expression for parameter linking + if name == "bind": + return _compile_bind(expr, ctx) + + # Pure functions that can be evaluated at compile time + PURE_FUNCTIONS = { + "max", "min", "floor", "ceil", "round", "abs", + "+", "-", "*", "/", "mod", "sqrt", "pow", + "len", "get", "first", "last", "nth", + "=", "<", ">", "<=", ">=", "not=", + "and", "or", "not", + "inc", "dec", + "chunk-every", + "list", "dict", + "assert", + } + if name in PURE_FUNCTIONS: + # Evaluate using the evaluator + from .evaluator import evaluate + # Build env from ctx.bindings + env = dict(ctx.bindings) + try: + result = evaluate(expr, env) + return result + except Exception as e: + raise CompileError(f"Error evaluating {name}: {e}") + + # Template invocation + if name in ctx.registry.get("templates", {}): + return _compile_template_call(expr, ctx) + + raise CompileError(f"Unknown expression type: {name}") + + +def _parse_kwargs(expr: List, start: int = 1) -> Tuple[List[Any], Dict[str, Any]]: + """ + Parse positional args and keyword args from expression. + + Returns (positional_args, keyword_dict) + """ + positional = [] + kwargs = {} + + i = start + while i < len(expr): + item = expr[i] + if isinstance(item, Keyword): + if i + 1 >= len(expr): + raise CompileError(f"Keyword :{item.name} missing value") + kwargs[item.name] = expr[i + 1] + i += 2 + else: + positional.append(item) + i += 1 + + return positional, kwargs + + +def _compile_asset(expr: List, ctx: CompilerContext) -> None: + """Compile (asset name :cid "..." :url "...") or legacy (asset name :hash "...")""" + if len(expr) < 2: + raise CompileError("asset requires a name") + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + _, kwargs = _parse_kwargs(expr, 2) + + # Support both :cid (new IPFS) and :hash (legacy SHA3-256) + asset_cid = kwargs.get("cid") or kwargs.get("hash") + if not asset_cid: + raise CompileError(f"asset {name} requires :cid or :hash") + + ctx.registry["assets"][name] = { + "cid": asset_cid, + "url": kwargs.get("url"), + } + return None + + +def _resolve_effect_path(path: str, ctx: CompilerContext) -> Optional[Path]: + """Resolve an effect path relative to recipe directory. + + Args: + path: Relative or absolute path to effect file + ctx: Compiler context with recipe_dir + + Returns: + Resolved absolute Path, or None if not found + """ + effect_path = Path(path) + + # Already absolute + if effect_path.is_absolute() and effect_path.exists(): + return effect_path + + # Try relative to recipe directory + if ctx.recipe_dir: + recipe_relative = ctx.recipe_dir / path + if recipe_relative.exists(): + return recipe_relative.resolve() + + # Try relative to cwd + import os + cwd = Path(os.getcwd()) + cwd_relative = cwd / path + if cwd_relative.exists(): + return cwd_relative.resolve() + + return None + + +def _compile_effect_decl(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile effect - either declaration or node. + + Declaration: (effect name :cid "..." :url "...") or legacy (effect name :hash "...") + Node: (effect effect-name) or (effect effect-name input-node) + """ + if len(expr) < 2: + raise CompileError("effect requires at least a name") + + # Check if this is a declaration (has :cid or :hash) + _, kwargs = _parse_kwargs(expr, 2) + + # Support both :cid (new) and :hash (legacy) + effect_cid = kwargs.get("cid") or kwargs.get("hash") + + if effect_cid or "path" in kwargs: + # Declaration + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + # Handle temporal flag - could be Symbol('true') or Python bool + temporal = kwargs.get("temporal", False) + if isinstance(temporal, Symbol): + temporal = temporal.name.lower() == "true" + + effect_path = kwargs.get("path") + + # Compute cid from file content if path provided and no cid + if effect_path and not effect_cid: + resolved_path = _resolve_effect_path(effect_path, ctx) + if resolved_path and resolved_path.exists(): + effect_cid = compute_file_cid(resolved_path) + effect_path = str(resolved_path) # Store absolute path + + ctx.registry["effects"][name] = { + "cid": effect_cid, + "path": effect_path, + "url": kwargs.get("url"), + "temporal": temporal, + } + return None + + # Otherwise it's a node - delegate to effect node compiler + return _compile_effect_node(expr, ctx) + + +def _compile_analyzer_decl(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile analyzer declaration. + + Declaration: (analyzer name :path "..." :cid "...") + + Example: + (analyzer beats :path "../analyzers/beats/analyzer.py") + """ + if len(expr) < 2: + raise CompileError("analyzer requires at least a name") + + _, kwargs = _parse_kwargs(expr, 2) + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + ctx.registry["analyzers"][name] = { + "cid": kwargs.get("cid"), + "path": kwargs.get("path"), + "url": kwargs.get("url"), + } + return None + + +def _compile_construct_decl(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile construct declaration. + + Declaration: (construct name :path "...") + + Example: + (construct beat-alternate :path "constructs/beat-alternate.sexp") + """ + if len(expr) < 2: + raise CompileError("construct requires at least a name") + + _, kwargs = _parse_kwargs(expr, 2) + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + ctx.registry["constructs"][name] = { + "path": kwargs.get("path"), + "cid": kwargs.get("cid"), + "url": kwargs.get("url"), + } + return None + + +def _compile_construct_call(expr: List, ctx: CompilerContext) -> str: + """ + Compile a call to a user-defined construct. + + Creates a CONSTRUCT node that will be expanded at plan time. + + Example: + (beat-alternate beats-data (list video-a video-b)) + """ + name = expr[0].name + construct_info = ctx.registry["constructs"][name] + + # Get positional args and kwargs + args, kwargs = _parse_kwargs(expr, 1) + + # Resolve input references + resolved_args = [] + node_inputs = [] # Track actual node IDs for inputs + + for arg in args: + if isinstance(arg, Symbol) and arg.name in ctx.bindings: + node_id = ctx.bindings[arg.name] + resolved_args.append(node_id) + node_inputs.append(node_id) + elif isinstance(arg, list) and arg and isinstance(arg[0], Symbol): + # Check if it's a literal list expression like (list video-a video-b) + if arg[0].name == "list": + # Resolve each element of the list + list_items = [] + for item in arg[1:]: + if isinstance(item, Symbol) and item.name in ctx.bindings: + list_items.append(ctx.bindings[item.name]) + node_inputs.append(ctx.bindings[item.name]) + else: + list_items.append(item) + resolved_args.append(list_items) + else: + # Try to compile as an expression + try: + node_id = _compile_expr(arg, ctx) + if node_id: + resolved_args.append(node_id) + node_inputs.append(node_id) + else: + resolved_args.append(arg) + except CompileError: + resolved_args.append(arg) + else: + resolved_args.append(arg) + + # Also scan kwargs for Symbol references to nodes (like analysis nodes) + # Helper to extract node IDs from a value (handles nested lists/dicts) + def extract_node_ids(val): + if isinstance(val, str) and len(val) == 64: + return [val] + elif isinstance(val, list): + ids = [] + for item in val: + ids.extend(extract_node_ids(item)) + return ids + elif isinstance(val, dict): + ids = [] + for v in val.values(): + ids.extend(extract_node_ids(v)) + return ids + return [] + + for key, value in kwargs.items(): + if isinstance(value, Symbol) and value.name in ctx.bindings: + binding_value = ctx.bindings[value.name] + # If it's a node ID (string hash), add to inputs + if isinstance(binding_value, str) and len(binding_value) == 64: + node_inputs.append(binding_value) + # Also scan lists/dicts for node IDs (e.g., video_infos list) + elif isinstance(binding_value, (list, dict)): + node_inputs.extend(extract_node_ids(binding_value)) + + node_id = ctx.add_node( + "CONSTRUCT", + { + "construct_name": name, + "construct_path": construct_info.get("path"), + "args": resolved_args, + # Include bindings so reducer lambda can reference video sources etc. + "bindings": dict(ctx.bindings), + **kwargs, + }, + inputs=node_inputs, + ) + return node_id + + +def _compile_include(expr: List, ctx: CompilerContext) -> None: + """ + Compile (include :path "...") or (include name :path "..."). + + Loads an external .sexp file and processes its declarations/definitions. + Supports analyzer, effect, construct declarations and def bindings. + + Forms: + (include :path "libs/standard-effects.sexp") ; declaration-only + (include :cid "bafy...") ; from L1/L2 cache + (include preset-name :path "presets/all.sexp") ; binds result to name + + Included files can contain: + - (analyzer name :path "...") declarations + - (effect name :path "...") declarations + - (construct name :path "...") declarations + - (deftemplate name (params...) body...) template definitions + - (def name value) bindings + + For web-based systems: + - :cid loads from L1 local cache or L2 shared cache + - :path is for local development + + Example library file (libs/standard-analyzers.sexp): + ;; Standard audio analyzers + (analyzer beats :path "../artdag-analyzers/beats/analyzer.py") + (analyzer bass :path "../artdag-analyzers/bass/analyzer.py") + (analyzer energy :path "../artdag-analyzers/energy/analyzer.py") + + Example usage: + (include :path "libs/standard-analyzers.sexp") + (include :path "libs/all-effects.sexp") + ;; Now beats, bass, energy analyzers and all effects are available + """ + from pathlib import Path + from .parser import parse_all + from .evaluator import evaluate + + _, kwargs = _parse_kwargs(expr, 1) + + # Name is optional - check if first arg is a symbol (name) or keyword + name = None + if len(expr) >= 2 and isinstance(expr[1], Symbol) and not str(expr[1].name).startswith(":"): + name = expr[1].name + _, kwargs = _parse_kwargs(expr, 2) + + path = kwargs.get("path") + cid = kwargs.get("cid") + + if not path and not cid: + raise CompileError("include requires :path or :cid") + + content = None + + if cid: + # Load from content-addressed cache (L1 local / L2 shared) + content = _load_from_cache(cid, ctx) + + if content is None and path: + # Load from local path + include_path = Path(path) + + # Try relative to recipe directory first + if hasattr(ctx, 'recipe_dir') and ctx.recipe_dir: + recipe_relative = ctx.recipe_dir / path + if recipe_relative.exists(): + include_path = recipe_relative + + # Try relative to cwd + if not include_path.exists(): + import os + cwd = Path(os.getcwd()) + include_path = cwd / path + + if not include_path.exists(): + raise CompileError(f"Include file not found: {path}") + + content = include_path.read_text() + + # Track included file by CID for upload/caching + include_cid = compute_content_cid(content) + ctx.registry["includes"][str(include_path.resolve())] = include_cid + + if content is None: + raise CompileError(f"Could not load include: path={path}, cid={cid}") + + # Parse the included file + sexp_list = parse_all(content) + if not isinstance(sexp_list, list): + sexp_list = [sexp_list] + + # Build an environment from current bindings + env = dict(ctx.bindings) + + for sexp in sexp_list: + if isinstance(sexp, list) and sexp and isinstance(sexp[0], Symbol): + form = sexp[0].name + + if form == "def": + # (def name value) - evaluate and add to bindings + if len(sexp) != 3: + raise CompileError(f"Invalid def in include: {sexp}") + def_name = sexp[1] + if isinstance(def_name, Symbol): + def_name = def_name.name + def_value = evaluate(sexp[2], env) + env[def_name] = def_value + ctx.bindings[def_name] = def_value + + elif form == "analyzer": + # (analyzer name :path "..." [:cid "..."]) + _compile_analyzer_decl(sexp, ctx) + + elif form == "effect": + # (effect name :path "..." [:cid "..."]) + _compile_effect_decl(sexp, ctx) + + elif form == "construct": + # (construct name :path "..." [:cid "..."]) + _compile_construct_decl(sexp, ctx) + + elif form == "deftemplate": + # (deftemplate name (params...) body...) + _compile_deftemplate(sexp, ctx) + + else: + # Try to evaluate as expression + result = evaluate(sexp, env) + # If a name was provided, bind the last result + if name and result is not None: + ctx.bindings[name] = result + else: + # Evaluate as expression (e.g., bare list literal) + result = evaluate(sexp, env) + if name and result is not None: + ctx.bindings[name] = result + + return None + + +def _load_from_cache(cid: str, ctx: CompilerContext) -> Optional[str]: + """ + Load content from L1 (local) or L2 (shared) cache by CID. + + Cache hierarchy: + L1: Local file cache (~/.artdag/cache/{cid}) + L2: Shared/network cache (IPFS, HTTP gateway, etc.) + + Returns file content as string, or None if not found. + """ + from pathlib import Path + import os + + # L1: Local cache directory + cache_dir = Path(os.path.expanduser("~/.artdag/cache")) + l1_path = cache_dir / cid + + if l1_path.exists(): + return l1_path.read_text() + + # L2: Try shared cache sources + content = _load_from_l2(cid, ctx) + + if content: + # Store in L1 for future use + cache_dir.mkdir(parents=True, exist_ok=True) + l1_path.write_text(content) + + return content + + +def _load_from_l2(cid: str, ctx: CompilerContext) -> Optional[str]: + """ + Load content from L2 shared cache. + + Supports: + - IPFS gateways (if CID starts with 'bafy' or 'Qm') + - HTTP URLs (if configured in ctx.l2_sources) + - Custom backends (extensible) + + Returns content as string, or None if not available. + """ + import urllib.request + import urllib.error + + # IPFS gateway (public, for development) + if cid.startswith("bafy") or cid.startswith("Qm"): + gateways = [ + f"https://ipfs.io/ipfs/{cid}", + f"https://dweb.link/ipfs/{cid}", + f"https://cloudflare-ipfs.com/ipfs/{cid}", + ] + for gateway_url in gateways: + try: + with urllib.request.urlopen(gateway_url, timeout=10) as response: + return response.read().decode('utf-8') + except (urllib.error.URLError, urllib.error.HTTPError): + continue + + # Custom L2 sources from context (e.g., private cache server) + l2_sources = getattr(ctx, 'l2_sources', []) + for source in l2_sources: + try: + url = f"{source}/{cid}" + with urllib.request.urlopen(url, timeout=10) as response: + return response.read().decode('utf-8') + except (urllib.error.URLError, urllib.error.HTTPError): + continue + + return None + + +def _compile_def(expr: List, ctx: CompilerContext) -> None: + """Compile (def name expr)""" + if len(expr) != 3: + raise CompileError("def requires exactly 2 arguments: name and expression") + + name = expr[1] + if not isinstance(name, Symbol): + raise CompileError(f"def name must be a symbol, got {name}") + + # If binding already exists (e.g. from command-line param), don't override + # This allows recipes to specify defaults that command-line params can override + if name.name in ctx.bindings: + return None + + body = expr[2] + + # Check if body is a simple value (number, string, etc.) + if isinstance(body, (int, float, str, bool)): + ctx.bindings[name.name] = body + return None + + node_id = _compile_expr(body, ctx) + + # Multi-scan dict emit: expand field bindings + if isinstance(node_id, dict) and node_id.get("_multi_scan"): + for field_name, field_node_id in node_id["fields"].items(): + binding_name = f"{name.name}-{field_name}" + ctx.bindings[binding_name] = field_node_id + if field_node_id in ctx.nodes: + ctx.nodes[field_node_id]["name"] = binding_name + return None + + # If result is a simple value (from evaluated pure function), store it directly + # This includes lists, tuples, dicts from pure functions like `list` + if isinstance(node_id, (int, float, str, bool, list, tuple, dict)): + ctx.bindings[name.name] = node_id + return None + + if node_id is None: + raise CompileError(f"def body must produce a node or value") + + # Store binding for reference resolution + ctx.bindings[name.name] = node_id + + # Also store the name on the node so planner can reference it + if node_id in ctx.nodes: + ctx.nodes[node_id]["name"] = name.name + + return None + + +def _compile_stage(expr: List, ctx: CompilerContext) -> Optional[str]: + """ + Compile (stage :name :requires [...] :inputs [...] :outputs [...] body...). + + Stage form enables explicit dependency declaration, parallel execution, + and variable scoping. + + Example: + (stage :analyze-a + :outputs [beats-a] + (def beats-a (-> audio-a (analyze beats)))) + + (stage :plan-a + :requires [:analyze-a] + :inputs [beats-a] + :outputs [segments-a] + (def segments-a (make-segments :beats beats-a))) + """ + if len(expr) < 2: + raise CompileError("stage requires at least a name") + + # Parse stage name (first element after 'stage' should be a keyword like :analyze-a) + # The stage name is NOT a key-value pair - it's a standalone keyword + stage_name = None + start_idx = 1 + + if len(expr) > 1: + first_arg = expr[1] + if isinstance(first_arg, Keyword): + stage_name = first_arg.name + start_idx = 2 + elif isinstance(first_arg, Symbol): + stage_name = first_arg.name + start_idx = 2 + + if stage_name is None: + raise CompileError("stage requires a name (e.g., (stage :analyze-a ...))") + + # Now parse remaining kwargs and body + args, kwargs = _parse_kwargs(expr, start_idx) + + # Parse requires, inputs, outputs + requires = [] + if "requires" in kwargs: + req_val = kwargs["requires"] + if isinstance(req_val, list): + for r in req_val: + if isinstance(r, Keyword): + requires.append(r.name) + elif isinstance(r, Symbol): + requires.append(r.name) + elif isinstance(r, str): + requires.append(r) + else: + raise CompileError(f"Invalid require: {r}") + else: + raise CompileError(":requires must be a list") + + inputs = [] + if "inputs" in kwargs: + inp_val = kwargs["inputs"] + if isinstance(inp_val, list): + for i in inp_val: + if isinstance(i, Symbol): + inputs.append(i.name) + elif isinstance(i, str): + inputs.append(i) + else: + raise CompileError(f"Invalid input: {i}") + else: + raise CompileError(":inputs must be a list") + + outputs = [] + if "outputs" in kwargs: + out_val = kwargs["outputs"] + if isinstance(out_val, list): + for o in out_val: + if isinstance(o, Symbol): + outputs.append(o.name) + elif isinstance(o, str): + outputs.append(o) + else: + raise CompileError(f"Invalid output: {o}") + else: + raise CompileError(":outputs must be a list") + + # Validate requires - must reference defined stages + for req in requires: + if req not in ctx.defined_stages: + raise CompileError( + f"Stage '{stage_name}' requires undefined stage '{req}'" + ) + + # Validate inputs - must be produced by required stages + for inp in inputs: + found = False + for req in requires: + if inp in ctx.defined_stages[req].output_bindings: + found = True + break + if not found and inp not in ctx.pre_stage_bindings: + raise CompileError( + f"Stage '{stage_name}' declares input '{inp}' " + f"which is not an output of any required stage or pre-stage binding" + ) + + # Check for circular dependencies (simple check for now) + # A more thorough check would use topological sort + visited = set() + def check_cycle(stage: str, path: List[str]): + if stage in path: + cycle = " -> ".join(path + [stage]) + raise CompileError(f"Circular stage dependency: {cycle}") + if stage in visited: + return + visited.add(stage) + if stage in ctx.defined_stages: + for req in ctx.defined_stages[stage].requires: + check_cycle(req, path + [stage]) + + for req in requires: + check_cycle(req, [stage_name]) + + # Save context state before entering stage + prev_stage = ctx.current_stage + prev_stage_node_ids = ctx.stage_node_ids + + # Enter stage context + ctx.current_stage = stage_name + ctx.stage_node_ids = [] + + # Build accessible bindings for this stage + stage_ctx_bindings = dict(ctx.pre_stage_bindings) + + # Add input bindings from required stages + for inp in inputs: + for req in requires: + if inp in ctx.defined_stages[req].output_bindings: + stage_ctx_bindings[inp] = ctx.defined_stages[req].output_bindings[inp] + break + + # Save current bindings and set up stage bindings + prev_bindings = ctx.bindings + ctx.bindings = stage_ctx_bindings + + # Compile body expressions + # Body expressions are lists or symbols after the stage name and kwargs + # Start from index 2 (after 'stage' and stage name) + body_exprs = [] + i = 2 # Skip 'stage' and stage name + while i < len(expr): + item = expr[i] + if isinstance(item, Keyword): + # Skip keyword and its value + i += 2 + elif isinstance(item, (list, Symbol)): + # Include both list expressions and symbol references + body_exprs.append(item) + i += 1 + else: + i += 1 + + last_result = None + for body_expr in body_exprs: + result = _compile_expr(body_expr, ctx) + if result is not None: + last_result = result + + # Collect output bindings + output_bindings = {} + for out in outputs: + if out in ctx.bindings: + output_bindings[out] = ctx.bindings[out] + else: + raise CompileError( + f"Stage '{stage_name}' declares output '{out}' " + f"but it was not defined in the stage body" + ) + + # Create CompiledStage + compiled_stage = CompiledStage( + name=stage_name, + requires=requires, + inputs=inputs, + outputs=outputs, + node_ids=ctx.stage_node_ids, + output_bindings=output_bindings, + ) + + # Register the stage + ctx.defined_stages[stage_name] = compiled_stage + ctx.stage_bindings[stage_name] = output_bindings + + # Restore context state + ctx.current_stage = prev_stage + ctx.stage_node_ids = prev_stage_node_ids + ctx.bindings = prev_bindings + + # Make stage outputs available to subsequent stages via bindings + ctx.bindings.update(output_bindings) + + return last_result + + +def _compile_threading(expr: List, ctx: CompilerContext) -> str: + """ + Compile (-> expr1 expr2 expr3 ...) + + Each expression's output becomes the implicit first input of the next. + """ + if len(expr) < 2: + raise CompileError("-> requires at least one expression") + + prev_node_id = None + + for i, sub_expr in enumerate(expr[1:]): + if prev_node_id is not None: + # Inject previous node as first input + sub_expr = _inject_input(sub_expr, prev_node_id) + + prev_node_id = _compile_expr(sub_expr, ctx) + + if prev_node_id is None: + raise CompileError(f"Expression {i} in -> chain produced no node") + + return prev_node_id + + +def _inject_input(expr: Any, input_id: str) -> List: + """Inject an input node ID into an expression.""" + if not isinstance(expr, list): + # Symbol reference - wrap in a node that takes input + if isinstance(expr, Symbol): + # Assume it's an effect name + return [Symbol("effect"), expr, Symbol(f"__input_{input_id}")] + raise CompileError(f"Cannot inject input into {expr}") + + # For node expressions, we'll handle the input in the compiler + # Mark it with a special __prev__ reference + return expr + [Symbol("__prev__"), input_id] + + +def _resolve_input(arg: Any, ctx: CompilerContext, prev_id: str = None) -> str: + """Resolve an argument to a node ID.""" + if isinstance(arg, Symbol): + if arg.name == "__prev__": + if prev_id is None: + raise CompileError("__prev__ used outside threading context") + return prev_id + if arg.name.startswith("__input_"): + return arg.name[8:] # Strip __input_ prefix + if arg.name in ctx.bindings: + return ctx.bindings[arg.name] + raise CompileError(f"Undefined reference: {arg.name}") + + if isinstance(arg, str): + # Direct node ID + return arg + + if isinstance(arg, list): + # Nested expression + return _compile_expr(arg, ctx) + + raise CompileError(f"Cannot resolve input: {arg}") + + +def _extract_prev_id(args: List, kwargs: Dict) -> Tuple[List, Dict, Optional[str]]: + """Extract __prev__ marker from args if present.""" + prev_id = None + new_args = [] + + i = 0 + while i < len(args): + if isinstance(args[i], Symbol) and args[i].name == "__prev__": + if i + 1 < len(args): + prev_id = args[i + 1] + i += 2 + continue + new_args.append(args[i]) + i += 1 + + return new_args, kwargs, prev_id + + +def _compile_source(expr: List, ctx: CompilerContext) -> str: + """ + Compile (source asset-name), (source :input "name" ...), or (source :path "file.mkv" ...). + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, _ = _extract_prev_id(args, kwargs) + + if "input" in kwargs: + # Variable input - :input can be followed by a name string + input_val = kwargs["input"] + if isinstance(input_val, str): + # (source :input "User Video" :description "...") + name = input_val + else: + # (source :input true :name "User Video") + name = kwargs.get("name", "Input") + config = { + "input": True, + "name": name, + "description": kwargs.get("description", ""), + } + elif "path" in kwargs: + # Local file path - for development/testing + # (source :path "dog.mkv" :description "Input video") + path = kwargs["path"] + config = { + "path": path, + "description": kwargs.get("description", ""), + } + elif args: + # Asset reference + asset_name = args[0] + if isinstance(asset_name, Symbol): + asset_name = asset_name.name + config = {"asset": asset_name} + else: + raise CompileError("source requires asset name, :input flag, or :path") + + return ctx.add_node("SOURCE", config) + + +def _compile_effect_node(expr: List, ctx: CompilerContext) -> str: + """ + Compile (effect effect-name [input-nodes...] :param value ...). + + Single input: + (effect rotate video :angle 45) + (-> video (effect rotate :angle 45)) + + Multi-input (blend, layer, etc.): + (effect blend video-a video-b :mode "overlay") + (-> video-a (effect blend video-b :mode "overlay")) + + Parameters can be literals or bind expressions: + (effect brightness video :level (bind analysis :energy :range [0 1])) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + if not args: + raise CompileError("effect requires effect name") + + effect_name = args[0] + if isinstance(effect_name, Symbol): + effect_name = effect_name.name + + config = {"effect": effect_name} + + # Look up effect info from registry + effects_registry = ctx.registry.get("effects", {}) + if effect_name in effects_registry: + effect_info = effects_registry[effect_name] + if isinstance(effect_info, dict): + if "path" in effect_info: + config["effect_path"] = effect_info["path"] + if "cid" in effect_info and effect_info["cid"]: + config["effect_cid"] = effect_info["cid"] + elif isinstance(effect_info, str): + config["effect_path"] = effect_info + + # Include full effects_registry with cids for workers to fetch dependencies + # Only include effects that have cids (content-addressed) + effects_with_cids = {} + for name, info in effects_registry.items(): + if isinstance(info, dict) and info.get("cid"): + effects_with_cids[name] = info["cid"] + if effects_with_cids: + config["effects_registry"] = effects_with_cids + + # Process parameter values, looking for bind expressions + # Also track analysis references for workers + analysis_refs = set() + for k, v in kwargs.items(): + if k not in ("hash", "url"): + processed = _process_value(v, ctx) + config[k] = processed + # Extract analysis references from bind expressions + _extract_analysis_refs(processed, analysis_refs) + + if analysis_refs: + config["analysis_refs"] = list(analysis_refs) + + # Collect inputs - first from threading (prev_id), then from additional args + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args[1:]: + # Handle list of inputs: (effect blend [video-a video-b] :mode "overlay") + if isinstance(arg, list) and arg and not isinstance(arg[0], Symbol): + for item in arg: + inputs.append(_resolve_input(item, ctx, prev_id)) + else: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + # Auto-detect multi-input effects + if len(inputs) > 1: + config["multi_input"] = True + + return ctx.add_node("EFFECT", config, inputs) + + +def _extract_analysis_refs(value: Any, refs: set) -> None: + """Extract analysis node references from a processed value. + + Bind expressions contain references to analysis nodes. This function + extracts those references so workers know which analysis data they need. + """ + if isinstance(value, dict): + # Check if this is a bind expression (has _binding flag or source/ref key) + if value.get("_binding") or "bind" in value or "ref" in value or "source" in value: + ref = value.get("source") or value.get("ref") or value.get("bind") + if ref: + refs.add(ref) + # Recursively check nested dicts + for v in value.values(): + _extract_analysis_refs(v, refs) + elif isinstance(value, list): + for item in value: + _extract_analysis_refs(item, refs) + + +def _compile_segment(expr: List, ctx: CompilerContext) -> str: + """Compile (segment :start 0.0 :end 2.0 [input]).""" + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + config = {} + analysis_refs = set() + + if "start" in kwargs: + val = _process_value(kwargs["start"], ctx) + # Binding dicts are preserved for runtime resolution, None values are skipped + if val is not None: + config["start"] = val if isinstance(val, dict) and val.get("_binding") else float(val) + _extract_analysis_refs(config.get("start"), analysis_refs) + if "end" in kwargs: + val = _process_value(kwargs["end"], ctx) + if val is not None: + config["end"] = val if isinstance(val, dict) and val.get("_binding") else float(val) + _extract_analysis_refs(config.get("end"), analysis_refs) + if "duration" in kwargs: + val = _process_value(kwargs["duration"], ctx) + if val is not None: + config["duration"] = val if isinstance(val, dict) and val.get("_binding") else float(val) + _extract_analysis_refs(config.get("duration"), analysis_refs) + + if analysis_refs: + config["analysis_refs"] = list(analysis_refs) + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("SEGMENT", config, inputs) + + +def _compile_resize(expr: List, ctx: CompilerContext) -> str: + """ + Compile (resize width height :mode "linear" [input]). + + Resize is now an EFFECT that uses the sexp resize-frame effect. + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + if len(args) < 2: + raise CompileError("resize requires width and height") + + # Create EFFECT node with resize effect + # Note: param names match resize.sexp (target-w, target-h to avoid primitive conflict) + config = { + "effect": "resize-frame", + "effect_path": "sexp_effects/effects/resize-frame.sexp", + "target-w": int(args[0]), + "target-h": int(args[1]), + "mode": kwargs.get("mode", "linear"), + } + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args[2:]: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("EFFECT", config, inputs) + + +def _compile_sequence(expr: List, ctx: CompilerContext) -> str: + """ + Compile (sequence node1 node2 ... :resize-mode :fit :priority :width). + + Options: + :transition - transition between clips (default: cut) + :resize-mode - fit | crop | stretch | cover (default: none) + :priority - width | height (which dimension to match exactly) + :target-width - explicit target width + :target-height - explicit target height + :pad-color - color for fit mode padding (default: black) + :crop-gravity - center | top | bottom | left | right (default: center) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + config = { + "transition": kwargs.get("transition", {"type": "cut"}), + } + + # Add normalize config if specified + resize_mode = kwargs.get("resize-mode") + if isinstance(resize_mode, (Symbol, Keyword)): + resize_mode = resize_mode.name + if resize_mode: + config["resize_mode"] = resize_mode + + priority = kwargs.get("priority") + if isinstance(priority, (Symbol, Keyword)): + priority = priority.name + if priority: + config["priority"] = priority + + if kwargs.get("target-width"): + config["target_width"] = kwargs["target-width"] + if kwargs.get("target-height"): + config["target_height"] = kwargs["target-height"] + + pad_color = kwargs.get("pad-color") + if isinstance(pad_color, (Symbol, Keyword)): + pad_color = pad_color.name + config["pad_color"] = pad_color or "black" + + crop_gravity = kwargs.get("crop-gravity") + if isinstance(crop_gravity, (Symbol, Keyword)): + crop_gravity = crop_gravity.name + config["crop_gravity"] = crop_gravity or "center" + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("SEQUENCE", config, inputs) + + +def _compile_mux(expr: List, ctx: CompilerContext) -> str: + """Compile (mux video-node audio-node).""" + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + config = { + "video_stream": 0, + "audio_stream": 1, + "shortest": kwargs.get("shortest", True), + } + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args: + inputs.append(_resolve_input(arg, ctx, prev_id)) + + if len(inputs) < 2: + raise CompileError("mux requires video and audio inputs") + + return ctx.add_node("MUX", config, inputs) + + +def _compile_slice_on(expr: List, ctx: CompilerContext) -> str: + """ + Compile slice-on with either legacy or lambda syntax. + + Legacy syntax: + (slice-on video analysis :times path :effect fx :pattern pat) + + Lambda syntax: + (slice-on analysis + :times times + :init 0 + :fn (lambda [acc i start end] + {:source video + :effects (if (odd? i) [invert] []) + :acc (inc acc)})) + + Args: + video: input video node (legacy) or omitted (lambda) + analysis: analysis node with times array + :times - path to times array in analysis + :effect - effect to apply (legacy, optional) + :pattern - all, odd, even, alternate (legacy, default: all) + :init - initial accumulator value (lambda) + :fn - reducer lambda function (lambda) + """ + from .parser import Lambda + + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + # Check for lambda mode + reducer_fn = kwargs.get("fn") + + # Parse lambda if it's a list + if isinstance(reducer_fn, list): + reducer_fn = _parse_lambda(reducer_fn) + + # Lambda mode: only analysis input required (sources come from fn) + # Legacy mode: requires video and analysis inputs + if reducer_fn is not None: + # Lambda mode - just need analysis input + if len(args) < 1: + raise CompileError("slice-on requires analysis input") + analysis_input = _resolve_input(args[0], ctx, prev_id) + inputs = [analysis_input] + else: + # Legacy mode - need video and analysis inputs + if len(args) < 2: + raise CompileError("slice-on requires video and analysis inputs") + video_input = _resolve_input(args[0], ctx, prev_id) + analysis_input = _resolve_input(args[1], ctx, prev_id) + inputs = [video_input, analysis_input] + + times_path = kwargs.get("times", "times") + if isinstance(times_path, Symbol): + times_path = times_path.name + + config = { + "times_path": times_path, + "fn": reducer_fn, + "init": kwargs.get("init", 0), + # Include bindings so lambda can reference video sources etc. + "bindings": dict(ctx.bindings), + } + + # Optional :videos list for multi-source composition mode + videos_list = kwargs.get("videos") + if videos_list is not None: + if not isinstance(videos_list, list): + raise CompileError(":videos must be a list") + resolved_videos = [] + for v in videos_list: + resolved_videos.append(_resolve_input(v, ctx, None)) + config["videos"] = resolved_videos + # Add to inputs so planner knows about dependencies + for vid in resolved_videos: + if vid not in inputs: + inputs.append(vid) + + return ctx.add_node("SLICE_ON", config, inputs) + + +def _parse_lambda(expr: List): + """Parse a lambda expression list into a Lambda object.""" + from .parser import Lambda, Symbol + + if not expr or not isinstance(expr[0], Symbol): + raise CompileError("Invalid lambda expression") + + name = expr[0].name + if name not in ("lambda", "fn"): + raise CompileError(f"Expected lambda or fn, got {name}") + + if len(expr) < 3: + raise CompileError("lambda requires params and body") + + params = expr[1] + if not isinstance(params, list): + raise CompileError("lambda params must be a list") + + param_names = [] + for p in params: + if isinstance(p, Symbol): + param_names.append(p.name) + elif isinstance(p, str): + param_names.append(p) + else: + raise CompileError(f"Invalid lambda param: {p}") + + return Lambda(param_names, expr[2]) + + +def _compile_analyze(expr: List, ctx: CompilerContext) -> str: + """ + Compile (analyze analyzer-name :param value ...). + + Example: + (analyze beats) + (analyze beats :min-bpm 120 :max-bpm 180) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + # First arg is analyzer name + if not args: + raise CompileError("analyze requires analyzer name") + + analyzer_name = args[0] + if isinstance(analyzer_name, Symbol): + analyzer_name = analyzer_name.name + + # Look up analyzer in registry + analyzer_entry = ctx.registry.get("analyzers", {}).get(analyzer_name, {}) + + config = { + "analyzer": analyzer_name, + "analyzer_path": analyzer_entry.get("path"), + "cid": analyzer_entry.get("cid"), + } + # Add params (kwargs) to config + config.update(kwargs) + + inputs = [] + if prev_id: + inputs.append(prev_id if isinstance(prev_id, str) else str(prev_id)) + for arg in args[1:]: # Skip analyzer name + inputs.append(_resolve_input(arg, ctx, prev_id)) + + return ctx.add_node("ANALYZE", config, inputs) + + +def _compile_bind(expr: List, ctx: CompilerContext) -> Dict[str, Any]: + """ + Compile (bind source feature :option value ...). + + Returns a binding specification dict (not a node ID). + + Examples: + (bind analysis :energy) + (bind analysis :energy :range [0 1]) + (bind analysis :beats :on-event 1.0 :decay 0.1) + (bind analysis :energy :range [0 1] :smooth 0.05 :noise 0.1 :seed 42) + """ + args, kwargs = _parse_kwargs(expr, 1) + + if len(args) < 2: + raise CompileError("bind requires source and feature: (bind source :feature ...)") + + source = args[0] + feature = args[1] + + # Source can be a symbol reference + source_ref = None + if isinstance(source, Symbol): + if source.name in ctx.bindings: + source_ref = ctx.bindings[source.name] + else: + source_ref = source.name + + # Feature should be a keyword + feature_name = None + if isinstance(feature, Keyword): + feature_name = feature.name + elif isinstance(feature, Symbol): + feature_name = feature.name + else: + raise CompileError(f"bind feature must be a keyword, got {feature}") + + binding = { + "_binding": True, # Marker for binding resolution + "source": source_ref, + "feature": feature_name, + } + + # Add optional binding modifiers + if "range" in kwargs: + range_val = kwargs["range"] + if isinstance(range_val, list) and len(range_val) == 2: + binding["range"] = [float(range_val[0]), float(range_val[1])] + else: + raise CompileError("bind :range must be [lo hi]") + + if "smooth" in kwargs: + binding["smooth"] = float(kwargs["smooth"]) + + if "offset" in kwargs: + binding["offset"] = float(kwargs["offset"]) + + if "on-event" in kwargs: + binding["on_event"] = float(kwargs["on-event"]) + + if "decay" in kwargs: + binding["decay"] = float(kwargs["decay"]) + + if "noise" in kwargs: + binding["noise"] = float(kwargs["noise"]) + + if "seed" in kwargs: + binding["seed"] = int(kwargs["seed"]) + + return binding + + +def _process_value(value: Any, ctx: CompilerContext) -> Any: + """ + Process a value, resolving nested expressions like bind and math. + + Returns the processed value (could be a binding dict, expression dict, node ref, or literal). + + Supported expressions: + (bind source feature :range [lo hi]) - bind to analysis data + (+ a b), (- a b), (* a b), (/ a b), (mod a b) - math operations + time - current frame time in seconds + frame - current frame number + """ + # Math operators that create runtime expressions + MATH_OPS = {'+', '-', '*', '/', 'mod', 'min', 'max', 'abs', 'sin', 'cos', + 'if', '<', '>', '<=', '>=', '=', + 'rand', 'rand-int', 'rand-range', + 'floor', 'ceil', 'nth'} + + if isinstance(value, Symbol): + # Special runtime symbols + if value.name == "time": + return {"_expr": True, "op": "time"} + if value.name == "frame": + return {"_expr": True, "op": "frame"} + # Resolve symbol from bindings + if value.name in ctx.bindings: + return ctx.bindings[value.name] + # Return as-is if not found (could be an effect reference, etc.) + return value + + if isinstance(value, list) and len(value) > 0: + head = value[0] + head_name = head.name if isinstance(head, Symbol) else None + + if head_name == "bind": + return _compile_bind(value, ctx) + + # Handle lambda expressions - parse but don't compile + if head_name in ("lambda", "fn"): + return _parse_lambda(value) + + # Handle dict expressions - keyword-value pairs for runtime dict construction + if head_name == "dict": + keys = [] + vals = [] + i = 1 + while i < len(value): + if isinstance(value[i], Keyword): + keys.append(value[i].name) + if i + 1 < len(value): + vals.append(_process_value(value[i + 1], ctx)) + i += 2 + else: + i += 1 + return {"_expr": True, "op": "dict", "keys": keys, "args": vals} + + # Handle math expressions - preserve for runtime evaluation + if head_name in MATH_OPS: + processed_args = [_process_value(arg, ctx) for arg in value[1:]] + return {"_expr": True, "op": head_name, "args": processed_args} + + # Could be other nested expressions + return _compile_expr(value, ctx) + + return value + + +def _compile_scan_expr(value: Any, ctx: CompilerContext) -> Any: + """ + Compile an expression for use in scan step/emit. + + Like _process_value but treats unbound symbols as runtime variable + references (for acc, dict fields like rem/hue, etc.). + """ + SCAN_OPS = { + '+', '-', '*', '/', 'mod', 'min', 'max', 'abs', 'sin', 'cos', + 'if', '<', '>', '<=', '>=', '=', + 'rand', 'rand-int', 'rand-range', + 'floor', 'ceil', 'nth', + } + + if isinstance(value, (int, float)): + return value + + if isinstance(value, Keyword): + return value.name + + if isinstance(value, Symbol): + # Known runtime symbols + if value.name in ("time", "frame"): + return {"_expr": True, "op": value.name} + # Check bindings for compile-time constants (e.g., recipe params) + if value.name in ctx.bindings: + bound = ctx.bindings[value.name] + if isinstance(bound, (int, float, str, bool)): + return bound + # Runtime variable reference (acc, rem, hue, etc.) + return {"_expr": True, "op": "var", "name": value.name} + + if isinstance(value, list) and len(value) > 0: + head = value[0] + head_name = head.name if isinstance(head, Symbol) else None + + if head_name == "dict": + # (dict :key1 val1 :key2 val2) + keys = [] + args = [] + i = 1 + while i < len(value): + if isinstance(value[i], Keyword): + keys.append(value[i].name) + if i + 1 < len(value): + args.append(_compile_scan_expr(value[i + 1], ctx)) + i += 2 + else: + i += 1 + return {"_expr": True, "op": "dict", "keys": keys, "args": args} + + if head_name in SCAN_OPS: + processed_args = [_compile_scan_expr(arg, ctx) for arg in value[1:]] + return {"_expr": True, "op": head_name, "args": processed_args} + + # Fall through to _process_value for bind expressions, etc. + return _process_value(value, ctx) + + return value + + +def _eval_const_expr(value, ctx: 'CompilerContext'): + """Evaluate a compile-time constant expression. + + Supports literals, symbol lookups in ctx.bindings, and basic arithmetic. + Used for values like scan :seed that must resolve to a number at compile time. + """ + if isinstance(value, (int, float)): + return value + if isinstance(value, Symbol): + if value.name in ctx.bindings: + bound = ctx.bindings[value.name] + if isinstance(bound, (int, float)): + return bound + raise CompileError(f"Cannot resolve symbol '{value.name}' to a constant") + if isinstance(value, list) and len(value) >= 1: + head = value[0] + if isinstance(head, Symbol): + name = head.name + if name == 'next-seed' and len(value) == 2: + rng_val = _resolve_rng_value(value[1], ctx) + return _derive_seed(rng_val) + args = [_eval_const_expr(a, ctx) for a in value[1:]] + if name == '+' and len(args) >= 2: + return args[0] + args[1] + if name == '-' and len(args) >= 2: + return args[0] - args[1] + if name == '*' and len(args) >= 2: + return args[0] * args[1] + if name == '/' and len(args) >= 2: + return args[0] / args[1] if args[1] != 0 else 0 + if name == 'mod' and len(args) >= 2: + return args[0] % args[1] if args[1] != 0 else 0 + raise CompileError(f"Unsupported constant expression operator: {name}") + raise CompileError(f"Cannot evaluate as constant: {value}") + + +def _derive_seed(rng_val: dict) -> int: + """Derive next unique seed from RNG value, incrementing counter.""" + master = rng_val["master_seed"] + counter = rng_val["_counter"] + digest = hashlib.sha256(f"{master}:{counter[0]}".encode()).hexdigest()[:8] + seed = int(digest, 16) + counter[0] += 1 + return seed + + +def _resolve_rng_value(ref, ctx) -> dict: + """Resolve a reference to an RNG value dict.""" + if isinstance(ref, dict) and ref.get("_rng"): + return ref + if isinstance(ref, Symbol): + if ref.name in ctx.bindings: + val = ctx.bindings[ref.name] + if isinstance(val, dict) and val.get("_rng"): + return val + raise CompileError(f"Symbol '{ref.name}' is not an RNG value") + raise CompileError(f"Expected RNG value, got {type(ref).__name__}") + + +def _compile_make_rng(expr, ctx): + """(make-rng SEED) -> compile-time RNG value dict.""" + if len(expr) != 2: + raise CompileError("make-rng requires exactly 1 argument: seed") + seed_val = _eval_const_expr(expr[1], ctx) + return {"_rng": True, "master_seed": int(seed_val), "_counter": [0]} + + +def _compile_next_seed(expr, ctx): + """(next-seed RNG) -> integer seed drawn from RNG.""" + if len(expr) != 2: + raise CompileError("next-seed requires exactly 1 argument: rng") + rng_val = _resolve_rng_value(expr[1], ctx) + return _derive_seed(rng_val) + + +def _compile_scan(expr: List, ctx: CompilerContext) -> str: + """ + Compile (scan source :seed N :init EXPR :step EXPR :emit EXPR). + + Creates a SCAN node that produces a time-series by iterating over + source analysis events with a step function and emit expression. + + The accumulator can be a number or a dict. Dict field names become + accessible as variables in step/emit expressions. + + The :seed parameter supports compile-time constant expressions, + e.g. (+ seed 100) where seed is a template parameter. + + Examples: + ;; Simple counter accumulator + (scan beat-data :seed 42 :init 0 + :step (if (> acc 0) (- acc 1) (if (< (rand) 0.1) (rand-int 1 5) 0)) + :emit (if (> acc 0) 1 0)) + + ;; Dict accumulator with named fields + (scan beat-data :seed 101 :init (dict :rem 0 :hue 0) + :step (if (> rem 0) + (dict :rem (- rem 1) :hue hue) + (if (< (rand) 0.1) + (dict :rem (rand-int 1 5) :hue (rand-range 30 330)) + (dict :rem 0 :hue 0))) + :emit (if (> rem 0) hue 0)) + """ + args, kwargs = _parse_kwargs(expr, 1) + args, kwargs, prev_id = _extract_prev_id(args, kwargs) + + # Resolve source input + if prev_id: + source_input = prev_id if isinstance(prev_id, str) else str(prev_id) + elif args: + source_input = _resolve_input(args[0], ctx, None) + else: + raise CompileError("scan requires a source input") + + if "rng" in kwargs: + rng_val = _resolve_rng_value(kwargs["rng"], ctx) + seed = _derive_seed(rng_val) + else: + seed = kwargs.get("seed", 0) + seed = _eval_const_expr(seed, ctx) + + if "step" not in kwargs: + raise CompileError("scan requires :step expression") + if "emit" not in kwargs: + raise CompileError("scan requires :emit expression") + + init_expr = _compile_scan_expr(kwargs.get("init", 0), ctx) + step_expr = _compile_scan_expr(kwargs["step"], ctx) + + emit_raw = kwargs["emit"] + if isinstance(emit_raw, dict): + result = {} + for field_name, field_expr in emit_raw.items(): + field_emit = _compile_scan_expr(field_expr, ctx) + config = { + "seed": int(seed), + "init": init_expr, + "step_expr": step_expr, + "emit_expr": field_emit, + } + node_id = ctx.add_node("SCAN", config, inputs=[source_input]) + result[field_name] = node_id + return {"_multi_scan": True, "fields": result} + + emit_expr = _compile_scan_expr(emit_raw, ctx) + + config = { + "seed": int(seed), + "init": init_expr, + "step_expr": step_expr, + "emit_expr": emit_expr, + } + + return ctx.add_node("SCAN", config, inputs=[source_input]) + + +def _compile_blend_multi(expr: List, ctx: CompilerContext) -> str: + """Compile (blend-multi :videos [...] :weights [...] :mode M :resize_mode R). + + Produces a single EFFECT node that takes N video inputs and N weight + bindings, blending them in one pass via the blend_multi effect. + """ + _, kwargs = _parse_kwargs(expr, 1) + + videos = kwargs.get("videos") + weights = kwargs.get("weights") + mode = kwargs.get("mode", "alpha") + resize_mode = kwargs.get("resize_mode", "fit") + + if not videos or not weights: + raise CompileError("blend-multi requires :videos and :weights") + if not isinstance(videos, list) or not isinstance(weights, list): + raise CompileError("blend-multi :videos and :weights must be lists") + if len(videos) != len(weights): + raise CompileError( + f"blend-multi: videos ({len(videos)}) and weights " + f"({len(weights)}) must be same length" + ) + if len(videos) < 2: + raise CompileError("blend-multi requires at least 2 videos") + + # Resolve video symbols to node IDs — these become the multi-input list + input_ids = [] + for v in videos: + input_ids.append(_resolve_input(v, ctx, None)) + + # Process each weight symbol into a binding dict {_binding, source, feature} + weight_bindings = [] + for w in weights: + bind_expr = [Symbol("bind"), w, Symbol("values")] + weight_bindings.append(_process_value(bind_expr, ctx)) + + # Build EFFECT config + effects_registry = ctx.registry.get("effects", {}) + config = { + "effect": "blend_multi", + "multi_input": True, + "weights": weight_bindings, + "mode": mode, + "resize_mode": resize_mode, + } + + # Attach effect path / cid from registry + if "blend_multi" in effects_registry: + effect_info = effects_registry["blend_multi"] + if isinstance(effect_info, dict): + if "path" in effect_info: + config["effect_path"] = effect_info["path"] + if "cid" in effect_info and effect_info["cid"]: + config["effect_cid"] = effect_info["cid"] + + # Include effects registry for workers + effects_with_cids = {} + for name, info in effects_registry.items(): + if isinstance(info, dict) and info.get("cid"): + effects_with_cids[name] = info["cid"] + if effects_with_cids: + config["effects_registry"] = effects_with_cids + + # Extract analysis refs so workers know which analysis data they need + analysis_refs = set() + for wb in weight_bindings: + _extract_analysis_refs(wb, analysis_refs) + if analysis_refs: + config["analysis_refs"] = list(analysis_refs) + + return ctx.add_node("EFFECT", config, input_ids) + + +def _compile_deftemplate(expr: List, ctx: CompilerContext) -> None: + """Compile (deftemplate NAME (PARAMS...) BODY...). + + Stores the template definition in the registry for later invocation. + Returns None (definition only, no nodes). + """ + if len(expr) < 4: + raise CompileError("deftemplate requires name, params, and body") + + name = expr[1] + if isinstance(name, Symbol): + name = name.name + + params = expr[2] + if not isinstance(params, list): + raise CompileError("deftemplate params must be a list") + + param_names = [] + for p in params: + if isinstance(p, Symbol): + param_names.append(p.name) + else: + raise CompileError(f"deftemplate param must be a symbol, got {p}") + + body_forms = expr[3:] + + ctx.registry["templates"][name] = { + "params": param_names, + "body": body_forms, + } + return None + + +def _substitute_template(expr, params_map, local_names, prefix): + """Deep walk s-expression tree, substituting params and prefixing locals.""" + if isinstance(expr, Symbol): + if expr.name in params_map: + return params_map[expr.name] + if expr.name in local_names: + return Symbol(prefix + expr.name) + return expr + if isinstance(expr, list): + return [_substitute_template(e, params_map, local_names, prefix) for e in expr] + if isinstance(expr, dict): + if expr.get("_rng"): + return expr # preserve shared mutable counter + return {k: _substitute_template(v, params_map, local_names, prefix) for k, v in expr.items()} + return expr # numbers, strings, keywords, etc. + + +def _compile_template_call(expr: List, ctx: CompilerContext) -> str: + """Compile a call to a user-defined template. + + Expands the template body with parameter substitution and local name + prefixing, then compiles each resulting form. + """ + name = expr[0].name + template = ctx.registry["templates"][name] + param_names = template["params"] + body_forms = template["body"] + + # Parse keyword args from invocation + _, kwargs = _parse_kwargs(expr, 1) + + # Build param -> value map + params_map = {} + for pname in param_names: + # Convert param name to kwarg key (hyphens match keyword names) + key = pname + if key not in kwargs: + raise CompileError(f"Template '{name}' missing parameter :{key}") + params_map[pname] = kwargs[key] + + # Generate unique prefix + prefix = f"_t{ctx.template_call_count}_" + ctx.template_call_count += 1 + + # Collect local names: scan body for (def NAME ...) forms + local_names = set() + for form in body_forms: + if isinstance(form, list) and len(form) >= 2: + if isinstance(form[0], Symbol) and form[0].name == "def": + if isinstance(form[1], Symbol): + local_names.add(form[1].name) + + # Substitute and compile each body form + last_node_id = None + for form in body_forms: + substituted = _substitute_template(form, params_map, local_names, prefix) + result = _compile_expr(substituted, ctx) + if result is not None: + last_node_id = result + + return last_node_id + + +def compile_string(text: str, initial_bindings: Dict[str, Any] = None, recipe_dir: Path = None) -> CompiledRecipe: + """ + Compile an S-expression recipe string. + + Convenience function combining parse + compile. + + Args: + text: S-expression recipe string + initial_bindings: Optional dict of name -> value bindings to inject before compilation. + These can be referenced as variables in the recipe. + recipe_dir: Directory containing the recipe file, for resolving relative paths to effects etc. + """ + sexp = parse(text) + return compile_recipe(sexp, initial_bindings, recipe_dir=recipe_dir, source_text=text) diff --git a/artdag/sexp/effect_loader.py b/artdag/sexp/effect_loader.py new file mode 100644 index 0000000..bd7ce62 --- /dev/null +++ b/artdag/sexp/effect_loader.py @@ -0,0 +1,337 @@ +""" +Sexp effect loader. + +Loads sexp effect definitions (define-effect forms) and creates +frame processors that evaluate the sexp body with primitives. + +Effects must use :params syntax: + + (define-effect name + :params ( + (param1 :type int :default 8 :range [4 32] :desc "description") + (param2 :type string :default "value" :desc "description") + ) + body) + +For effects with no parameters, use empty :params (): + + (define-effect name + :params () + body) + +Unknown parameters passed to effects will raise an error. + +Usage: + loader = SexpEffectLoader() + effect_fn = loader.load_effect_file(Path("effects/ascii_art.sexp")) + output = effect_fn(input_path, output_path, config) +""" + +import logging +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import numpy as np + +from .parser import parse_all, Symbol, Keyword +from .evaluator import evaluate +from .primitives import PRIMITIVES +from .compiler import ParamDef, _parse_params, CompileError + +logger = logging.getLogger(__name__) + + +def _parse_define_effect(sexp) -> tuple: + """ + Parse a define-effect form. + + Required syntax: + (define-effect name + :params ( + (param1 :type int :default 8 :range [4 32] :desc "description") + ) + body) + + Effects MUST use :params syntax. Legacy ((param default) ...) syntax is not supported. + + Returns (name, params_with_defaults, param_defs, body) + where param_defs is a list of ParamDef objects + """ + if not isinstance(sexp, list) or len(sexp) < 3: + raise ValueError(f"Invalid define-effect form: {sexp}") + + head = sexp[0] + if not (isinstance(head, Symbol) and head.name == "define-effect"): + raise ValueError(f"Expected define-effect, got {head}") + + name = sexp[1] + if isinstance(name, Symbol): + name = name.name + + params_with_defaults = {} + param_defs: List[ParamDef] = [] + body = None + found_params = False + + # Parse :params and body + i = 2 + while i < len(sexp): + item = sexp[i] + if isinstance(item, Keyword) and item.name == "params": + # :params syntax + if i + 1 >= len(sexp): + raise ValueError(f"Effect '{name}': Missing params list after :params keyword") + try: + param_defs = _parse_params(sexp[i + 1]) + # Build params_with_defaults from ParamDef objects + for pd in param_defs: + params_with_defaults[pd.name] = pd.default + except CompileError as e: + raise ValueError(f"Effect '{name}': Error parsing :params: {e}") + found_params = True + i += 2 + elif isinstance(item, Keyword): + # Skip other keywords we don't recognize + i += 2 + elif body is None: + # First non-keyword item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy syntax and reject it + if isinstance(first_elem, list) and len(first_elem) >= 2: + raise ValueError( + f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. " + f"Use :params block instead:\n" + f" :params (\n" + f" (param_name :type int :default 0 :desc \"description\")\n" + f" )" + ) + body = item + i += 1 + else: + i += 1 + + if body is None: + raise ValueError(f"Effect '{name}': No body found") + + if not found_params: + raise ValueError( + f"Effect '{name}': Missing :params block. Effects must declare parameters.\n" + f"For effects with no parameters, use empty :params ():\n" + f" (define-effect {name}\n" + f" :params ()\n" + f" body)" + ) + + return name, params_with_defaults, param_defs, body + + +def _create_process_frame( + effect_name: str, + params_with_defaults: Dict[str, Any], + param_defs: List[ParamDef], + body: Any, +) -> Callable: + """ + Create a process_frame function that evaluates the sexp body. + + The function signature is: (frame, params, state) -> (frame, state) + """ + import math + + def process_frame(frame: np.ndarray, params: Dict[str, Any], state: Any): + """Evaluate sexp effect body on a frame.""" + # Build environment with primitives + env = dict(PRIMITIVES) + + # Add math functions + env["floor"] = lambda x: int(math.floor(x)) + env["ceil"] = lambda x: int(math.ceil(x)) + env["round"] = lambda x: int(round(x)) + env["abs"] = abs + env["min"] = min + env["max"] = max + env["sqrt"] = math.sqrt + env["sin"] = math.sin + env["cos"] = math.cos + + # Add list operations + env["list"] = lambda *args: tuple(args) + env["nth"] = lambda coll, i: coll[int(i)] if coll else None + + # Bind frame + env["frame"] = frame + + # Validate that all provided params are known + known_params = set(params_with_defaults.keys()) + for k in params.keys(): + if k not in known_params: + raise ValueError( + f"Effect '{effect_name}': Unknown parameter '{k}'. " + f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}" + ) + + # Bind parameters (defaults + overrides from config) + for param_name, default in params_with_defaults.items(): + # Use config value if provided, otherwise default + if param_name in params: + env[param_name] = params[param_name] + elif default is not None: + env[param_name] = default + + # Evaluate the body + try: + result = evaluate(body, env) + if isinstance(result, np.ndarray): + return result, state + else: + logger.warning(f"Effect {effect_name} returned {type(result)}, expected ndarray") + return frame, state + except Exception as e: + logger.error(f"Error evaluating effect {effect_name}: {e}") + raise + + return process_frame + + +def load_sexp_effect(source: str, base_path: Optional[Path] = None) -> tuple: + """ + Load a sexp effect from source code. + + Args: + source: Sexp source code + base_path: Base path for resolving relative imports + + Returns: + (effect_name, process_frame_fn, params_with_defaults, param_defs) + where param_defs is a list of ParamDef objects for introspection + """ + exprs = parse_all(source) + + # Find define-effect form + define_effect = None + if isinstance(exprs, list): + for expr in exprs: + if isinstance(expr, list) and expr and isinstance(expr[0], Symbol): + if expr[0].name == "define-effect": + define_effect = expr + break + elif isinstance(exprs, list) and exprs and isinstance(exprs[0], Symbol): + if exprs[0].name == "define-effect": + define_effect = exprs + + if not define_effect: + raise ValueError("No define-effect form found in sexp effect") + + name, params_with_defaults, param_defs, body = _parse_define_effect(define_effect) + process_frame = _create_process_frame(name, params_with_defaults, param_defs, body) + + return name, process_frame, params_with_defaults, param_defs + + +def load_sexp_effect_file(path: Path) -> tuple: + """ + Load a sexp effect from file. + + Returns: + (effect_name, process_frame_fn, params_with_defaults, param_defs) + where param_defs is a list of ParamDef objects for introspection + """ + source = path.read_text() + return load_sexp_effect(source, base_path=path.parent) + + +class SexpEffectLoader: + """ + Loader for sexp effect definitions. + + Creates effect functions compatible with the EffectExecutor. + """ + + def __init__(self, recipe_dir: Optional[Path] = None): + """ + Initialize loader. + + Args: + recipe_dir: Base directory for resolving relative effect paths + """ + self.recipe_dir = recipe_dir or Path.cwd() + # Cache loaded effects with their param_defs for introspection + self._loaded_effects: Dict[str, tuple] = {} + + def load_effect_path(self, effect_path: str) -> Callable: + """ + Load a sexp effect from a relative path. + + Args: + effect_path: Relative path to effect .sexp file + + Returns: + Effect function (input_path, output_path, config) -> output_path + """ + from ..effects.frame_processor import process_video + + full_path = self.recipe_dir / effect_path + if not full_path.exists(): + raise FileNotFoundError(f"Sexp effect not found: {full_path}") + + name, process_frame_fn, params_defaults, param_defs = load_sexp_effect_file(full_path) + logger.info(f"Loaded sexp effect: {name} from {effect_path}") + + # Cache for introspection + self._loaded_effects[effect_path] = (name, params_defaults, param_defs) + + def effect_fn(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path: + """Run sexp effect via frame processor.""" + # Extract params (excluding internal keys) + params = dict(params_defaults) # Start with defaults + for k, v in config.items(): + if k not in ("effect", "cid", "hash", "effect_path", "_binding"): + params[k] = v + + # Get bindings if present + bindings = {} + for key, value in config.items(): + if isinstance(value, dict) and value.get("_resolved_values"): + bindings[key] = value["_resolved_values"] + + output_path.parent.mkdir(parents=True, exist_ok=True) + actual_output = output_path.with_suffix(".mp4") + + process_video( + input_path=input_path, + output_path=actual_output, + process_frame=process_frame_fn, + params=params, + bindings=bindings, + ) + + logger.info(f"Processed sexp effect '{name}' from {effect_path}") + return actual_output + + return effect_fn + + def get_effect_params(self, effect_path: str) -> List[ParamDef]: + """ + Get parameter definitions for an effect. + + Args: + effect_path: Relative path to effect .sexp file + + Returns: + List of ParamDef objects describing the effect's parameters + """ + if effect_path not in self._loaded_effects: + # Load the effect to get its params + full_path = self.recipe_dir / effect_path + if not full_path.exists(): + raise FileNotFoundError(f"Sexp effect not found: {full_path}") + name, _, params_defaults, param_defs = load_sexp_effect_file(full_path) + self._loaded_effects[effect_path] = (name, params_defaults, param_defs) + + return self._loaded_effects[effect_path][2] + + +def get_sexp_effect_loader(recipe_dir: Optional[Path] = None) -> SexpEffectLoader: + """Get a sexp effect loader instance.""" + return SexpEffectLoader(recipe_dir) diff --git a/artdag/sexp/evaluator.py b/artdag/sexp/evaluator.py new file mode 100644 index 0000000..5e3b175 --- /dev/null +++ b/artdag/sexp/evaluator.py @@ -0,0 +1,869 @@ +""" +Expression evaluator for S-expression DSL. + +Supports: +- Arithmetic: +, -, *, /, mod, sqrt, pow, abs, floor, ceil, round, min, max, clamp +- Comparison: =, <, >, <=, >= +- Logic: and, or, not +- Predicates: odd?, even?, zero?, nil? +- Conditionals: if, cond, case +- Data: list, dict/map construction, get +- Lambda calls +""" + +from typing import Any, Dict, List, Callable +from .parser import Symbol, Keyword, Lambda, Binding + + +class EvalError(Exception): + """Error during expression evaluation.""" + pass + + +# Built-in functions +BUILTINS: Dict[str, Callable] = {} + + +def builtin(name: str): + """Decorator to register a builtin function.""" + def decorator(fn): + BUILTINS[name] = fn + return fn + return decorator + + +@builtin("+") +def add(*args): + return sum(args) + + +@builtin("-") +def sub(a, b=None): + if b is None: + return -a + return a - b + + +@builtin("*") +def mul(*args): + result = 1 + for a in args: + result *= a + return result + + +@builtin("/") +def div(a, b): + return a / b + + +@builtin("mod") +def mod(a, b): + return a % b + + +@builtin("sqrt") +def sqrt(x): + return x ** 0.5 + + +@builtin("pow") +def power(x, n): + return x ** n + + +@builtin("abs") +def absolute(x): + return abs(x) + + +@builtin("floor") +def floor_fn(x): + import math + return math.floor(x) + + +@builtin("ceil") +def ceil_fn(x): + import math + return math.ceil(x) + + +@builtin("round") +def round_fn(x, ndigits=0): + return round(x, int(ndigits)) + + +@builtin("min") +def min_fn(*args): + if len(args) == 1 and isinstance(args[0], (list, tuple)): + return min(args[0]) + return min(args) + + +@builtin("max") +def max_fn(*args): + if len(args) == 1 and isinstance(args[0], (list, tuple)): + return max(args[0]) + return max(args) + + +@builtin("clamp") +def clamp(x, lo, hi): + return max(lo, min(hi, x)) + + +@builtin("=") +def eq(a, b): + return a == b + + +@builtin("<") +def lt(a, b): + return a < b + + +@builtin(">") +def gt(a, b): + return a > b + + +@builtin("<=") +def lte(a, b): + return a <= b + + +@builtin(">=") +def gte(a, b): + return a >= b + + +@builtin("odd?") +def is_odd(n): + return n % 2 == 1 + + +@builtin("even?") +def is_even(n): + return n % 2 == 0 + + +@builtin("zero?") +def is_zero(n): + return n == 0 + + +@builtin("nil?") +def is_nil(x): + return x is None + + +@builtin("not") +def not_fn(x): + return not x + + +@builtin("inc") +def inc(n): + return n + 1 + + +@builtin("dec") +def dec(n): + return n - 1 + + +@builtin("list") +def make_list(*args): + return list(args) + + +@builtin("assert") +def assert_true(condition, message="Assertion failed"): + if not condition: + raise RuntimeError(f"Assertion error: {message}") + return True + + +@builtin("get") +def get(coll, key, default=None): + if isinstance(coll, dict): + # Try the key directly first + result = coll.get(key, None) + if result is not None: + return result + # If key is a Keyword, also try its string name (for Python dicts with string keys) + if isinstance(key, Keyword): + result = coll.get(key.name, None) + if result is not None: + return result + # Return the default + return default + elif isinstance(coll, list): + return coll[key] if 0 <= key < len(coll) else default + else: + raise EvalError(f"get: expected dict or list, got {type(coll).__name__}: {str(coll)[:100]}") + + +@builtin("dict?") +def is_dict(x): + return isinstance(x, dict) + + +@builtin("list?") +def is_list(x): + return isinstance(x, list) + + +@builtin("nil?") +def is_nil(x): + return x is None + + +@builtin("number?") +def is_number(x): + return isinstance(x, (int, float)) + + +@builtin("string?") +def is_string(x): + return isinstance(x, str) + + +@builtin("len") +def length(coll): + return len(coll) + + +@builtin("first") +def first(coll): + return coll[0] if coll else None + + +@builtin("last") +def last(coll): + return coll[-1] if coll else None + + +@builtin("chunk-every") +def chunk_every(coll, n): + """Split collection into chunks of n elements.""" + n = int(n) + return [coll[i:i+n] for i in range(0, len(coll), n)] + + +@builtin("rest") +def rest(coll): + return coll[1:] if coll else [] + + +@builtin("nth") +def nth(coll, n): + return coll[n] if 0 <= n < len(coll) else None + + +@builtin("concat") +def concat(*colls): + """Concatenate multiple lists/sequences.""" + result = [] + for c in colls: + if c is not None: + result.extend(c) + return result + + +@builtin("cons") +def cons(x, coll): + """Prepend x to collection.""" + return [x] + list(coll) if coll else [x] + + +@builtin("append") +def append(coll, x): + """Append x to collection.""" + return list(coll) + [x] if coll else [x] + + +@builtin("range") +def make_range(start, end, step=1): + """Create a range of numbers.""" + return list(range(int(start), int(end), int(step))) + + +@builtin("zip-pairs") +def zip_pairs(coll): + """Zip consecutive pairs: [a,b,c,d] -> [[a,b],[b,c],[c,d]].""" + if not coll or len(coll) < 2: + return [] + return [[coll[i], coll[i+1]] for i in range(len(coll)-1)] + + +@builtin("dict") +def make_dict(*pairs): + """Create dict from key-value pairs: (dict :a 1 :b 2).""" + result = {} + i = 0 + while i < len(pairs) - 1: + key = pairs[i] + if isinstance(key, Keyword): + key = key.name + result[key] = pairs[i + 1] + i += 2 + return result + + +@builtin("keys") +def keys(d): + """Get the keys of a dict as a list.""" + if not isinstance(d, dict): + raise EvalError(f"keys: expected dict, got {type(d).__name__}") + return list(d.keys()) + + +@builtin("vals") +def vals(d): + """Get the values of a dict as a list.""" + if not isinstance(d, dict): + raise EvalError(f"vals: expected dict, got {type(d).__name__}") + return list(d.values()) + + +@builtin("merge") +def merge(*dicts): + """Merge multiple dicts, later dicts override earlier.""" + result = {} + for d in dicts: + if d is not None: + if not isinstance(d, dict): + raise EvalError(f"merge: expected dict, got {type(d).__name__}") + result.update(d) + return result + + +@builtin("assoc") +def assoc(d, *pairs): + """Associate keys with values in a dict: (assoc d :a 1 :b 2).""" + if d is None: + result = {} + elif isinstance(d, dict): + result = dict(d) + else: + raise EvalError(f"assoc: expected dict or nil, got {type(d).__name__}") + + i = 0 + while i < len(pairs) - 1: + key = pairs[i] + if isinstance(key, Keyword): + key = key.name + result[key] = pairs[i + 1] + i += 2 + return result + + +@builtin("dissoc") +def dissoc(d, *keys_to_remove): + """Remove keys from a dict: (dissoc d :a :b).""" + if d is None: + return {} + if not isinstance(d, dict): + raise EvalError(f"dissoc: expected dict or nil, got {type(d).__name__}") + + result = dict(d) + for key in keys_to_remove: + if isinstance(key, Keyword): + key = key.name + result.pop(key, None) + return result + + +@builtin("into") +def into(target, coll): + """Convert a collection into another collection type. + + (into [] {:a 1 :b 2}) -> [["a" 1] ["b" 2]] + (into {} [[:a 1] [:b 2]]) -> {"a": 1, "b": 2} + (into [] [1 2 3]) -> [1 2 3] + """ + if isinstance(target, list): + if isinstance(coll, dict): + return [[k, v] for k, v in coll.items()] + elif isinstance(coll, (list, tuple)): + return list(coll) + else: + raise EvalError(f"into: cannot convert {type(coll).__name__} into list") + elif isinstance(target, dict): + if isinstance(coll, dict): + return dict(coll) + elif isinstance(coll, (list, tuple)): + result = {} + for item in coll: + if isinstance(item, (list, tuple)) and len(item) >= 2: + key = item[0] + if isinstance(key, Keyword): + key = key.name + result[key] = item[1] + else: + raise EvalError(f"into: expected [key value] pairs, got {item}") + return result + else: + raise EvalError(f"into: cannot convert {type(coll).__name__} into dict") + else: + raise EvalError(f"into: unsupported target type {type(target).__name__}") + + +@builtin("filter") +def filter_fn(pred, coll): + """Filter collection by predicate. Pred must be a lambda.""" + if not isinstance(pred, Lambda): + raise EvalError(f"filter: expected lambda as predicate, got {type(pred).__name__}") + + result = [] + for item in coll: + # Evaluate predicate with item + local_env = {} + if pred.closure: + local_env.update(pred.closure) + local_env[pred.params[0]] = item + + # Inline evaluation of pred.body + from . import evaluator + if evaluator.evaluate(pred.body, local_env): + result.append(item) + return result + + +@builtin("some") +def some(pred, coll): + """Return first truthy value of (pred item) for items in coll, or nil.""" + if not isinstance(pred, Lambda): + raise EvalError(f"some: expected lambda as predicate, got {type(pred).__name__}") + + for item in coll: + local_env = {} + if pred.closure: + local_env.update(pred.closure) + local_env[pred.params[0]] = item + + from . import evaluator + result = evaluator.evaluate(pred.body, local_env) + if result: + return result + return None + + +@builtin("every?") +def every(pred, coll): + """Return true if (pred item) is truthy for all items in coll.""" + if not isinstance(pred, Lambda): + raise EvalError(f"every?: expected lambda as predicate, got {type(pred).__name__}") + + for item in coll: + local_env = {} + if pred.closure: + local_env.update(pred.closure) + local_env[pred.params[0]] = item + + from . import evaluator + if not evaluator.evaluate(pred.body, local_env): + return False + return True + + +@builtin("empty?") +def is_empty(coll): + """Return true if collection is empty.""" + if coll is None: + return True + return len(coll) == 0 + + +@builtin("contains?") +def contains(coll, key): + """Check if collection contains key (for dicts) or element (for lists).""" + if isinstance(coll, dict): + if isinstance(key, Keyword): + key = key.name + return key in coll + elif isinstance(coll, (list, tuple)): + return key in coll + return False + + +def evaluate(expr: Any, env: Dict[str, Any] = None) -> Any: + """ + Evaluate an S-expression in the given environment. + + Args: + expr: The expression to evaluate + env: Variable bindings (name -> value) + + Returns: + The result of evaluation + """ + if env is None: + env = {} + + # Literals + if isinstance(expr, (int, float, str, bool)) or expr is None: + return expr + + # Symbol - variable lookup + if isinstance(expr, Symbol): + name = expr.name + if name in env: + return env[name] + if name in BUILTINS: + return BUILTINS[name] + if name == "true": + return True + if name == "false": + return False + if name == "nil": + return None + raise EvalError(f"Undefined symbol: {name}") + + # Keyword - return as-is (used as map keys) + if isinstance(expr, Keyword): + return expr.name + + # Lambda - return as-is (it's a value) + if isinstance(expr, Lambda): + return expr + + # Binding - return as-is (resolved at execution time) + if isinstance(expr, Binding): + return expr + + # Dict literal + if isinstance(expr, dict): + return {k: evaluate(v, env) for k, v in expr.items()} + + # List - function call or special form + if isinstance(expr, list): + if not expr: + return [] + + head = expr[0] + + # If head is a string/number/etc (not Symbol), treat as data list + if not isinstance(head, (Symbol, Lambda, list)): + return [evaluate(x, env) for x in expr] + + # Special forms + if isinstance(head, Symbol): + name = head.name + + # if - conditional + if name == "if": + if len(expr) < 3: + raise EvalError("if requires condition and then-branch") + cond_result = evaluate(expr[1], env) + if cond_result: + return evaluate(expr[2], env) + elif len(expr) > 3: + return evaluate(expr[3], env) + return None + + # cond - multi-way conditional + # Supports both Clojure style: (cond test1 result1 test2 result2 :else default) + # and Scheme style: (cond (test1 result1) (test2 result2) (else default)) + if name == "cond": + clauses = expr[1:] + # Check if Clojure style (flat list) or Scheme style (nested pairs) + # Scheme style: first clause is (test result) - exactly 2 elements + # Clojure style: first clause is a test expression like (= x 0) - 3+ elements + first_is_scheme_clause = ( + clauses and + isinstance(clauses[0], list) and + len(clauses[0]) == 2 and + not (isinstance(clauses[0][0], Symbol) and clauses[0][0].name in ('=', '<', '>', '<=', '>=', '!=', 'not=', 'and', 'or')) + ) + if first_is_scheme_clause: + # Scheme style: ((test result) ...) + for clause in clauses: + if not isinstance(clause, list) or len(clause) < 2: + raise EvalError("cond clause must be (test result)") + test = clause[0] + # Check for else/default + if isinstance(test, Symbol) and test.name in ("else", ":else"): + return evaluate(clause[1], env) + if isinstance(test, Keyword) and test.name == "else": + return evaluate(clause[1], env) + if evaluate(test, env): + return evaluate(clause[1], env) + else: + # Clojure style: test1 result1 test2 result2 ... + i = 0 + while i < len(clauses) - 1: + test = clauses[i] + result = clauses[i + 1] + # Check for :else + if isinstance(test, Keyword) and test.name == "else": + return evaluate(result, env) + if isinstance(test, Symbol) and test.name == ":else": + return evaluate(result, env) + if evaluate(test, env): + return evaluate(result, env) + i += 2 + return None + + # case - switch on value + # (case expr val1 result1 val2 result2 :else default) + if name == "case": + if len(expr) < 2: + raise EvalError("case requires expression to match") + match_val = evaluate(expr[1], env) + clauses = expr[2:] + i = 0 + while i < len(clauses) - 1: + test = clauses[i] + result = clauses[i + 1] + # Check for :else / else + if isinstance(test, Keyword) and test.name == "else": + return evaluate(result, env) + if isinstance(test, Symbol) and test.name in (":else", "else"): + return evaluate(result, env) + # Evaluate test value and compare + test_val = evaluate(test, env) + if match_val == test_val: + return evaluate(result, env) + i += 2 + return None + + # and - short-circuit + if name == "and": + result = True + for arg in expr[1:]: + result = evaluate(arg, env) + if not result: + return result + return result + + # or - short-circuit + if name == "or": + result = False + for arg in expr[1:]: + result = evaluate(arg, env) + if result: + return result + return result + + # let and let* - local bindings (both bind sequentially in this impl) + if name in ("let", "let*"): + if len(expr) < 3: + raise EvalError(f"{name} requires bindings and body") + bindings = expr[1] + + local_env = dict(env) + + if isinstance(bindings, list): + # Check if it's ((name value) ...) style (Lisp let* style) + if bindings and isinstance(bindings[0], list): + for binding in bindings: + if len(binding) != 2: + raise EvalError(f"{name} binding must be (name value)") + var_name = binding[0] + if isinstance(var_name, Symbol): + var_name = var_name.name + value = evaluate(binding[1], local_env) + local_env[var_name] = value + # Vector-style [name value ...] + elif len(bindings) % 2 == 0: + for i in range(0, len(bindings), 2): + var_name = bindings[i] + if isinstance(var_name, Symbol): + var_name = var_name.name + value = evaluate(bindings[i + 1], local_env) + local_env[var_name] = value + else: + raise EvalError(f"{name} bindings must be [name value ...] or ((name value) ...)") + else: + raise EvalError(f"{name} bindings must be a list") + + return evaluate(expr[2], local_env) + + # lambda / fn - create function with closure + if name in ("lambda", "fn"): + if len(expr) < 3: + raise EvalError("lambda requires params and body") + params = expr[1] + if not isinstance(params, list): + raise EvalError("lambda params must be a list") + param_names = [] + for p in params: + if isinstance(p, Symbol): + param_names.append(p.name) + elif isinstance(p, str): + param_names.append(p) + else: + raise EvalError(f"Invalid param: {p}") + # Capture current environment as closure + return Lambda(param_names, expr[2], dict(env)) + + # quote - return unevaluated + if name == "quote": + return expr[1] if len(expr) > 1 else None + + # bind - create binding to analysis data + # (bind analysis-var) + # (bind analysis-var :range [0.3 1.0]) + # (bind analysis-var :range [0 100] :transform sqrt) + if name == "bind": + if len(expr) < 2: + raise EvalError("bind requires analysis reference") + analysis_ref = expr[1] + if isinstance(analysis_ref, Symbol): + symbol_name = analysis_ref.name + # Look up the symbol in environment + if symbol_name in env: + resolved = env[symbol_name] + # If resolved is actual analysis data (dict with times/values or + # S-expression list with Keywords), keep the symbol name as reference + # for later lookup at execution time + if isinstance(resolved, dict) and ("times" in resolved or "values" in resolved): + analysis_ref = symbol_name # Use name as reference, not the data + elif isinstance(resolved, list) and any(isinstance(x, Keyword) for x in resolved): + # Parsed S-expression analysis data ([:times [...] :duration ...]) + analysis_ref = symbol_name + else: + analysis_ref = resolved + else: + raise EvalError(f"bind: undefined symbol '{symbol_name}' - must reference analysis data") + + # Parse optional :range [min max] and :transform + range_min, range_max = 0.0, 1.0 + transform = None + i = 2 + while i < len(expr): + if isinstance(expr[i], Keyword): + kw = expr[i].name + if kw == "range" and i + 1 < len(expr): + range_val = evaluate(expr[i + 1], env) # Evaluate to get actual value + if isinstance(range_val, list) and len(range_val) >= 2: + range_min = float(range_val[0]) + range_max = float(range_val[1]) + i += 2 + elif kw == "transform" and i + 1 < len(expr): + t = expr[i + 1] + if isinstance(t, Symbol): + transform = t.name + elif isinstance(t, str): + transform = t + i += 2 + else: + i += 1 + else: + i += 1 + + return Binding(analysis_ref, range_min=range_min, range_max=range_max, transform=transform) + + # Vector literal [a b c] + if name == "vec" or name == "vector": + return [evaluate(e, env) for e in expr[1:]] + + # map - (map fn coll) + if name == "map": + if len(expr) != 3: + raise EvalError("map requires fn and collection") + fn = evaluate(expr[1], env) + coll = evaluate(expr[2], env) + if not isinstance(fn, Lambda): + raise EvalError(f"map requires lambda, got {type(fn)}") + result = [] + for item in coll: + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = item + result.append(evaluate(fn.body, local_env)) + return result + + # map-indexed - (map-indexed fn coll) + if name == "map-indexed": + if len(expr) != 3: + raise EvalError("map-indexed requires fn and collection") + fn = evaluate(expr[1], env) + coll = evaluate(expr[2], env) + if not isinstance(fn, Lambda): + raise EvalError(f"map-indexed requires lambda, got {type(fn)}") + if len(fn.params) < 2: + raise EvalError("map-indexed lambda needs (i item) params") + result = [] + for i, item in enumerate(coll): + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = i + local_env[fn.params[1]] = item + result.append(evaluate(fn.body, local_env)) + return result + + # reduce - (reduce fn init coll) + if name == "reduce": + if len(expr) != 4: + raise EvalError("reduce requires fn, init, and collection") + fn = evaluate(expr[1], env) + acc = evaluate(expr[2], env) + coll = evaluate(expr[3], env) + if not isinstance(fn, Lambda): + raise EvalError(f"reduce requires lambda, got {type(fn)}") + if len(fn.params) < 2: + raise EvalError("reduce lambda needs (acc item) params") + for item in coll: + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = acc + local_env[fn.params[1]] = item + acc = evaluate(fn.body, local_env) + return acc + + # for-each - (for-each fn coll) - iterate with side effects + if name == "for-each": + if len(expr) != 3: + raise EvalError("for-each requires fn and collection") + fn = evaluate(expr[1], env) + coll = evaluate(expr[2], env) + if not isinstance(fn, Lambda): + raise EvalError(f"for-each requires lambda, got {type(fn)}") + for item in coll: + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + local_env[fn.params[0]] = item + evaluate(fn.body, local_env) + return None + + # Function call + fn = evaluate(head, env) + args = [evaluate(arg, env) for arg in expr[1:]] + + # Call builtin + if callable(fn): + return fn(*args) + + # Call lambda + if isinstance(fn, Lambda): + if len(args) != len(fn.params): + raise EvalError(f"Lambda expects {len(fn.params)} args, got {len(args)}") + # Start with closure (captured env), then overlay calling env, then params + local_env = {} + if fn.closure: + local_env.update(fn.closure) + local_env.update(env) + for name, value in zip(fn.params, args): + local_env[name] = value + return evaluate(fn.body, local_env) + + raise EvalError(f"Not callable: {fn}") + + raise EvalError(f"Cannot evaluate: {expr!r}") + + +def make_env(**kwargs) -> Dict[str, Any]: + """Create an environment with initial bindings.""" + return dict(kwargs) diff --git a/artdag/sexp/external_tools.py b/artdag/sexp/external_tools.py new file mode 100644 index 0000000..fea13e2 --- /dev/null +++ b/artdag/sexp/external_tools.py @@ -0,0 +1,292 @@ +""" +External tool runners for effects that can't be done in FFmpeg. + +Supports: +- datamosh: via ffglitch or datamoshing Python CLI +- pixelsort: via Rust pixelsort or Python pixelsort-cli +""" + +import shutil +import subprocess +import tempfile +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +def find_tool(tool_names: List[str]) -> Optional[str]: + """Find first available tool from a list of candidates.""" + for name in tool_names: + path = shutil.which(name) + if path: + return path + return None + + +def check_python_package(package: str) -> bool: + """Check if a Python package is installed.""" + try: + result = subprocess.run( + ["python3", "-c", f"import {package}"], + capture_output=True, + timeout=5, + ) + return result.returncode == 0 + except Exception: + return False + + +# Tool detection +DATAMOSH_TOOLS = ["ffgac", "ffedit"] # ffglitch tools +PIXELSORT_TOOLS = ["pixelsort"] # Rust CLI + + +def get_available_tools() -> Dict[str, Optional[str]]: + """Get dictionary of available external tools.""" + return { + "datamosh": find_tool(DATAMOSH_TOOLS), + "pixelsort": find_tool(PIXELSORT_TOOLS), + "datamosh_python": "datamoshing" if check_python_package("datamoshing") else None, + "pixelsort_python": "pixelsort" if check_python_package("pixelsort") else None, + } + + +def run_datamosh( + input_path: Path, + output_path: Path, + params: Dict[str, Any], +) -> Tuple[bool, str]: + """ + Run datamosh effect using available tool. + + Args: + input_path: Input video file + output_path: Output video file + params: Effect parameters (corruption, block_size, etc.) + + Returns: + (success, error_message) + """ + tools = get_available_tools() + + corruption = params.get("corruption", 0.3) + + # Try ffglitch first + if tools.get("datamosh"): + ffgac = tools["datamosh"] + try: + # ffglitch approach: remove I-frames to create datamosh effect + # This is a simplified version - full datamosh needs custom scripts + with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f: + # Write a simple ffglitch script that corrupts motion vectors + f.write(f""" +// Datamosh script - corrupt motion vectors +let corruption = {corruption}; + +export function glitch_frame(frame, stream) {{ + if (frame.pict_type === 'P' && Math.random() < corruption) {{ + // Corrupt motion vectors + let dominated = frame.mv?.forward?.overflow; + if (dominated) {{ + for (let i = 0; i < dominated.length; i++) {{ + if (Math.random() < corruption) {{ + dominated[i] = [ + Math.floor(Math.random() * 64 - 32), + Math.floor(Math.random() * 64 - 32) + ]; + }} + }} + }} + }} + return frame; +}} +""") + script_path = f.name + + cmd = [ + ffgac, + "-i", str(input_path), + "-s", script_path, + "-o", str(output_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + Path(script_path).unlink(missing_ok=True) + + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + + except subprocess.TimeoutExpired: + return False, "Datamosh timeout" + except Exception as e: + return False, str(e) + + # Fall back to Python datamoshing package + if tools.get("datamosh_python"): + try: + cmd = [ + "python3", "-m", "datamoshing", + str(input_path), + str(output_path), + "--mode", "iframe_removal", + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + except Exception as e: + return False, str(e) + + return False, "No datamosh tool available. Install ffglitch or: pip install datamoshing" + + +def run_pixelsort( + input_path: Path, + output_path: Path, + params: Dict[str, Any], +) -> Tuple[bool, str]: + """ + Run pixelsort effect using available tool. + + Args: + input_path: Input image/frame file + output_path: Output image file + params: Effect parameters (sort_by, threshold_low, threshold_high, angle) + + Returns: + (success, error_message) + """ + tools = get_available_tools() + + sort_by = params.get("sort_by", "lightness") + threshold_low = params.get("threshold_low", 50) + threshold_high = params.get("threshold_high", 200) + angle = params.get("angle", 0) + + # Try Rust pixelsort first (faster) + if tools.get("pixelsort"): + try: + cmd = [ + tools["pixelsort"], + str(input_path), + "-o", str(output_path), + "--sort", sort_by, + "-r", str(angle), + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + except Exception as e: + return False, str(e) + + # Fall back to Python pixelsort-cli + if tools.get("pixelsort_python"): + try: + cmd = [ + "python3", "-m", "pixelsort", + "--image_path", str(input_path), + "--output", str(output_path), + "--angle", str(angle), + "--threshold", str(threshold_low / 255), # Normalize to 0-1 + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=60) + if result.returncode == 0: + return True, "" + return False, result.stderr[:500] + except Exception as e: + return False, str(e) + + return False, "No pixelsort tool available. Install: cargo install pixelsort or pip install pixelsort-cli" + + +def run_pixelsort_video( + input_path: Path, + output_path: Path, + params: Dict[str, Any], + fps: float = 30, +) -> Tuple[bool, str]: + """ + Run pixelsort on a video by processing each frame. + + This extracts frames, processes them, then reassembles. + """ + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + frames_in = tmpdir / "frame_%04d.png" + frames_out = tmpdir / "sorted_%04d.png" + + # Extract frames + extract_cmd = [ + "ffmpeg", "-y", + "-i", str(input_path), + "-vf", f"fps={fps}", + str(frames_in), + ] + result = subprocess.run(extract_cmd, capture_output=True, timeout=300) + if result.returncode != 0: + return False, "Failed to extract frames" + + # Process each frame + frame_files = sorted(tmpdir.glob("frame_*.png")) + for i, frame in enumerate(frame_files): + out_frame = tmpdir / f"sorted_{i:04d}.png" + success, error = run_pixelsort(frame, out_frame, params) + if not success: + return False, f"Frame {i}: {error}" + + # Reassemble + # Get audio from original + reassemble_cmd = [ + "ffmpeg", "-y", + "-framerate", str(fps), + "-i", str(tmpdir / "sorted_%04d.png"), + "-i", str(input_path), + "-map", "0:v", "-map", "1:a?", + "-c:v", "libx264", "-preset", "fast", + "-c:a", "copy", + str(output_path), + ] + result = subprocess.run(reassemble_cmd, capture_output=True, timeout=300) + if result.returncode != 0: + return False, "Failed to reassemble video" + + return True, "" + + +def run_external_effect( + effect_name: str, + input_path: Path, + output_path: Path, + params: Dict[str, Any], + is_video: bool = True, +) -> Tuple[bool, str]: + """ + Run an external effect tool. + + Args: + effect_name: Name of effect (datamosh, pixelsort) + input_path: Input file + output_path: Output file + params: Effect parameters + is_video: Whether input is video (vs single image) + + Returns: + (success, error_message) + """ + if effect_name == "datamosh": + return run_datamosh(input_path, output_path, params) + elif effect_name == "pixelsort": + if is_video: + return run_pixelsort_video(input_path, output_path, params) + else: + return run_pixelsort(input_path, output_path, params) + else: + return False, f"Unknown external effect: {effect_name}" + + +if __name__ == "__main__": + # Print available tools + print("Available external tools:") + for name, path in get_available_tools().items(): + status = path if path else "NOT INSTALLED" + print(f" {name}: {status}") diff --git a/artdag/sexp/ffmpeg_compiler.py b/artdag/sexp/ffmpeg_compiler.py new file mode 100644 index 0000000..d69508e --- /dev/null +++ b/artdag/sexp/ffmpeg_compiler.py @@ -0,0 +1,616 @@ +""" +FFmpeg filter compiler for sexp effects. + +Compiles sexp effect definitions to FFmpeg filter expressions, +with support for dynamic parameters via sendcmd scripts. + +Usage: + compiler = FFmpegCompiler() + + # Compile an effect with static params + filter_str = compiler.compile_effect("brightness", {"amount": 50}) + # -> "eq=brightness=0.196" + + # Compile with dynamic binding to analysis data + filter_str, sendcmd = compiler.compile_effect_with_binding( + "brightness", + {"amount": {"_bind": "bass-data", "range_min": 0, "range_max": 100}}, + analysis_data={"bass-data": {"times": [...], "values": [...]}}, + segment_start=0.0, + segment_duration=5.0, + ) + # -> ("eq=brightness=0.5", "0.0 [eq] brightness 0.5;\n0.05 [eq] brightness 0.6;...") +""" + +import math +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +# FFmpeg filter mappings for common effects +# Maps effect name -> {filter: str, params: {param_name: {ffmpeg_param, scale, offset}}} +EFFECT_MAPPINGS = { + "invert": { + "filter": "negate", + "params": {}, + }, + "grayscale": { + "filter": "colorchannelmixer", + "static": "0.3:0.4:0.3:0:0.3:0.4:0.3:0:0.3:0.4:0.3", + "params": {}, + }, + "sepia": { + "filter": "colorchannelmixer", + "static": "0.393:0.769:0.189:0:0.349:0.686:0.168:0:0.272:0.534:0.131", + "params": {}, + }, + "brightness": { + "filter": "eq", + "params": { + "amount": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0}, + }, + }, + "contrast": { + "filter": "eq", + "params": { + "amount": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0}, + }, + }, + "saturation": { + "filter": "eq", + "params": { + "amount": {"ffmpeg_param": "saturation", "scale": 1.0, "offset": 0}, + }, + }, + "hue_shift": { + "filter": "hue", + "params": { + "degrees": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0}, + }, + }, + "blur": { + "filter": "gblur", + "params": { + "radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0}, + }, + }, + "sharpen": { + "filter": "unsharp", + "params": { + "amount": {"ffmpeg_param": "la", "scale": 1.0, "offset": 0}, + }, + }, + "pixelate": { + # Scale down then up to create pixelation effect + "filter": "scale", + "static": "iw/8:ih/8:flags=neighbor,scale=iw*8:ih*8:flags=neighbor", + "params": {}, + }, + "vignette": { + "filter": "vignette", + "params": { + "strength": {"ffmpeg_param": "a", "scale": 1.0, "offset": 0}, + }, + }, + "noise": { + "filter": "noise", + "params": { + "amount": {"ffmpeg_param": "alls", "scale": 1.0, "offset": 0}, + }, + }, + "flip": { + "filter": "hflip", # Default horizontal + "params": {}, + }, + "mirror": { + "filter": "hflip", + "params": {}, + }, + "rotate": { + "filter": "rotate", + "params": { + "angle": {"ffmpeg_param": "a", "scale": math.pi/180, "offset": 0}, # degrees to radians + }, + }, + "zoom": { + "filter": "zoompan", + "params": { + "factor": {"ffmpeg_param": "z", "scale": 1.0, "offset": 0}, + }, + }, + "posterize": { + # Use lutyuv to quantize levels (approximate posterization) + "filter": "lutyuv", + "static": "y=floor(val/32)*32:u=floor(val/32)*32:v=floor(val/32)*32", + "params": {}, + }, + "threshold": { + # Use geq for thresholding + "filter": "geq", + "static": "lum='if(gt(lum(X,Y),128),255,0)':cb=128:cr=128", + "params": {}, + }, + "edge_detect": { + "filter": "edgedetect", + "params": { + "low": {"ffmpeg_param": "low", "scale": 1/255, "offset": 0}, + "high": {"ffmpeg_param": "high", "scale": 1/255, "offset": 0}, + }, + }, + "swirl": { + "filter": "lenscorrection", # Approximate with lens distortion + "params": { + "strength": {"ffmpeg_param": "k1", "scale": 0.1, "offset": 0}, + }, + }, + "fisheye": { + "filter": "lenscorrection", + "params": { + "strength": {"ffmpeg_param": "k1", "scale": 1.0, "offset": 0}, + }, + }, + "wave": { + # Wave displacement using geq - need r/g/b for RGB mode + "filter": "geq", + "static": "r='r(X+10*sin(Y/20),Y)':g='g(X+10*sin(Y/20),Y)':b='b(X+10*sin(Y/20),Y)'", + "params": {}, + }, + "rgb_split": { + # Chromatic aberration using geq + "filter": "geq", + "static": "r='p(X+5,Y)':g='p(X,Y)':b='p(X-5,Y)'", + "params": {}, + }, + "scanlines": { + "filter": "drawgrid", + "params": { + "spacing": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0}, + }, + }, + "film_grain": { + "filter": "noise", + "params": { + "intensity": {"ffmpeg_param": "alls", "scale": 100, "offset": 0}, + }, + }, + "crt": { + "filter": "vignette", # Simplified - just vignette for CRT look + "params": {}, + }, + "bloom": { + "filter": "gblur", # Simplified bloom = blur overlay + "params": { + "radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0}, + }, + }, + "color_cycle": { + "filter": "hue", + "params": { + "speed": {"ffmpeg_param": "h", "scale": 360.0, "offset": 0, "time_expr": True}, + }, + "time_based": True, # Uses time expression + }, + "strobe": { + # Strobe using select to drop frames + "filter": "select", + "static": "'mod(n,4)'", + "params": {}, + }, + "echo": { + # Echo using tmix + "filter": "tmix", + "static": "frames=4:weights='1 0.5 0.25 0.125'", + "params": {}, + }, + "trails": { + # Trails using tblend + "filter": "tblend", + "static": "all_mode=average", + "params": {}, + }, + "kaleidoscope": { + # 4-way mirror kaleidoscope using FFmpeg filter chain + # Crops top-left quadrant, mirrors horizontally, then vertically + "filter": "crop", + "complex": True, + "static": "iw/2:ih/2:0:0[q];[q]split[q1][q2];[q1]hflip[qr];[q2][qr]hstack[top];[top]split[t1][t2];[t2]vflip[bot];[t1][bot]vstack", + "params": {}, + }, + "emboss": { + "filter": "convolution", + "static": "-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2", + "params": {}, + }, + "neon_glow": { + # Edge detect + negate for neon-like effect + "filter": "edgedetect", + "static": "mode=colormix:high=0.1", + "params": {}, + }, + "ascii_art": { + # Requires Python frame processing - no FFmpeg equivalent + "filter": None, + "python_primitive": "ascii_art_frame", + "params": { + "char_size": {"default": 8}, + "alphabet": {"default": "standard"}, + "color_mode": {"default": "color"}, + }, + }, + "ascii_zones": { + # Requires Python frame processing - zone-based ASCII + "filter": None, + "python_primitive": "ascii_zones_frame", + "params": { + "char_size": {"default": 8}, + "zone_threshold": {"default": 128}, + }, + }, + "datamosh": { + # External tool: ffglitch or datamoshing CLI, falls back to Python + "filter": None, + "external_tool": "datamosh", + "python_primitive": "datamosh_frame", + "params": { + "block_size": {"default": 32}, + "corruption": {"default": 0.3}, + }, + }, + "pixelsort": { + # External tool: pixelsort CLI (Rust or Python), falls back to Python + "filter": None, + "external_tool": "pixelsort", + "python_primitive": "pixelsort_frame", + "params": { + "sort_by": {"default": "lightness"}, + "threshold_low": {"default": 50}, + "threshold_high": {"default": 200}, + "angle": {"default": 0}, + }, + }, + "ripple": { + # Use geq for ripple displacement + "filter": "geq", + "static": "lum='lum(X+5*sin(hypot(X-W/2,Y-H/2)/10),Y+5*cos(hypot(X-W/2,Y-H/2)/10))'", + "params": {}, + }, + "tile_grid": { + # Use tile filter for grid + "filter": "tile", + "static": "2x2", + "params": {}, + }, + "outline": { + "filter": "edgedetect", + "params": {}, + }, + "color-adjust": { + "filter": "eq", + "params": { + "brightness": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0}, + "contrast": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0}, + }, + }, +} + + +class FFmpegCompiler: + """Compiles sexp effects to FFmpeg filters with sendcmd support.""" + + def __init__(self, effect_mappings: Dict = None): + self.mappings = effect_mappings or EFFECT_MAPPINGS + + def get_full_mapping(self, effect_name: str) -> Optional[Dict]: + """Get full mapping for an effect (including external tools and python primitives).""" + mapping = self.mappings.get(effect_name) + if not mapping: + # Try with underscores/hyphens converted + normalized = effect_name.replace("-", "_").replace(" ", "_").lower() + mapping = self.mappings.get(normalized) + return mapping + + def get_mapping(self, effect_name: str) -> Optional[Dict]: + """Get FFmpeg filter mapping for an effect (returns None for non-FFmpeg effects).""" + mapping = self.get_full_mapping(effect_name) + # Return None if no mapping or no FFmpeg filter + if mapping and mapping.get("filter") is None: + return None + return mapping + + def has_external_tool(self, effect_name: str) -> Optional[str]: + """Check if effect uses an external tool. Returns tool name or None.""" + mapping = self.get_full_mapping(effect_name) + if mapping: + return mapping.get("external_tool") + return None + + def has_python_primitive(self, effect_name: str) -> Optional[str]: + """Check if effect uses a Python primitive. Returns primitive name or None.""" + mapping = self.get_full_mapping(effect_name) + if mapping: + return mapping.get("python_primitive") + return None + + def is_complex_filter(self, effect_name: str) -> bool: + """Check if effect uses a complex filter chain.""" + mapping = self.get_full_mapping(effect_name) + return bool(mapping and mapping.get("complex")) + + def compile_effect( + self, + effect_name: str, + params: Dict[str, Any], + ) -> Optional[str]: + """ + Compile an effect to an FFmpeg filter string with static params. + + Returns None if effect has no FFmpeg mapping. + """ + mapping = self.get_mapping(effect_name) + if not mapping: + return None + + filter_name = mapping["filter"] + + # Handle static filters (no params) + if "static" in mapping: + return f"{filter_name}={mapping['static']}" + + if not mapping.get("params"): + return filter_name + + # Build param string + filter_params = [] + for param_name, param_config in mapping["params"].items(): + if param_name in params: + value = params[param_name] + # Skip if it's a binding (handled separately) + if isinstance(value, dict) and ("_bind" in value or "_binding" in value): + continue + ffmpeg_param = param_config["ffmpeg_param"] + scale = param_config.get("scale", 1.0) + offset = param_config.get("offset", 0) + # Handle various value types + if isinstance(value, (int, float)): + ffmpeg_value = value * scale + offset + filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + elif isinstance(value, str): + filter_params.append(f"{ffmpeg_param}={value}") + elif isinstance(value, list) and value and isinstance(value[0], (int, float)): + ffmpeg_value = value[0] * scale + offset + filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + + if filter_params: + return f"{filter_name}={':'.join(filter_params)}" + return filter_name + + def compile_effect_with_bindings( + self, + effect_name: str, + params: Dict[str, Any], + analysis_data: Dict[str, Dict], + segment_start: float, + segment_duration: float, + sample_interval: float = 0.04, # ~25 fps + ) -> Tuple[Optional[str], Optional[str], List[str]]: + """ + Compile an effect with dynamic bindings to a filter + sendcmd script. + + Returns: + (filter_string, sendcmd_script, bound_param_names) + - filter_string: Initial FFmpeg filter (may have placeholder values) + - sendcmd_script: Script content for sendcmd filter + - bound_param_names: List of params that have bindings + """ + mapping = self.get_mapping(effect_name) + if not mapping: + return None, None, [] + + filter_name = mapping["filter"] + static_params = [] + bound_params = [] + sendcmd_lines = [] + + # Handle time-based effects (use FFmpeg expressions with 't') + if mapping.get("time_based"): + for param_name, param_config in mapping.get("params", {}).items(): + if param_name in params: + value = params[param_name] + ffmpeg_param = param_config["ffmpeg_param"] + scale = param_config.get("scale", 1.0) + if isinstance(value, (int, float)): + # Create time expression: h='t*speed*scale' + static_params.append(f"{ffmpeg_param}='t*{value}*{scale}'") + else: + static_params.append(f"{ffmpeg_param}='t*{scale}'") + if static_params: + filter_str = f"{filter_name}={':'.join(static_params)}" + else: + filter_str = f"{filter_name}=h='t*360'" # Default rotation + return filter_str, None, [] + + # Process each param + for param_name, param_config in mapping.get("params", {}).items(): + if param_name not in params: + continue + + value = params[param_name] + ffmpeg_param = param_config["ffmpeg_param"] + scale = param_config.get("scale", 1.0) + offset = param_config.get("offset", 0) + + # Check if it's a binding + if isinstance(value, dict) and ("_bind" in value or "_binding" in value): + bind_ref = value.get("_bind") or value.get("_binding") + range_min = value.get("range_min", 0.0) + range_max = value.get("range_max", 1.0) + transform = value.get("transform") + + # Get analysis data + analysis = analysis_data.get(bind_ref) + if not analysis: + # Try without -data suffix + analysis = analysis_data.get(bind_ref.replace("-data", "")) + + if analysis and "times" in analysis and "values" in analysis: + times = analysis["times"] + values = analysis["values"] + + # Generate sendcmd entries for this segment + segment_end = segment_start + segment_duration + t = 0.0 # Time relative to segment start + + while t < segment_duration: + abs_time = segment_start + t + + # Find analysis value at this time + raw_value = self._interpolate_value(times, values, abs_time) + + # Apply transform + if transform == "sqrt": + raw_value = math.sqrt(max(0, raw_value)) + elif transform == "pow2": + raw_value = raw_value ** 2 + elif transform == "log": + raw_value = math.log(max(0.001, raw_value)) + + # Map to range + mapped_value = range_min + raw_value * (range_max - range_min) + + # Apply FFmpeg scaling + ffmpeg_value = mapped_value * scale + offset + + # Add sendcmd line (time relative to segment) + sendcmd_lines.append(f"{t:.3f} [{filter_name}] {ffmpeg_param} {ffmpeg_value:.4f};") + + t += sample_interval + + bound_params.append(param_name) + # Use initial value for the filter string + initial_value = self._interpolate_value(times, values, segment_start) + initial_mapped = range_min + initial_value * (range_max - range_min) + initial_ffmpeg = initial_mapped * scale + offset + static_params.append(f"{ffmpeg_param}={initial_ffmpeg:.4f}") + else: + # No analysis data, use range midpoint + mid_value = (range_min + range_max) / 2 + ffmpeg_value = mid_value * scale + offset + static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + else: + # Static value - handle various types + if isinstance(value, (int, float)): + ffmpeg_value = value * scale + offset + static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + elif isinstance(value, str): + # String value - use as-is (e.g., for direction parameters) + static_params.append(f"{ffmpeg_param}={value}") + elif isinstance(value, list) and value: + # List - try to use first numeric element + first = value[0] + if isinstance(first, (int, float)): + ffmpeg_value = first * scale + offset + static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}") + # Skip other types + + # Handle static filters + if "static" in mapping: + filter_str = f"{filter_name}={mapping['static']}" + elif static_params: + filter_str = f"{filter_name}={':'.join(static_params)}" + else: + filter_str = filter_name + + # Combine sendcmd lines + sendcmd_script = "\n".join(sendcmd_lines) if sendcmd_lines else None + + return filter_str, sendcmd_script, bound_params + + def _interpolate_value( + self, + times: List[float], + values: List[float], + target_time: float, + ) -> float: + """Interpolate a value from analysis data at a given time.""" + if not times or not values: + return 0.5 + + # Find surrounding indices + if target_time <= times[0]: + return values[0] + if target_time >= times[-1]: + return values[-1] + + # Binary search for efficiency + lo, hi = 0, len(times) - 1 + while lo < hi - 1: + mid = (lo + hi) // 2 + if times[mid] <= target_time: + lo = mid + else: + hi = mid + + # Linear interpolation + t0, t1 = times[lo], times[hi] + v0, v1 = values[lo], values[hi] + + if t1 == t0: + return v0 + + alpha = (target_time - t0) / (t1 - t0) + return v0 + alpha * (v1 - v0) + + +def generate_sendcmd_filter( + effects: List[Dict], + analysis_data: Dict[str, Dict], + segment_start: float, + segment_duration: float, +) -> Tuple[str, Optional[Path]]: + """ + Generate FFmpeg filter chain with sendcmd for dynamic effects. + + Args: + effects: List of effect configs with name and params + analysis_data: Analysis data keyed by name + segment_start: Segment start time in source + segment_duration: Segment duration + + Returns: + (filter_chain_string, sendcmd_file_path or None) + """ + import tempfile + + compiler = FFmpegCompiler() + filters = [] + all_sendcmd_lines = [] + + for effect in effects: + effect_name = effect.get("effect") + params = {k: v for k, v in effect.items() if k != "effect"} + + filter_str, sendcmd, _ = compiler.compile_effect_with_bindings( + effect_name, + params, + analysis_data, + segment_start, + segment_duration, + ) + + if filter_str: + filters.append(filter_str) + if sendcmd: + all_sendcmd_lines.append(sendcmd) + + if not filters: + return "", None + + filter_chain = ",".join(filters) + + # NOTE: sendcmd disabled - FFmpeg's sendcmd filter has compatibility issues. + # Bindings use their initial value (sampled at segment start time). + # This is acceptable since each segment is only ~8 seconds. + # The binding value is still music-reactive (varies per segment). + sendcmd_path = None + + return filter_chain, sendcmd_path diff --git a/artdag/sexp/parser.py b/artdag/sexp/parser.py new file mode 100644 index 0000000..8f7b4a4 --- /dev/null +++ b/artdag/sexp/parser.py @@ -0,0 +1,425 @@ +""" +S-expression parser for ArtDAG recipes and plans. + +Supports: +- Lists: (a b c) +- Symbols: foo, bar-baz, -> +- Keywords: :key +- Strings: "hello world" +- Numbers: 42, 3.14, -1.5 +- Comments: ; to end of line +- Vectors: [a b c] (syntactic sugar for lists) +- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts) +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Union +import re + + +@dataclass +class Symbol: + """An unquoted symbol/identifier.""" + name: str + + def __repr__(self): + return f"Symbol({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Symbol): + return self.name == other.name + if isinstance(other, str): + return self.name == other + return False + + def __hash__(self): + return hash(self.name) + + +@dataclass +class Keyword: + """A keyword starting with colon.""" + name: str + + def __repr__(self): + return f"Keyword({self.name!r})" + + def __eq__(self, other): + if isinstance(other, Keyword): + return self.name == other.name + return False + + def __hash__(self): + return hash((':' , self.name)) + + +@dataclass +class Lambda: + """A lambda/anonymous function with closure.""" + params: List[str] # Parameter names + body: Any # Expression body + closure: Dict = None # Captured environment (optional for backwards compat) + + def __repr__(self): + return f"Lambda({self.params}, {self.body!r})" + + +@dataclass +class Binding: + """A binding to analysis data for dynamic effect parameters.""" + analysis_ref: str # Name of analysis variable + track: str = None # Optional track name (e.g., "bass", "energy") + range_min: float = 0.0 # Output range minimum + range_max: float = 1.0 # Output range maximum + transform: str = None # Optional transform: "sqrt", "pow2", "log", etc. + + def __repr__(self): + t = f", transform={self.transform!r}" if self.transform else "" + return f"Binding({self.analysis_ref!r}, track={self.track!r}, range=[{self.range_min}, {self.range_max}]{t})" + + +class ParseError(Exception): + """Error during S-expression parsing.""" + def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1): + self.position = position + self.line = line + self.col = col + super().__init__(f"{message} at line {line}, column {col}") + + +class Tokenizer: + """Tokenize S-expression text into tokens.""" + + # Token patterns + WHITESPACE = re.compile(r'\s+') + COMMENT = re.compile(r';[^\n]*') + STRING = re.compile(r'"(?:[^"\\]|\\.)*"') + NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?') + KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*') + SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*') + + def __init__(self, text: str): + self.text = text + self.pos = 0 + self.line = 1 + self.col = 1 + + def _advance(self, count: int = 1): + """Advance position, tracking line/column.""" + for _ in range(count): + if self.pos < len(self.text): + if self.text[self.pos] == '\n': + self.line += 1 + self.col = 1 + else: + self.col += 1 + self.pos += 1 + + def _skip_whitespace_and_comments(self): + """Skip whitespace and comments.""" + while self.pos < len(self.text): + # Whitespace + match = self.WHITESPACE.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + # Comments + match = self.COMMENT.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + continue + + break + + def peek(self) -> str | None: + """Peek at current character.""" + self._skip_whitespace_and_comments() + if self.pos >= len(self.text): + return None + return self.text[self.pos] + + def next_token(self) -> Any: + """Get the next token.""" + self._skip_whitespace_and_comments() + + if self.pos >= len(self.text): + return None + + char = self.text[self.pos] + start_line, start_col = self.line, self.col + + # Single-character tokens (parens, brackets, braces) + if char in '()[]{}': + self._advance() + return char + + # String + if char == '"': + match = self.STRING.match(self.text, self.pos) + if not match: + raise ParseError("Unterminated string", self.pos, self.line, self.col) + self._advance(match.end() - self.pos) + # Parse escape sequences + content = match.group()[1:-1] + content = content.replace('\\n', '\n') + content = content.replace('\\t', '\t') + content = content.replace('\\"', '"') + content = content.replace('\\\\', '\\') + return content + + # Keyword + if char == ':': + match = self.KEYWORD.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Keyword(match.group()[1:]) # Strip leading colon + raise ParseError(f"Invalid keyword", self.pos, self.line, self.col) + + # Number (must check before symbol due to - prefix) + if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and + (self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')): + match = self.NUMBER.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + num_str = match.group() + if '.' in num_str or 'e' in num_str or 'E' in num_str: + return float(num_str) + return int(num_str) + + # Symbol + match = self.SYMBOL.match(self.text, self.pos) + if match: + self._advance(match.end() - self.pos) + return Symbol(match.group()) + + raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col) + + +def parse(text: str) -> Any: + """ + Parse an S-expression string into Python data structures. + + Returns: + Parsed S-expression as nested Python structures: + - Lists become Python lists + - Symbols become Symbol objects + - Keywords become Keyword objects + - Strings become Python strings + - Numbers become int/float + + Example: + >>> parse('(recipe "test" :version "1.0")') + [Symbol('recipe'), 'test', Keyword('version'), '1.0'] + """ + tokenizer = Tokenizer(text) + result = _parse_expr(tokenizer) + + # Check for trailing content + if tokenizer.peek() is not None: + raise ParseError("Unexpected content after expression", + tokenizer.pos, tokenizer.line, tokenizer.col) + + return result + + +def parse_all(text: str) -> List[Any]: + """ + Parse multiple S-expressions from a string. + + Returns list of parsed expressions. + """ + tokenizer = Tokenizer(text) + results = [] + + while tokenizer.peek() is not None: + results.append(_parse_expr(tokenizer)) + + return results + + +def _parse_expr(tokenizer: Tokenizer) -> Any: + """Parse a single expression.""" + token = tokenizer.next_token() + + if token is None: + raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col) + + # List + if token == '(': + return _parse_list(tokenizer, ')') + + # Vector (sugar for list) + if token == '[': + return _parse_list(tokenizer, ']') + + # Map/dict: {:key1 val1 :key2 val2} + if token == '{': + return _parse_map(tokenizer) + + # Unexpected closers + if token in (')', ']', '}'): + raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col) + + # Atom + return token + + +def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]: + """Parse a list until the closing delimiter.""" + items = [] + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError(f"Unterminated list, expected {closer!r}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == closer: + tokenizer.next_token() # Consume closer + return items + + items.append(_parse_expr(tokenizer)) + + +def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]: + """Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}.""" + result = {} + + while True: + char = tokenizer.peek() + + if char is None: + raise ParseError("Unterminated map, expected '}'", + tokenizer.pos, tokenizer.line, tokenizer.col) + + if char == '}': + tokenizer.next_token() # Consume closer + return result + + # Parse key (should be a keyword like :key) + key_token = _parse_expr(tokenizer) + if isinstance(key_token, Keyword): + key = key_token.name + elif isinstance(key_token, str): + key = key_token + else: + raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}", + tokenizer.pos, tokenizer.line, tokenizer.col) + + # Parse value + value = _parse_expr(tokenizer) + result[key] = value + + +def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str: + """ + Serialize a Python data structure back to S-expression format. + + Args: + expr: The expression to serialize + indent: Current indentation level (for pretty printing) + pretty: Whether to use pretty printing with newlines + + Returns: + S-expression string + """ + if isinstance(expr, list): + if not expr: + return "()" + + if pretty: + return _serialize_pretty(expr, indent) + else: + items = [serialize(item, indent, False) for item in expr] + return "(" + " ".join(items) + ")" + + if isinstance(expr, Symbol): + return expr.name + + if isinstance(expr, Keyword): + return f":{expr.name}" + + if isinstance(expr, Lambda): + params = " ".join(expr.params) + body = serialize(expr.body, indent, pretty) + return f"(fn [{params}] {body})" + + if isinstance(expr, Binding): + # analysis_ref can be a string, node ID, or dict - serialize it properly + if isinstance(expr.analysis_ref, str): + ref_str = f'"{expr.analysis_ref}"' + else: + ref_str = serialize(expr.analysis_ref, indent, pretty) + s = f"(bind {ref_str} :range [{expr.range_min} {expr.range_max}]" + if expr.transform: + s += f" :transform {expr.transform}" + return s + ")" + + if isinstance(expr, str): + # Escape special characters + escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t') + return f'"{escaped}"' + + if isinstance(expr, bool): + return "true" if expr else "false" + + if isinstance(expr, (int, float)): + return str(expr) + + if expr is None: + return "nil" + + if isinstance(expr, dict): + # Serialize dict as property list: {:key1 val1 :key2 val2} + items = [] + for k, v in expr.items(): + items.append(f":{k}") + items.append(serialize(v, indent, pretty)) + return "{" + " ".join(items) + "}" + + raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}") + + +def _serialize_pretty(expr: List, indent: int) -> str: + """Pretty-print a list expression with smart formatting.""" + if not expr: + return "()" + + prefix = " " * indent + inner_prefix = " " * (indent + 1) + + # Check if this is a simple list that fits on one line + simple = serialize(expr, indent, False) + if len(simple) < 60 and '\n' not in simple: + return simple + + # Start building multiline output + head = serialize(expr[0], indent + 1, False) + parts = [f"({head}"] + + i = 1 + while i < len(expr): + item = expr[i] + + # Group keyword-value pairs on same line + if isinstance(item, Keyword) and i + 1 < len(expr): + key = serialize(item, 0, False) + val = serialize(expr[i + 1], indent + 1, False) + + # If value is short, put on same line + if len(val) < 50 and '\n' not in val: + parts.append(f"{inner_prefix}{key} {val}") + else: + # Value is complex, serialize it pretty + val_pretty = serialize(expr[i + 1], indent + 1, True) + parts.append(f"{inner_prefix}{key} {val_pretty}") + i += 2 + else: + # Regular item + item_str = serialize(item, indent + 1, True) + parts.append(f"{inner_prefix}{item_str}") + i += 1 + + return "\n".join(parts) + ")" diff --git a/artdag/sexp/planner.py b/artdag/sexp/planner.py new file mode 100644 index 0000000..ecd6595 --- /dev/null +++ b/artdag/sexp/planner.py @@ -0,0 +1,2187 @@ +""" +Execution plan generation from S-expression recipes. + +The planner: +1. Takes a compiled recipe + input content hashes +2. Runs analyzers to get concrete data (beat times, etc.) +3. Expands dynamic nodes (SLICE_ON) into primitive operations +4. Resolves all registry references to content hashes +5. Generates an execution plan with pre-computed cache IDs + +Plans are S-expressions with all references resolved to hashes, +ready for distribution to Celery workers. +""" + +import hashlib +import importlib.util +import json +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Callable + +from .parser import Symbol, Keyword, Binding, serialize +from .compiler import CompiledRecipe + + +# Node types that can be collapsed into a single FFmpeg filter chain +COLLAPSIBLE_TYPES = {"EFFECT", "SEGMENT"} + +# Node types that are boundaries (sources, merges, or special processing) +BOUNDARY_TYPES = {"SOURCE", "SEQUENCE", "MUX", "ANALYZE", "SCAN", "LIST"} + +# Node types that need expansion during planning +EXPANDABLE_TYPES = {"SLICE_ON", "CONSTRUCT"} + + +def _load_module(module_path: Path, module_name: str = "module"): + """Load a Python module from file path.""" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _run_analyzer( + analyzer_path: Path, + input_path: Path, + params: Dict[str, Any], +) -> Dict[str, Any]: + """Run an analyzer module and return results.""" + analyzer = _load_module(analyzer_path, "analyzer") + return analyzer.analyze(input_path, params) + + +def _pre_execute_segment( + node: Dict, + input_path: Path, + work_dir: Path, +) -> Path: + """ + Pre-execute a SEGMENT node during planning. + + This is needed when ANALYZE depends on a SEGMENT output. + Returns path to the segmented file. + """ + import subprocess + import tempfile + + config = node.get("config", {}) + start = config.get("start", 0) + duration = config.get("duration") + end = config.get("end") + + # Detect if input is audio-only + suffix = input_path.suffix.lower() + is_audio = suffix in ('.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a') + + if is_audio: + output_ext = ".m4a" # Use m4a for aac codec + else: + output_ext = ".mp4" + + output_path = work_dir / f"segment_{node['id'][:16]}{output_ext}" + + cmd = ["ffmpeg", "-y", "-i", str(input_path)] + if start: + cmd.extend(["-ss", str(start)]) + if duration: + cmd.extend(["-t", str(duration)]) + elif end: + cmd.extend(["-t", str(end - start)]) + + if is_audio: + cmd.extend(["-c:a", "aac", str(output_path)]) + else: + cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18", + "-c:a", "aac", str(output_path)]) + + subprocess.run(cmd, check=True, capture_output=True) + return output_path + + +def _serialize_for_hash(obj) -> str: + """Serialize any value to canonical S-expression string for hashing.""" + from .parser import Lambda + + if obj is None: + return "nil" + if isinstance(obj, bool): + return "true" if obj else "false" + if isinstance(obj, (int, float)): + return str(obj) + if isinstance(obj, str): + escaped = obj.replace('\\', '\\\\').replace('"', '\\"') + return f'"{escaped}"' + if isinstance(obj, Symbol): + return obj.name + if isinstance(obj, Keyword): + return f":{obj.name}" + if isinstance(obj, Lambda): + params = " ".join(obj.params) + body = _serialize_for_hash(obj.body) + return f"(fn [{params}] {body})" + if isinstance(obj, Binding): + # analysis_ref can be a string, node ID, or dict - serialize it properly + if isinstance(obj.analysis_ref, str): + ref_str = f'"{obj.analysis_ref}"' + else: + ref_str = _serialize_for_hash(obj.analysis_ref) + return f"(bind {ref_str} :range [{obj.range_min} {obj.range_max}])" + if isinstance(obj, dict): + items = [] + for k, v in sorted(obj.items()): + items.append(f":{k} {_serialize_for_hash(v)}") + return "{" + " ".join(items) + "}" + if isinstance(obj, list): + items = [_serialize_for_hash(x) for x in obj] + return "(" + " ".join(items) + ")" + return str(obj) + + +def _stable_hash(data: Any, cluster_key: str = None) -> str: + """Create stable SHA3-256 hash from data using S-expression serialization.""" + if cluster_key: + data = {"_cluster_key": cluster_key, "_data": data} + sexp_str = _serialize_for_hash(data) + return hashlib.sha3_256(sexp_str.encode()).hexdigest() + + +@dataclass +class PlanStep: + """A step in the execution plan.""" + step_id: str + node_type: str + config: Dict[str, Any] + inputs: List[str] # List of input step_ids + cache_id: str + level: int = 0 + stage: Optional[str] = None # Stage this step belongs to + + def to_sexp(self) -> List: + """Convert to S-expression.""" + sexp = [Symbol("step"), self.step_id] + + # Add cache-id + sexp.extend([Keyword("cache-id"), self.cache_id]) + + # Add level if > 0 + if self.level > 0: + sexp.extend([Keyword("level"), self.level]) + + # Add stage info if present + if self.stage: + sexp.extend([Keyword("stage"), self.stage]) + + # Add the node expression + node_sexp = [Symbol(self.node_type.lower())] + + # Add config as keywords + for key, value in self.config.items(): + # Convert Binding to sexp form + if isinstance(value, Binding): + value = [Symbol("bind"), value.analysis_ref, + Keyword("range"), [value.range_min, value.range_max]] + node_sexp.extend([Keyword(key), value]) + + # Add inputs if any + if self.inputs: + node_sexp.extend([Keyword("inputs"), self.inputs]) + + sexp.append(node_sexp) + return sexp + + +@dataclass +class StagePlan: + """A stage in the execution plan.""" + stage_name: str + steps: List[PlanStep] + requires: List[str] # Names of required stages + output_bindings: Dict[str, str] # binding_name -> cache_id of output + level: int = 0 # Stage level for parallel execution + + +@dataclass +class ExecutionPlanSexp: + """Execution plan as S-expression.""" + plan_id: str + steps: List[PlanStep] + output_step_id: str + source_hash: str = "" # CID of recipe source + params: Dict[str, Any] = field(default_factory=dict) # Resolved parameter values + params_hash: str = "" # Hash of params for quick comparison + inputs: Dict[str, str] = field(default_factory=dict) # name -> hash + analysis: Dict[str, Dict] = field(default_factory=dict) # name -> {times, values} + metadata: Dict[str, Any] = field(default_factory=dict) + stage_plans: List[StagePlan] = field(default_factory=list) # Stage-level plans + stage_order: List[str] = field(default_factory=list) # Topologically sorted stage names + stage_levels: Dict[str, int] = field(default_factory=dict) # stage_name -> level + effects_registry: Dict[str, Dict] = field(default_factory=dict) # effect_name -> {path, cid, ...} + minimal_primitives: bool = False # If True, interpreter uses only core primitives + + def to_sexp(self) -> List: + """Convert entire plan to S-expression.""" + sexp = [Symbol("plan")] + + # Metadata - purely content-addressed + sexp.extend([Keyword("id"), self.plan_id]) + sexp.extend([Keyword("source-cid"), self.source_hash]) # CID of recipe source + + # Parameters + if self.params: + sexp.extend([Keyword("params-hash"), self.params_hash]) + params_sexp = [Symbol("params")] + for name, value in self.params.items(): + params_sexp.append([Symbol(name), value]) + sexp.append(params_sexp) + + # Input bindings + if self.inputs: + inputs_sexp = [Symbol("inputs")] + for name, hash_val in self.inputs.items(): + inputs_sexp.append([Symbol(name), hash_val]) + sexp.append(inputs_sexp) + + # Analysis data (for effect parameter bindings) + if self.analysis: + analysis_sexp = [Symbol("analysis")] + for name, data in self.analysis.items(): + track_sexp = [Symbol(name)] + if isinstance(data, dict) and "_cache_id" in data: + track_sexp.extend([Keyword("cache-id"), data["_cache_id"]]) + else: + if "times" in data: + track_sexp.extend([Keyword("times"), data["times"]]) + if "values" in data: + track_sexp.extend([Keyword("values"), data["values"]]) + analysis_sexp.append(track_sexp) + sexp.append(analysis_sexp) + + # Stage information + if self.stage_plans: + stages_sexp = [Symbol("stages")] + for stage_plan in self.stage_plans: + stage_sexp = [ + Keyword("name"), stage_plan.stage_name, + Keyword("level"), stage_plan.level, + ] + if stage_plan.requires: + stage_sexp.extend([Keyword("requires"), stage_plan.requires]) + if stage_plan.output_bindings: + outputs_sexp = [] + for name, cache_id in stage_plan.output_bindings.items(): + outputs_sexp.append([Symbol(name), Keyword("cache-id"), cache_id]) + stage_sexp.extend([Keyword("outputs"), outputs_sexp]) + stages_sexp.append(stage_sexp) + sexp.append(stages_sexp) + + # Effects registry - for loading explicitly declared effects + if self.effects_registry: + registry_sexp = [Symbol("effects-registry")] + for name, info in self.effects_registry.items(): + effect_sexp = [Symbol(name)] + if info.get("path"): + effect_sexp.extend([Keyword("path"), info["path"]]) + if info.get("cid"): + effect_sexp.extend([Keyword("cid"), info["cid"]]) + registry_sexp.append(effect_sexp) + sexp.append(registry_sexp) + + # Minimal primitives flag + if self.minimal_primitives: + sexp.extend([Keyword("minimal-primitives"), True]) + + # Steps + for step in self.steps: + sexp.append(step.to_sexp()) + + # Output reference + sexp.extend([Keyword("output"), self.output_step_id]) + + return sexp + + def to_string(self, pretty: bool = True) -> str: + """Serialize plan to S-expression string.""" + return serialize(self.to_sexp(), pretty=pretty) + + +def _expand_list_inputs(nodes: List[Dict]) -> List[Dict]: + """ + Expand LIST node inputs in SEQUENCE nodes. + + When a SEQUENCE has a LIST as input, replace it with all the LIST's inputs. + LIST nodes that are referenced by non-SEQUENCE nodes (e.g., EFFECT chains) + are promoted to SEQUENCE nodes so they produce a concatenated output. + Unreferenced LIST nodes are removed. + """ + nodes_by_id = {n["id"]: n for n in nodes} + list_nodes = {n["id"]: n for n in nodes if n["type"] == "LIST"} + + if not list_nodes: + return nodes + + # Determine which LIST nodes are referenced by SEQUENCE vs other node types + list_consumed_by_seq = set() + list_referenced_by_other = set() + for node in nodes: + if node["type"] == "LIST": + continue + for inp in node.get("inputs", []): + if inp in list_nodes: + if node["type"] == "SEQUENCE": + list_consumed_by_seq.add(inp) + else: + list_referenced_by_other.add(inp) + + result = [] + for node in nodes: + if node["type"] == "LIST": + if node["id"] in list_referenced_by_other: + # Promote to SEQUENCE — non-SEQUENCE nodes reference this LIST + result.append({ + "id": node["id"], + "type": "SEQUENCE", + "config": node.get("config", {}), + "inputs": node.get("inputs", []), + }) + # Otherwise skip (consumed by SEQUENCE expansion or unreferenced) + continue + + if node["type"] == "SEQUENCE": + # Expand any LIST inputs + new_inputs = [] + for inp in node.get("inputs", []): + if inp in list_nodes: + # Replace LIST with its contents + new_inputs.extend(list_nodes[inp].get("inputs", [])) + else: + new_inputs.append(inp) + + # Create updated node + result.append({ + **node, + "inputs": new_inputs, + }) + else: + result.append(node) + + return result + + +def _collapse_effect_chains(nodes: List[Dict], registry: Dict = None) -> List[Dict]: + """ + Collapse sequential effect chains into single COMPOUND nodes. + + A chain is a sequence of single-input collapsible nodes where: + - Each node has exactly one input + - No node in the chain is referenced by multiple other nodes + - The chain ends at a boundary or multi-ref node + - No node in the chain is marked as temporal + + Effects can declare :temporal true to prevent collapsing (e.g., reverse). + + Returns a new node list with chains collapsed. + """ + if not nodes: + return nodes + + registry = registry or {} + nodes_by_id = {n["id"]: n for n in nodes} + + # Build reference counts: how many nodes reference each node as input + ref_count = {n["id"]: 0 for n in nodes} + for node in nodes: + for inp in node.get("inputs", []): + if inp in ref_count: + ref_count[inp] += 1 + + # Track which nodes are consumed by chains + consumed = set() + compound_nodes = [] + + def is_temporal(node: Dict) -> bool: + """Check if a node is temporal (needs complete input).""" + config = node.get("config", {}) + # Check node-level temporal flag + if config.get("temporal"): + return True + # Check effect registry for temporal flag + if node["type"] == "EFFECT": + effect_name = config.get("effect") + if effect_name: + effect_meta = registry.get("effects", {}).get(effect_name, {}) + if effect_meta.get("temporal"): + return True + return False + + def is_collapsible(node_id: str) -> bool: + """Check if a node can be part of a chain.""" + if node_id in consumed: + return False + node = nodes_by_id.get(node_id) + if not node: + return False + if node["type"] not in COLLAPSIBLE_TYPES: + return False + # Temporal effects can't be collapsed + if is_temporal(node): + return False + # Effects CAN be collapsed if they have an FFmpeg mapping + # Only fall back to Python interpreter if no mapping exists + config = node.get("config", {}) + if node["type"] == "EFFECT": + effect_name = config.get("effect") + # Import here to avoid circular imports + from .ffmpeg_compiler import FFmpegCompiler + compiler = FFmpegCompiler() + if compiler.get_mapping(effect_name): + return True # Has FFmpeg mapping, can collapse + elif config.get("effect_path"): + return False # No FFmpeg mapping, has Python path, can't collapse + return True + + def is_chain_boundary(node_id: str) -> bool: + """Check if a node is a chain boundary (can't be collapsed into).""" + node = nodes_by_id.get(node_id) + if not node: + return True # Unknown node is a boundary + # Boundary if: it's a boundary type, or referenced by multiple nodes + return node["type"] in BOUNDARY_TYPES or ref_count.get(node_id, 0) > 1 + + def collect_chain(start_id: str) -> List[str]: + """Collect a chain of collapsible nodes starting from start_id.""" + chain = [start_id] + current = start_id + + while True: + node = nodes_by_id[current] + inputs = node.get("inputs", []) + + # Must have exactly one input + if len(inputs) != 1: + break + + next_id = inputs[0] + + # Stop if next is a boundary or already consumed + if is_chain_boundary(next_id) or not is_collapsible(next_id): + break + + # Stop if next is referenced by others besides current + if ref_count.get(next_id, 0) > 1: + break + + chain.append(next_id) + current = next_id + + return chain + + # Process nodes in reverse order (from outputs toward inputs) + # This ensures we find complete chains starting from their end + # First, topologically sort to get dependency order + sorted_ids = [] + visited = set() + + def topo_visit(node_id: str): + if node_id in visited: + return + visited.add(node_id) + node = nodes_by_id.get(node_id) + if node: + for inp in node.get("inputs", []): + topo_visit(inp) + sorted_ids.append(node_id) + + for node in nodes: + topo_visit(node["id"]) + + # Process in reverse topological order (outputs first) + result_nodes = [] + + for node_id in reversed(sorted_ids): + node = nodes_by_id[node_id] + + if node_id in consumed: + continue + + if not is_collapsible(node_id): + # Keep boundary nodes as-is + result_nodes.append(node) + continue + + # Check if this node is the start of a chain (output end) + # A node is a chain start if it's collapsible and either: + # - Referenced by a boundary node + # - Referenced by multiple nodes + # - Is the output node + # For now, collect chain going backwards from this node + + chain = collect_chain(node_id) + + if len(chain) == 1: + # Single node, no collapse needed + result_nodes.append(node) + continue + + # Collapse the chain into a COMPOUND node + # Chain is [end, ..., start] order (backwards from output) + # The compound node: + # - Has the same ID as the chain end (for reference stability) + # - Takes input from what the chain start originally took + # - Has a filter_chain config with all the nodes in order + + chain_start = chain[-1] # First to execute + chain_end = chain[0] # Last to execute + + start_node = nodes_by_id[chain_start] + end_node = nodes_by_id[chain_end] + + # Build filter chain config (in execution order: start to end) + filter_chain = [] + for chain_node_id in reversed(chain): + chain_node = nodes_by_id[chain_node_id] + filter_chain.append({ + "type": chain_node["type"], + "config": chain_node.get("config", {}), + }) + + compound_node = { + "id": chain_end, # Keep the end ID for reference stability + "type": "COMPOUND", + "config": { + "filter_chain": filter_chain, + # Include effects registry so executor can load only declared effects + "effects_registry": registry.get("effects", {}), + }, + "inputs": start_node.get("inputs", []), + "name": f"compound_{len(filter_chain)}_effects", + } + + result_nodes.append(compound_node) + + # Mark all chain nodes as consumed + for chain_node_id in chain: + consumed.add(chain_node_id) + + return result_nodes + + +def _expand_slice_on( + node: Dict, + analysis_data: Dict[str, Any], + registry: Dict, + sources: Dict[str, str] = None, + cluster_key: str = None, + encoding: Dict = None, + named_analysis: Dict = None, +) -> List[Dict]: + """ + Expand a SLICE_ON node into primitive SEGMENT + EFFECT + SEQUENCE nodes. + + Supports two modes: + 1. Legacy: :effect and :pattern parameters + 2. Lambda: :fn parameter with reducer function + + Lambda syntax: + (slice-on analysis + :times times + :init 0 + :fn (lambda [acc i start end] + {:source video + :effects (if (odd? i) [invert] []) + :acc (inc acc)})) + + When all beats produce composition-mode results (layers + compositor) + with the same layer structure, consecutive beats are automatically merged + into fewer compositions with time-varying parameter bindings. This can + reduce thousands of nodes to a handful. + + Args: + node: The SLICE_ON node to expand + analysis_data: Analysis results containing times array + registry: Recipe registry with effect definitions + sources: Map of source names to node IDs + cluster_key: Optional cluster key for hashing + named_analysis: Mutable dict to inject synthetic analysis tracks into + + Returns: + List of expanded nodes (segments, effects, sequence) + """ + from .evaluator import evaluate, EvalError + from .parser import Lambda, Symbol + + config = node.get("config", {}) + node_inputs = node.get("inputs", []) + sources = sources or {} + + # Extract times + times_path = config.get("times_path", "times") + times = analysis_data + for key in times_path.split("."): + times = times[key] + + if not times: + raise ValueError(f"No times found at path '{times_path}' in analysis") + + # Default video input (first input after analysis) + default_video = node_inputs[0] if node_inputs else None + + expanded_nodes = [] + sequence_inputs = [] + base_id = node["id"][:8] + + # Check for lambda-based reducer + reducer_fn = config.get("fn") + + if isinstance(reducer_fn, Lambda): + # Lambda mode - evaluate function for each slice + acc = config.get("init", 0) + slice_times = list(zip([0] + times[:-1], times)) + + # Frame-accurate timing calculation + # Align ALL times to frame boundaries to prevent accumulating drift + fps = (encoding or {}).get("fps", 30) + frame_duration = 1.0 / fps + + # Get total duration from analysis data (beats analyzer includes this) + # Falls back to config target_duration for backwards compatibility + total_duration = analysis_data.get("duration") or config.get("target_duration") + + # Pre-compute frame-aligned cumulative times + cumulative_frames = [0] # Start at frame 0 + for t in times: + # Round to nearest frame boundary + frames = round(t * fps) + cumulative_frames.append(frames) + + # If total duration known, ensure last segment extends to it exactly + if total_duration is not None: + target_frames = round(total_duration * fps) + if target_frames > cumulative_frames[-1]: + cumulative_frames[-1] = target_frames + + # Pre-compute frame-aligned start times and durations for each slice + frame_aligned_starts = [] + frame_aligned_durations = [] + for i in range(len(cumulative_frames) - 1): + start_frames = cumulative_frames[i] + end_frames = cumulative_frames[i + 1] + frame_aligned_starts.append(start_frames * frame_duration) + frame_aligned_durations.append((end_frames - start_frames) * frame_duration) + + # Phase 1: Evaluate all lambdas upfront + videos = config.get("videos", []) + all_results = [] + all_timings = [] # (seg_start, seg_duration) per valid beat + original_indices = [] # original beat index for each result + + for i, (start, end) in enumerate(slice_times): + if start >= end: + continue + + # Build environment with sources, effects, and builtins + env = dict(sources) + + # Add effect names so they can be referenced as symbols + for effect_name in registry.get("effects", {}): + env[effect_name] = effect_name + + # Make :videos list available to lambda + if videos: + env["videos"] = videos + + env["acc"] = acc + env["i"] = i + env["start"] = start + env["end"] = end + + # Evaluate the reducer + result = evaluate([reducer_fn, Symbol("acc"), Symbol("i"), + Symbol("start"), Symbol("end")], env) + + if not isinstance(result, dict): + raise ValueError(f"Reducer must return a dict, got {type(result)}") + + # Extract accumulator + acc = result.get("acc", acc) + + # Segment timing: use frame-aligned values to prevent drift + # Lambda can override with explicit start/duration/end + if result.get("start") is not None or result.get("duration") is not None or result.get("end") is not None: + # Explicit timing from lambda - use as-is + seg_start = result.get("start", start) + seg_duration = result.get("duration") + if seg_duration is None: + if result.get("end") is not None: + seg_duration = result["end"] - seg_start + else: + seg_duration = end - start + else: + # Default: use frame-aligned start and duration to prevent accumulated drift + seg_start = frame_aligned_starts[i] if i < len(frame_aligned_starts) else start + seg_duration = frame_aligned_durations[i] if i < len(frame_aligned_durations) else (end - start) + + all_results.append(result) + all_timings.append((seg_start, seg_duration)) + original_indices.append(i) + + # Phase 2: Merge or expand + all_composition = ( + len(all_results) > 1 + and all("layers" in r for r in all_results) + and named_analysis is not None + ) + + if all_composition: + # All beats are composition mode — try to merge consecutive + # beats with the same layer structure + _merge_composition_beats( + all_results, all_timings, base_id, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, + ) + else: + # Fallback: expand each beat individually + for idx, result in enumerate(all_results): + orig_i = original_indices[idx] + seg_start, seg_duration = all_timings[idx] + + if "layers" in result: + # COMPOSITION MODE — multi-source with per-layer effects + compositor + _expand_composition_beat( + result, orig_i, base_id, videos, registry, + seg_start, seg_duration, expanded_nodes, sequence_inputs, + ) + else: + # SINGLE-SOURCE MODE (existing behavior) + source_name = result.get("source") + effects = result.get("effects", []) + + # Resolve source to node ID + if isinstance(source_name, Symbol): + source_name = source_name.name + valid_node_ids = set(sources.values()) + if source_name in sources: + video_input = sources[source_name] + elif source_name in valid_node_ids: + video_input = source_name + else: + video_input = default_video + + # Create SEGMENT node + segment_id = f"{base_id}_seg_{orig_i:04d}" + segment_node = { + "id": segment_id, + "type": "SEGMENT", + "config": { + "start": seg_start, + "duration": seg_duration, + }, + "inputs": [video_input], + } + expanded_nodes.append(segment_node) + + # Apply effects chain + current_input = segment_id + for j, effect in enumerate(effects): + effect_name, effect_params = _parse_effect_spec(effect) + if not effect_name: + continue + + effect_id = f"{base_id}_fx_{orig_i:04d}_{j}" + effect_entry = registry.get("effects", {}).get(effect_name, {}) + + effect_config = { + "effect": effect_name, + "effect_path": effect_entry.get("path"), + } + effect_config.update(effect_params) + + effect_node = { + "id": effect_id, + "type": "EFFECT", + "config": effect_config, + "inputs": [current_input], + } + expanded_nodes.append(effect_node) + current_input = effect_id + + sequence_inputs.append(current_input) + + else: + # Legacy mode - :effect and :pattern + effect_name = config.get("effect") + effect_path = config.get("effect_path") + pattern = config.get("pattern", "all") + video_input = default_video + + if not video_input: + raise ValueError("SLICE_ON requires video input") + + slice_times = list(zip([0] + times[:-1], times)) + + for i, (start, end) in enumerate(slice_times): + if start >= end: + continue + + # Determine if effect should be applied + apply_effect = False + if effect_name: + if pattern == "all": + apply_effect = True + elif pattern == "odd": + apply_effect = (i % 2 == 1) + elif pattern == "even": + apply_effect = (i % 2 == 0) + elif pattern == "alternate": + apply_effect = (i % 2 == 1) + + # Create SEGMENT node + segment_id = f"{base_id}_seg_{i:04d}" + segment_node = { + "id": segment_id, + "type": "SEGMENT", + "config": { + "start": start, + "duration": end - start, + }, + "inputs": [video_input], + } + expanded_nodes.append(segment_node) + + if apply_effect: + effect_id = f"{base_id}_fx_{i:04d}" + effect_config = {"effect": effect_name} + if effect_path: + effect_config["effect_path"] = effect_path + + effect_node = { + "id": effect_id, + "type": "EFFECT", + "config": effect_config, + "inputs": [segment_id], + } + expanded_nodes.append(effect_node) + sequence_inputs.append(effect_id) + else: + sequence_inputs.append(segment_id) + # Create LIST node to hold all slices (user must explicitly sequence them) + list_node = { + "id": node["id"], # Keep original ID for reference stability + "type": "LIST", + "config": {}, + "inputs": sequence_inputs, + } + expanded_nodes.append(list_node) + + return expanded_nodes + + +def _parse_effect_spec(effect): + """Parse an effect spec into (name, params) from Symbol, string, or dict.""" + from .parser import Symbol + + effect_name = None + effect_params = {} + + if isinstance(effect, Symbol): + effect_name = effect.name + elif isinstance(effect, str): + effect_name = effect + elif isinstance(effect, dict): + effect_name = effect.get("effect") + if isinstance(effect_name, Symbol): + effect_name = effect_name.name + for k, v in effect.items(): + if k != "effect": + effect_params[k] = v + + return effect_name, effect_params + + +def _expand_composition_beat(result, beat_idx, base_id, videos, registry, + seg_start, seg_duration, expanded_nodes, sequence_inputs): + """ + Expand a composition-mode beat into per-layer SEGMENT + EFFECT nodes + and a single composition EFFECT node. + + Args: + result: Lambda result dict with 'layers' and optional 'compose' + beat_idx: Beat index for ID generation + base_id: Base ID prefix + videos: List of video node IDs from :videos config + registry: Recipe registry with effect definitions + seg_start: Segment start time + seg_duration: Segment duration + expanded_nodes: List to append generated nodes to + sequence_inputs: List to append final composition node ID to + """ + layers = result["layers"] + compose_spec = result.get("compose", {}) + + layer_outputs = [] + for layer_idx, layer in enumerate(layers): + # Resolve video: integer index into videos list, or node ID string + video_ref = layer.get("video") + if isinstance(video_ref, (int, float)): + video_input = videos[int(video_ref)] + else: + video_input = str(video_ref) + + # SEGMENT for this layer + segment_id = f"{base_id}_seg_{beat_idx:04d}_L{layer_idx}" + expanded_nodes.append({ + "id": segment_id, + "type": "SEGMENT", + "config": {"start": seg_start, "duration": seg_duration}, + "inputs": [video_input], + }) + + # Per-layer EFFECT chain + current = segment_id + for fx_idx, effect in enumerate(layer.get("effects", [])): + effect_name, effect_params = _parse_effect_spec(effect) + if not effect_name: + continue + effect_id = f"{base_id}_fx_{beat_idx:04d}_L{layer_idx}_{fx_idx}" + effect_entry = registry.get("effects", {}).get(effect_name, {}) + config = { + "effect": effect_name, + "effect_path": effect_entry.get("path"), + } + config.update(effect_params) + expanded_nodes.append({ + "id": effect_id, + "type": "EFFECT", + "config": config, + "inputs": [current], + }) + current = effect_id + layer_outputs.append(current) + + # Composition EFFECT node + compose_name = compose_spec.get("effect", "blend_multi") + compose_id = f"{base_id}_comp_{beat_idx:04d}" + compose_entry = registry.get("effects", {}).get(compose_name, {}) + compose_config = { + "effect": compose_name, + "effect_path": compose_entry.get("path"), + "multi_input": True, + } + for k, v in compose_spec.items(): + if k != "effect": + compose_config[k] = v + + expanded_nodes.append({ + "id": compose_id, + "type": "EFFECT", + "config": compose_config, + "inputs": layer_outputs, + }) + sequence_inputs.append(compose_id) + + +def _fingerprint_composition(result): + """Create a hashable fingerprint of a composition beat's layer structure. + + Beats with the same fingerprint have the same video refs, effect names, + and compositor type — only parameter values differ. Such beats can be + merged into a single composition with time-varying bindings. + """ + layers = result.get("layers", []) + compose = result.get("compose", {}) + + layer_fps = [] + for layer in layers: + video_ref = layer.get("video") + effect_names = tuple( + _parse_effect_spec(e)[0] for e in layer.get("effects", []) + ) + layer_fps.append((video_ref, effect_names)) + + compose_name = compose.get("effect", "blend_multi") + # Include static compose params (excluding list-valued params like weights) + static_compose = tuple(sorted( + (k, v) for k, v in compose.items() + if k not in ("effect", "weights") and isinstance(v, (str, int, float, bool)) + )) + + return (len(layers), tuple(layer_fps), compose_name, static_compose) + + +def _merge_composition_beats( + all_results, all_timings, base_id, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, +): + """Merge consecutive composition beats with the same layer structure. + + Groups consecutive beats by structural fingerprint. Groups of 2+ beats + get merged into a single composition with synthetic analysis tracks for + time-varying parameters. Single beats use standard per-beat expansion. + """ + import sys + + # Compute fingerprints + fingerprints = [_fingerprint_composition(r) for r in all_results] + + # Group consecutive beats with the same fingerprint + groups = [] # list of (start_idx, end_idx_exclusive) + group_start = 0 + for i in range(1, len(fingerprints)): + if fingerprints[i] != fingerprints[group_start]: + groups.append((group_start, i)) + group_start = i + groups.append((group_start, len(fingerprints))) + + print(f" Composition merging: {len(all_results)} beats -> {len(groups)} groups", file=sys.stderr) + + for group_idx, (g_start, g_end) in enumerate(groups): + group_size = g_end - g_start + + if group_size == 1: + # Single beat — use standard expansion + result = all_results[g_start] + seg_start, seg_duration = all_timings[g_start] + _expand_composition_beat( + result, g_start, base_id, videos, registry, + seg_start, seg_duration, expanded_nodes, sequence_inputs, + ) + else: + # Merge group into one composition with time-varying bindings + _merge_composition_group( + all_results, all_timings, + list(range(g_start, g_end)), + base_id, group_idx, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, + ) + + +def _merge_composition_group( + all_results, all_timings, group_indices, + base_id, group_idx, videos, registry, + expanded_nodes, sequence_inputs, named_analysis, +): + """Merge a group of same-structure composition beats into one composition. + + Creates: + - One SEGMENT per layer (spanning full group duration) + - One EFFECT per layer with time-varying params via synthetic analysis tracks + - One compositor EFFECT with time-varying weights via synthetic tracks + """ + import sys + + first = all_results[group_indices[0]] + layers = first["layers"] + compose_spec = first.get("compose", {}) + num_layers = len(layers) + + # Group timing + first_start = all_timings[group_indices[0]][0] + last_start, last_dur = all_timings[group_indices[-1]] + group_duration = (last_start + last_dur) - first_start + + # Beat start times for synthetic tracks (absolute times) + beat_times = [float(all_timings[i][0]) for i in group_indices] + + print(f" Group {group_idx}: {len(group_indices)} beats, " + f"{first_start:.1f}s -> {first_start + group_duration:.1f}s " + f"({num_layers} layers)", file=sys.stderr) + + # --- Per-layer segments and effects --- + layer_outputs = [] + for layer_idx in range(num_layers): + layer = layers[layer_idx] + + # Resolve video input + video_ref = layer.get("video") + if isinstance(video_ref, (int, float)): + video_input = videos[int(video_ref)] + else: + video_input = str(video_ref) + + # SEGMENT for this layer (full group duration) + segment_id = f"{base_id}_seg_G{group_idx:03d}_L{layer_idx}" + expanded_nodes.append({ + "id": segment_id, + "type": "SEGMENT", + "config": {"start": first_start, "duration": group_duration}, + "inputs": [video_input], + }) + + # Per-layer EFFECT chain + current = segment_id + effects = layer.get("effects", []) + for fx_idx, effect in enumerate(effects): + effect_name, first_params = _parse_effect_spec(effect) + if not effect_name: + continue + + effect_id = f"{base_id}_fx_G{group_idx:03d}_L{layer_idx}_{fx_idx}" + effect_entry = registry.get("effects", {}).get(effect_name, {}) + fx_config = { + "effect": effect_name, + "effect_path": effect_entry.get("path"), + } + + # For each param, check if it varies across beats + for param_name, first_val in first_params.items(): + values = [] + for bi in group_indices: + beat_layer = all_results[bi]["layers"][layer_idx] + beat_effects = beat_layer.get("effects", []) + if fx_idx < len(beat_effects): + _, beat_params = _parse_effect_spec(beat_effects[fx_idx]) + values.append(float(beat_params.get(param_name, first_val))) + else: + values.append(float(first_val)) + + # Check if all values are identical + if all(v == values[0] for v in values): + fx_config[param_name] = values[0] + else: + # Create synthetic analysis track + # Prefix with 'syn_' to ensure valid S-expression symbol + # (base_id may start with digits, which the parser splits) + track_name = f"syn_{base_id}_L{layer_idx}_fx{fx_idx}_{param_name}" + named_analysis[track_name] = { + "times": beat_times, + "values": values, + } + fx_config[param_name] = { + "_binding": True, + "source": track_name, + "feature": "values", + "range": [0.0, 1.0], # pass-through + } + + expanded_nodes.append({ + "id": effect_id, + "type": "EFFECT", + "config": fx_config, + "inputs": [current], + }) + current = effect_id + + layer_outputs.append(current) + + # --- Compositor --- + compose_name = compose_spec.get("effect", "blend_multi") + compose_id = f"{base_id}_comp_G{group_idx:03d}" + compose_entry = registry.get("effects", {}).get(compose_name, {}) + compose_config = { + "effect": compose_name, + "effect_path": compose_entry.get("path"), + "multi_input": True, + } + + for k, v in compose_spec.items(): + if k == "effect": + continue + + if isinstance(v, list): + # List param (e.g., weights) — check each element + merged_list = [] + for elem_idx in range(len(v)): + elem_values = [] + for bi in group_indices: + beat_compose = all_results[bi].get("compose", {}) + beat_v = beat_compose.get(k, v) + if isinstance(beat_v, list) and elem_idx < len(beat_v): + elem_values.append(float(beat_v[elem_idx])) + else: + elem_values.append(float(v[elem_idx])) + + if all(ev == elem_values[0] for ev in elem_values): + merged_list.append(elem_values[0]) + else: + track_name = f"syn_{base_id}_comp_{k}_{elem_idx}" + named_analysis[track_name] = { + "times": beat_times, + "values": elem_values, + } + merged_list.append({ + "_binding": True, + "source": track_name, + "feature": "values", + "range": [0.0, 1.0], + }) + compose_config[k] = merged_list + elif isinstance(v, (int, float)): + # Scalar param — check if it varies + values = [] + for bi in group_indices: + beat_compose = all_results[bi].get("compose", {}) + values.append(float(beat_compose.get(k, v))) + + if all(val == values[0] for val in values): + compose_config[k] = values[0] + else: + track_name = f"syn_{base_id}_comp_{k}" + named_analysis[track_name] = { + "times": beat_times, + "values": values, + } + compose_config[k] = { + "_binding": True, + "source": track_name, + "feature": "values", + "range": [0.0, 1.0], + } + else: + # String or other — keep as-is + compose_config[k] = v + + expanded_nodes.append({ + "id": compose_id, + "type": "EFFECT", + "config": compose_config, + "inputs": layer_outputs, + }) + sequence_inputs.append(compose_id) + + +def _parse_construct_params(params_list: list) -> tuple: + """ + Parse :params block in a construct definition. + + Syntax: + ( + (param_name :type string :default "value" :desc "description") + ) + + Returns: + (param_names, param_defaults) where param_names is a list of strings + and param_defaults is a dict of param_name -> default_value + """ + param_names = [] + param_defaults = {} + + for param_def in params_list: + if not isinstance(param_def, list) or len(param_def) < 1: + continue + + # First element is the parameter name + first = param_def[0] + if isinstance(first, Symbol): + param_name = first.name + elif isinstance(first, str): + param_name = first + else: + continue + + param_names.append(param_name) + + # Parse keyword arguments + default = None + i = 1 + while i < len(param_def): + item = param_def[i] + if isinstance(item, Keyword): + if i + 1 >= len(param_def): + break + kw_value = param_def[i + 1] + + if item.name == "default": + default = kw_value + # We could also parse :type, :range, :choices, :desc here + i += 2 + else: + i += 1 + + param_defaults[param_name] = default + + return param_names, param_defaults + + +def _expand_construct( + node: Dict, + registry: Dict, + sources: Dict[str, str], + analysis_data: Dict[str, Dict], + recipe_dir: Path, + cluster_key: str = None, + encoding: Dict = None, +) -> List[Dict]: + """ + Expand a user-defined CONSTRUCT node. + + Loads the construct definition from .sexp file, evaluates it with + the provided arguments, and converts the result into segment nodes. + + Args: + node: The CONSTRUCT node to expand + registry: Recipe registry + sources: Map of source names to node IDs + analysis_data: Analysis results (analysis_id -> {times, values}) + recipe_dir: Recipe directory for resolving paths + cluster_key: Optional cluster key for hashing + encoding: Encoding config + + Returns: + List of expanded nodes (segments, effects, list) + """ + from .parser import parse_all, Symbol + from .evaluator import evaluate + + config = node.get("config", {}) + construct_name = config.get("construct_name") + construct_path = config.get("construct_path") + args = config.get("args", []) + + # Load construct definition + full_path = recipe_dir / construct_path + if not full_path.exists(): + raise ValueError(f"Construct file not found: {full_path}") + + print(f" Loading construct: {construct_name} from {construct_path}", file=sys.stderr) + + construct_text = full_path.read_text() + construct_sexp = parse_all(construct_text) + + # Parse define-construct: (define-construct name "desc" (params...) body) + if not isinstance(construct_sexp, list): + construct_sexp = [construct_sexp] + + # Process imports (effect, construct declarations) in the construct file + # These extend the registry for this construct's scope + local_registry = dict(registry) # Copy parent registry + construct_def = None + + for expr in construct_sexp: + if isinstance(expr, list) and expr and isinstance(expr[0], Symbol): + form_name = expr[0].name + + if form_name == "effect": + # (effect name :path "...") + effect_name = expr[1].name if isinstance(expr[1], Symbol) else expr[1] + # Parse kwargs + i = 2 + kwargs = {} + while i < len(expr): + if isinstance(expr[i], Keyword): + kwargs[expr[i].name] = expr[i + 1] if i + 1 < len(expr) else None + i += 2 + else: + i += 1 + local_registry.setdefault("effects", {})[effect_name] = { + "path": kwargs.get("path"), + "cid": kwargs.get("cid"), + } + print(f" Construct imports effect: {effect_name}", file=sys.stderr) + + elif form_name == "define-construct": + construct_def = expr + + if not construct_def: + raise ValueError(f"No define-construct found in {construct_path}") + + # Use local_registry instead of registry from here + registry = local_registry + + # Parse define-construct - requires :params syntax: + # (define-construct name + # :params ( + # (param1 :type string :default "value" :desc "description") + # ) + # body) + # + # Legacy syntax (define-construct name "desc" (param1 param2) body) is not supported. + def_name = construct_def[1].name if isinstance(construct_def[1], Symbol) else construct_def[1] + + params = [] # List of param names + param_defaults = {} # param_name -> default value + body = None + found_params = False + + idx = 2 + while idx < len(construct_def): + item = construct_def[idx] + if isinstance(item, Keyword) and item.name == "params": + # :params syntax + if idx + 1 >= len(construct_def): + raise ValueError(f"Construct '{def_name}': Missing params list after :params keyword") + params_list = construct_def[idx + 1] + params, param_defaults = _parse_construct_params(params_list) + found_params = True + idx += 2 + elif isinstance(item, Keyword): + # Skip other keywords (like :desc) + idx += 2 + elif isinstance(item, str): + # Skip description strings (but warn about legacy format) + print(f" Warning: Description strings in define-construct are deprecated", file=sys.stderr) + idx += 1 + elif body is None: + # First non-keyword, non-string item is the body + if isinstance(item, list) and item: + first_elem = item[0] + # Check for legacy params syntax and reject it + if isinstance(first_elem, Symbol) and first_elem.name not in ("let", "let*", "if", "when", "do", "begin", "->", "map", "filter", "fn", "reduce", "nth"): + # Could be legacy params if all items are just symbols + if all(isinstance(p, Symbol) for p in item): + raise ValueError( + f"Construct '{def_name}': Legacy parameter syntax (param1 param2) is not supported. " + f"Use :params block instead." + ) + body = item + idx += 1 + else: + idx += 1 + + if body is None: + raise ValueError(f"No body found in define-construct {def_name}") + + # Build environment with sources and analysis data + env = dict(sources) + + # Add bindings from compiler (video-a, video-b, etc.) + if "bindings" in config: + env.update(config["bindings"]) + + # Add effect names so they can be referenced as symbols + for effect_name in registry.get("effects", {}): + env[effect_name] = effect_name + + # Map analysis node IDs to their data with :times and :values + for analysis_id, data in analysis_data.items(): + # Find the name this analysis was bound to + for name, node_id in sources.items(): + if node_id == analysis_id or name.endswith("-data"): + env[name] = data + env[analysis_id] = data + + # Apply param defaults first (for :params syntax) + for param_name, default_value in param_defaults.items(): + if default_value is not None: + env[param_name] = default_value + + # Bind positional args to params (overrides defaults) + param_names = [p.name if isinstance(p, Symbol) else p for p in params] + for i, param in enumerate(param_names): + if i < len(args): + arg = args[i] + # Resolve node IDs to their data if it's analysis + if isinstance(arg, str) and arg in analysis_data: + env[param] = analysis_data[arg] + else: + env[param] = arg + + # Helper to resolve node IDs to analysis data recursively + def resolve_value(val): + """Resolve node IDs and symbols in a value, including inside dicts/lists.""" + if isinstance(val, str) and val in analysis_data: + return analysis_data[val] + elif isinstance(val, str) and val in env: + return env[val] + elif isinstance(val, Symbol): + if val.name in env: + return env[val.name] + return val + elif isinstance(val, dict): + return {k: resolve_value(v) for k, v in val.items()} + elif isinstance(val, list): + return [resolve_value(v) for v in val] + return val + + # Validate and bind keyword arguments from the config (excluding internal keys) + # These may be S-expressions that need evaluation (e.g., lambdas) + # or Symbols that need resolution from bindings + internal_keys = {"construct_name", "construct_path", "args", "bindings"} + known_params = set(param_names) | set(param_defaults.keys()) + for key, value in config.items(): + if key not in internal_keys: + # Convert key to valid identifier (replace - with _) for checking + param_key = key.replace("-", "_") + if param_key not in known_params: + raise ValueError( + f"Construct '{def_name}': Unknown parameter '{key}'. " + f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}" + ) + # Evaluate if it's an expression (list starting with Symbol) + if isinstance(value, list) and value and isinstance(value[0], Symbol): + env[param_key] = evaluate(value, env) + elif isinstance(value, Symbol): + # Resolve Symbol from env/bindings, then resolve any node IDs in the value + if value.name in env: + env[param_key] = resolve_value(env[value.name]) + else: + raise ValueError(f"Undefined symbol in construct arg: {value.name}") + else: + # Resolve node IDs inside dicts/lists + env[param_key] = resolve_value(value) + + # Evaluate construct body + print(f" Evaluating construct with params: {param_names}", file=sys.stderr) + segments = evaluate(body, env) + + if not isinstance(segments, list): + raise ValueError(f"Construct must return a list of segments, got {type(segments)}") + + print(f" Construct produced {len(segments)} segments", file=sys.stderr) + + # Convert segment descriptors to plan nodes + expanded_nodes = [] + sequence_inputs = [] + base_id = node["id"][:8] + + for i, seg in enumerate(segments): + if not isinstance(seg, dict): + continue + + source_ref = seg.get("source") + start = seg.get("start", 0) + print(f" DEBUG segment {i}: source={str(source_ref)[:20]}... start={start}", file=sys.stderr) + end = seg.get("end") + duration = seg.get("duration") or (end - start if end else 1.0) + effects = seg.get("effects", []) + + # Resolve source reference to node ID + source_id = sources.get(source_ref, source_ref) if isinstance(source_ref, str) else source_ref + + # Create segment node + segment_id = f"{base_id}_seg_{i:04d}" + segment_node = { + "id": segment_id, + "type": "SEGMENT", + "config": { + "start": start, + "duration": duration, + }, + "inputs": [source_id] if source_id else [], + } + expanded_nodes.append(segment_node) + + # Add effects if specified + if effects: + prev_id = segment_id + for j, eff in enumerate(effects): + effect_name = eff.get("effect") if isinstance(eff, dict) else eff + effect_id = f"{base_id}_fx_{i:04d}_{j:02d}" + # Look up effect_path from registry (prevents collapsing Python effects) + effect_entry = registry.get("effects", {}).get(effect_name, {}) + effect_config = { + "effect": effect_name, + **{k: v for k, v in (eff.items() if isinstance(eff, dict) else []) if k != "effect"}, + } + if effect_entry.get("path"): + effect_config["effect_path"] = effect_entry["path"] + effect_node = { + "id": effect_id, + "type": "EFFECT", + "config": effect_config, + "inputs": [prev_id], + } + expanded_nodes.append(effect_node) + prev_id = effect_id + sequence_inputs.append(prev_id) + else: + sequence_inputs.append(segment_id) + + # Create LIST node + list_node = { + "id": node["id"], + "type": "LIST", + "config": {}, + "inputs": sequence_inputs, + } + expanded_nodes.append(list_node) + + return expanded_nodes + + +def _expand_nodes( + nodes: List[Dict], + registry: Dict, + recipe_dir: Path, + source_paths: Dict[str, Path], + work_dir: Path = None, + cluster_key: str = None, + on_analysis: Callable[[str, Dict], None] = None, + encoding: Dict = None, + pre_analysis: Dict[str, Dict] = None, +) -> List[Dict]: + """ + Expand dynamic nodes (SLICE_ON) by running analyzers. + + Processes nodes in dependency order: + 1. SOURCE nodes: resolve file paths + 2. SEGMENT nodes: pre-execute if needed for analysis + 3. ANALYZE nodes: run analyzers (or use pre_analysis), store results + 4. SLICE_ON nodes: expand using analysis results + + Args: + nodes: List of compiled nodes + registry: Recipe registry + recipe_dir: Directory for resolving relative paths + source_paths: Resolved source paths (id -> path) + work_dir: Working directory for temporary files (created if None) + cluster_key: Optional cluster key + on_analysis: Callback when analysis completes (node_id, results) + pre_analysis: Pre-computed analysis data (name -> results) + + Returns: + Tuple of (expanded_nodes, named_analysis) where: + - expanded_nodes: List with SLICE_ON replaced by primitives + - named_analysis: Dict of analyzer_name -> {times, values} + """ + import tempfile + + nodes_by_id = {n["id"]: n for n in nodes} + sorted_ids = _topological_sort(nodes) + + # Create work directory if needed + if work_dir is None: + work_dir = Path(tempfile.mkdtemp(prefix="artdag_plan_")) + + # Track outputs and analysis results + outputs = {} # node_id -> output path or analysis data + analysis_results = {} # node_id -> analysis dict + named_analysis = {} # analyzer_name -> analysis dict (for effect bindings) + pre_executed = set() # nodes pre-executed during planning + expanded = [] + expanded_ids = set() + + for node_id in sorted_ids: + node = nodes_by_id[node_id] + node_type = node["type"] + + if node_type == "SOURCE": + # Resolve source path + config = node.get("config", {}) + if "path" in config: + path = recipe_dir / config["path"] + outputs[node_id] = path.resolve() + source_paths[node_id] = outputs[node_id] + expanded.append(node) + expanded_ids.add(node_id) + + elif node_type == "SEGMENT": + # Check if this segment's input is resolved + inputs = node.get("inputs", []) + if inputs and inputs[0] in outputs: + input_path = outputs[inputs[0]] + if isinstance(input_path, Path): + # Skip pre-execution if config contains unresolved bindings + seg_config = node.get("config", {}) + has_binding = any( + isinstance(v, Binding) or (isinstance(v, dict) and v.get("_binding")) + for v in [seg_config.get("start"), seg_config.get("duration"), seg_config.get("end")] + if v is not None + ) + if not has_binding: + # Pre-execute segment to get output path + # This is needed if ANALYZE depends on this segment + import sys + print(f" Pre-executing segment: {node_id[:16]}...", file=sys.stderr) + output_path = _pre_execute_segment(node, input_path, work_dir) + outputs[node_id] = output_path + pre_executed.add(node_id) + expanded.append(node) + expanded_ids.add(node_id) + + elif node_type == "ANALYZE": + # Get or run analysis + config = node.get("config", {}) + analysis_name = node.get("name") or config.get("analyzer") + + # Check for pre-computed analysis first + if pre_analysis and analysis_name and analysis_name in pre_analysis: + import sys + print(f" Using pre-computed analysis: {analysis_name}", file=sys.stderr) + results = pre_analysis[analysis_name] + else: + # Run analyzer to get concrete data + analyzer_path = config.get("analyzer_path") + node_inputs = node.get("inputs", []) + + if not node_inputs: + raise ValueError(f"ANALYZE node {node_id} has no inputs") + + # Get input path - could be SOURCE or pre-executed SEGMENT + input_id = node_inputs[0] + input_path = outputs.get(input_id) + + if input_path is None: + raise ValueError( + f"ANALYZE input {input_id} not resolved. " + "Check that input SOURCE or SEGMENT exists." + ) + + if not isinstance(input_path, Path): + raise ValueError( + f"ANALYZE input {input_id} is not a file path: {type(input_path)}" + ) + + if analyzer_path: + full_path = recipe_dir / analyzer_path + params = {k: v for k, v in config.items() + if k not in ("analyzer", "analyzer_path", "cid")} + import sys + print(f" Running analyzer: {config.get('analyzer', 'unknown')}", file=sys.stderr) + results = _run_analyzer(full_path, input_path, params) + else: + raise ValueError(f"ANALYZE node {node_id} missing analyzer_path") + + analysis_results[node_id] = results + outputs[node_id] = results + + # Store by name for effect binding resolution + if analysis_name: + named_analysis[analysis_name] = results + + if on_analysis: + on_analysis(node_id, results) + + # Keep ANALYZE node in plan (it produces a JSON artifact) + expanded.append(node) + expanded_ids.add(node_id) + + elif node_type == "SLICE_ON": + # Expand into primitives using analysis results + inputs = node.get("inputs", []) + config = node.get("config", {}) + + # Lambda mode can have just 1 input (analysis), legacy needs 2 (video + analysis) + has_lambda = "fn" in config + if has_lambda: + if len(inputs) < 1: + raise ValueError(f"SLICE_ON {node_id} requires analysis input") + analysis_id = inputs[0] # First input is analysis + else: + if len(inputs) < 2: + raise ValueError(f"SLICE_ON {node_id} requires video and analysis inputs") + analysis_id = inputs[1] + + if analysis_id not in analysis_results: + raise ValueError( + f"SLICE_ON {node_id} analysis input {analysis_id} not found" + ) + + # Build sources map: name -> node_id + # This lets the lambda reference videos by name + sources = {} + for n in nodes: + if n.get("name"): + sources[n["name"]] = n["id"] + + analysis_data = analysis_results[analysis_id] + slice_nodes = _expand_slice_on(node, analysis_data, registry, sources, cluster_key, encoding, named_analysis) + + for sn in slice_nodes: + if sn["id"] not in expanded_ids: + expanded.append(sn) + expanded_ids.add(sn["id"]) + + elif node_type == "CONSTRUCT": + # Expand user-defined construct + config = node.get("config", {}) + construct_name = config.get("construct_name") + construct_path = config.get("construct_path") + + if not construct_path: + raise ValueError(f"CONSTRUCT {node_id} missing path") + + # Build sources map + sources = {} + for n in nodes: + if n.get("name"): + sources[n["name"]] = n["id"] + + # Get analysis data if referenced + inputs = node.get("inputs", []) + analysis_data = {} + for inp in inputs: + if inp in analysis_results: + analysis_data[inp] = analysis_results[inp] + + construct_nodes = _expand_construct( + node, registry, sources, analysis_data, recipe_dir, cluster_key, encoding + ) + + for cn in construct_nodes: + if cn["id"] not in expanded_ids: + expanded.append(cn) + expanded_ids.add(cn["id"]) + + else: + # Keep other nodes as-is + expanded.append(node) + expanded_ids.add(node_id) + + return expanded, named_analysis + + +def create_plan( + recipe: CompiledRecipe, + inputs: Dict[str, str] = None, + recipe_dir: Path = None, + cluster_key: str = None, + on_analysis: Callable[[str, Dict], None] = None, + pre_analysis: Dict[str, Dict] = None, +) -> ExecutionPlanSexp: + """ + Create an execution plan from a compiled recipe. + + Args: + recipe: Compiled S-expression recipe + inputs: Mapping of input names to content hashes + recipe_dir: Directory for resolving relative paths (required for analyzers) + cluster_key: Optional cluster key for cache isolation + on_analysis: Callback when analysis completes (node_id, results) + pre_analysis: Pre-computed analysis data (name -> results), skips running analyzers + + Returns: + ExecutionPlanSexp with all cache IDs computed + + Example: + >>> recipe = compile_string('(recipe "test" (-> (source cat) (effect identity)))') + >>> plan = create_plan(recipe, inputs={}, recipe_dir=Path(".")) + >>> print(plan.to_string()) + """ + inputs = inputs or {} + + # Compute source hash as CID (SHA256 of raw bytes) - this IS the content address + source_hash = hashlib.sha256(recipe.source_text.encode('utf-8')).hexdigest() if recipe.source_text else "" + + # Compute params hash (use JSON + SHA256 for consistency with cache.py) + if recipe.resolved_params: + import json + params_str = json.dumps(recipe.resolved_params, sort_keys=True, default=str) + params_hash = hashlib.sha256(params_str.encode()).hexdigest() + else: + params_hash = "" + + # Check if recipe has expandable nodes (SLICE_ON, etc.) + has_expandable = any(n["type"] in EXPANDABLE_TYPES for n in recipe.nodes) + named_analysis = {} + + if has_expandable: + if recipe_dir is None: + raise ValueError("recipe_dir required for recipes with SLICE_ON nodes") + + # Expand dynamic nodes (runs analyzers, expands SLICE_ON) + source_paths = {} + expanded_nodes, named_analysis = _expand_nodes( + recipe.nodes, + recipe.registry, + recipe_dir, + source_paths, + cluster_key=cluster_key, + on_analysis=on_analysis, + encoding=recipe.encoding, + pre_analysis=pre_analysis, + ) + # Expand LIST inputs in SEQUENCE nodes + expanded_nodes = _expand_list_inputs(expanded_nodes) + # Collapse effect chains after expansion + collapsed_nodes = _collapse_effect_chains(expanded_nodes, recipe.registry) + else: + # No expansion needed + collapsed_nodes = _collapse_effect_chains(recipe.nodes, recipe.registry) + + # Build node lookup from collapsed nodes + nodes_by_id = {node["id"]: node for node in collapsed_nodes} + + # Topological sort + sorted_ids = _topological_sort(collapsed_nodes) + + # Create steps with resolved hashes + steps = [] + cache_ids = {} # step_id -> cache_id + + for node_id in sorted_ids: + node = nodes_by_id[node_id] + step = _create_step( + node, + recipe.registry, + inputs, + cache_ids, + cluster_key, + ) + steps.append(step) + cache_ids[node_id] = step.cache_id + + # Compute levels + _compute_levels(steps, nodes_by_id) + + # Handle stage-aware planning if recipe has stages + stage_plans = [] + stage_order = [] + stage_levels = {} + + if recipe.stages: + # Build mapping from node_id to stage + node_to_stage = {} + for stage in recipe.stages: + for node_id in stage.node_ids: + node_to_stage[node_id] = stage.name + + # Compute stage levels (for parallel execution) + stage_levels = _compute_stage_levels(recipe.stages) + + # Tag each step with stage info + for step in steps: + if step.step_id in node_to_stage: + step.stage = node_to_stage[step.step_id] + + # Build stage plans + for stage_name in recipe.stage_order: + stage = next(s for s in recipe.stages if s.name == stage_name) + stage_steps = [s for s in steps if s.stage == stage_name] + + # Build output bindings with cache IDs + output_cache_ids = {} + for out_name, node_id in stage.output_bindings.items(): + if node_id in cache_ids: + output_cache_ids[out_name] = cache_ids[node_id] + + stage_plans.append(StagePlan( + stage_name=stage_name, + steps=stage_steps, + requires=stage.requires, + output_bindings=output_cache_ids, + level=stage_levels.get(stage_name, 0), + )) + + stage_order = recipe.stage_order + + # Compute plan ID from source CID + steps + plan_content = { + "source_cid": source_hash, + "steps": [{"id": s.step_id, "cache_id": s.cache_id} for s in steps], + "inputs": inputs, + } + plan_id = _stable_hash(plan_content, cluster_key) + + return ExecutionPlanSexp( + plan_id=plan_id, + source_hash=source_hash, + params=recipe.resolved_params, + params_hash=params_hash, + steps=steps, + output_step_id=recipe.output_node_id, + inputs=inputs, + analysis=named_analysis, + stage_plans=stage_plans, + stage_order=stage_order, + stage_levels=stage_levels, + effects_registry=recipe.registry.get("effects", {}), + minimal_primitives=recipe.minimal_primitives, + ) + + +def _topological_sort(nodes: List[Dict]) -> List[str]: + """Sort nodes in dependency order.""" + nodes_by_id = {n["id"]: n for n in nodes} + visited = set() + order = [] + + def visit(node_id: str): + if node_id in visited: + return + visited.add(node_id) + node = nodes_by_id.get(node_id) + if node: + for input_id in node.get("inputs", []): + visit(input_id) + order.append(node_id) + + for node in nodes: + visit(node["id"]) + + return order + + +def _create_step( + node: Dict, + registry: Dict, + inputs: Dict[str, str], + cache_ids: Dict[str, str], + cluster_key: str = None, +) -> PlanStep: + """Create a PlanStep from a node definition.""" + node_id = node["id"] + node_type = node["type"] + config = dict(node.get("config", {})) + node_inputs = node.get("inputs", []) + + # Resolve registry references + resolved_config = _resolve_config(config, registry, inputs) + + # Get input cache IDs (direct graph inputs) + input_cache_ids = [cache_ids[inp] for inp in node_inputs if inp in cache_ids] + + # Also include analysis_refs as dependencies (for binding resolution) + # These are implicit inputs that affect the computation result + analysis_refs = resolved_config.get("analysis_refs", []) + analysis_cache_ids = [cache_ids[ref] for ref in analysis_refs if ref in cache_ids] + + # Compute cache ID including both inputs and analysis dependencies + cache_content = { + "node_type": node_type, + "config": resolved_config, + "inputs": sorted(input_cache_ids + analysis_cache_ids), + } + cache_id = _stable_hash(cache_content, cluster_key) + + return PlanStep( + step_id=node_id, + node_type=node_type, + config=resolved_config, + inputs=node_inputs, + cache_id=cache_id, + ) + + +def _resolve_config( + config: Dict, + registry: Dict, + inputs: Dict[str, str], +) -> Dict: + """Resolve registry references in config to content hashes.""" + resolved = {} + + for key, value in config.items(): + if key == "filter_chain" and isinstance(value, list): + # Resolve each filter in the chain (for COMPOUND nodes) + resolved_chain = [] + for filter_item in value: + filter_config = filter_item.get("config", {}) + resolved_filter_config = _resolve_config(filter_config, registry, inputs) + resolved_chain.append({ + "type": filter_item["type"], + "config": resolved_filter_config, + }) + resolved["filter_chain"] = resolved_chain + + elif key == "asset" and isinstance(value, str): + # Resolve asset reference - use CID from registry + if value in registry.get("assets", {}): + resolved["cid"] = registry["assets"][value]["cid"] + else: + resolved["asset"] = value # Keep as-is if not in registry + + elif key == "effect" and isinstance(value, str): + # Resolve effect reference - keep name AND add CID/path + resolved["effect"] = value + if value in registry.get("effects", {}): + effect_entry = registry["effects"][value] + if effect_entry.get("cid"): + resolved["cid"] = effect_entry["cid"] + if effect_entry.get("path"): + resolved["effect_path"] = effect_entry["path"] + + elif key == "input" and value is True: + # Variable input - resolve from inputs dict + input_name = config.get("name", "input") + if input_name in inputs: + resolved["hash"] = inputs[input_name] + else: + resolved["input"] = True + resolved["name"] = input_name + + elif key == "path": + # Local file path - keep as-is for local execution + resolved["path"] = value + + else: + resolved[key] = value + + return resolved + + +def _compute_levels(steps: List[PlanStep], nodes_by_id: Dict) -> None: + """Compute dependency levels for steps. + + Considers both inputs (data dependencies) and analysis_refs (binding dependencies). + """ + levels = {} + + def compute_level(step_id: str) -> int: + if step_id in levels: + return levels[step_id] + + node = nodes_by_id.get(step_id) + if not node: + levels[step_id] = 0 + return 0 + + # Collect all dependencies: inputs + analysis_refs + deps = list(node.get("inputs", [])) + + # Add analysis_refs as dependencies (for bindings to analysis data) + config = node.get("config", {}) + analysis_refs = config.get("analysis_refs", []) + deps.extend(analysis_refs) + + if not deps: + levels[step_id] = 0 + return 0 + + max_dep = max(compute_level(dep) for dep in deps) + levels[step_id] = max_dep + 1 + return levels[step_id] + + for step in steps: + step.level = compute_level(step.step_id) + + +def _compute_stage_levels(stages: List) -> Dict[str, int]: + """ + Compute stage levels for parallel execution. + + Stages at the same level have no dependencies between them + and can run in parallel. + """ + from .compiler import CompiledStage + + levels = {} + + def compute_level(stage_name: str) -> int: + if stage_name in levels: + return levels[stage_name] + + stage = next((s for s in stages if s.name == stage_name), None) + if not stage or not stage.requires: + levels[stage_name] = 0 + return 0 + + max_req = max(compute_level(req) for req in stage.requires) + levels[stage_name] = max_req + 1 + return levels[stage_name] + + for stage in stages: + compute_level(stage.name) + + return levels + + +def step_to_task_sexp(step: PlanStep) -> List: + """ + Convert a step to a minimal S-expression for Celery task. + + This is the S-expression that gets sent to a worker. + The worker hashes this to verify cache_id. + """ + sexp = [Symbol(step.node_type.lower())] + + # Add resolved config + for key, value in step.config.items(): + sexp.extend([Keyword(key), value]) + + # Add input cache IDs (not step IDs) + if step.inputs: + sexp.extend([Keyword("inputs"), step.inputs]) + + return sexp + + +def task_cache_id(task_sexp: List, cluster_key: str = None) -> str: + """ + Compute cache ID from task S-expression. + + This allows workers to verify they're executing the right task. + """ + # Serialize S-expression to canonical form + canonical = serialize(task_sexp) + return _stable_hash({"sexp": canonical}, cluster_key) diff --git a/artdag/sexp/primitives.py b/artdag/sexp/primitives.py new file mode 100644 index 0000000..65bbcc0 --- /dev/null +++ b/artdag/sexp/primitives.py @@ -0,0 +1,620 @@ +""" +Frame processing primitives for sexp effects. + +These primitives are called by sexp effect definitions and operate on +numpy arrays (frames). They're used when falling back to Python rendering +instead of FFmpeg. + +Required: numpy, PIL +""" + +import math +from typing import Any, Dict, List, Optional, Tuple + +try: + import numpy as np + HAS_NUMPY = True +except ImportError: + HAS_NUMPY = False + np = None + +try: + from PIL import Image, ImageDraw, ImageFont + HAS_PIL = True +except ImportError: + HAS_PIL = False + + +# ASCII character sets for different styles +ASCII_ALPHABETS = { + "standard": " .:-=+*#%@", + "blocks": " ░▒▓█", + "simple": " .-:+=xX#", + "detailed": " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$", + "binary": " █", +} + + +def check_deps(): + """Check that required dependencies are available.""" + if not HAS_NUMPY: + raise ImportError("numpy required for frame primitives: pip install numpy") + if not HAS_PIL: + raise ImportError("PIL required for frame primitives: pip install Pillow") + + +def frame_to_image(frame: np.ndarray) -> Image.Image: + """Convert numpy frame to PIL Image.""" + check_deps() + if frame.dtype != np.uint8: + frame = np.clip(frame, 0, 255).astype(np.uint8) + return Image.fromarray(frame) + + +def image_to_frame(img: Image.Image) -> np.ndarray: + """Convert PIL Image to numpy frame.""" + check_deps() + return np.array(img) + + +# ============================================================================ +# ASCII Art Primitives +# ============================================================================ + +def cell_sample(frame: np.ndarray, cell_size: int = 8) -> Tuple[np.ndarray, np.ndarray]: + """ + Sample frame into cells, returning average colors and luminances. + + Args: + frame: Input frame (H, W, C) + cell_size: Size of each cell + + Returns: + (colors, luminances) - colors is (rows, cols, 3), luminances is (rows, cols) + """ + check_deps() + h, w = frame.shape[:2] + rows = h // cell_size + cols = w // cell_size + + colors = np.zeros((rows, cols, 3), dtype=np.float32) + luminances = np.zeros((rows, cols), dtype=np.float32) + + for r in range(rows): + for c in range(cols): + y0, y1 = r * cell_size, (r + 1) * cell_size + x0, x1 = c * cell_size, (c + 1) * cell_size + cell = frame[y0:y1, x0:x1] + + # Average color + avg_color = cell.mean(axis=(0, 1)) + colors[r, c] = avg_color[:3] # RGB only + + # Luminance (ITU-R BT.601) + lum = 0.299 * avg_color[0] + 0.587 * avg_color[1] + 0.114 * avg_color[2] + luminances[r, c] = lum + + return colors, luminances + + +def luminance_to_chars( + luminances: np.ndarray, + alphabet: str = "standard", + contrast: float = 1.0, +) -> List[List[str]]: + """ + Convert luminance values to ASCII characters. + + Args: + luminances: 2D array of luminance values (0-255) + alphabet: Name of character set or custom string + contrast: Contrast multiplier + + Returns: + 2D list of characters + """ + check_deps() + chars = ASCII_ALPHABETS.get(alphabet, alphabet) + n_chars = len(chars) + + rows, cols = luminances.shape + result = [] + + for r in range(rows): + row_chars = [] + for c in range(cols): + lum = luminances[r, c] + # Apply contrast around midpoint + lum = 128 + (lum - 128) * contrast + lum = np.clip(lum, 0, 255) + # Map to character index + idx = int(lum / 256 * n_chars) + idx = min(idx, n_chars - 1) + row_chars.append(chars[idx]) + result.append(row_chars) + + return result + + +def render_char_grid( + frame: np.ndarray, + chars: List[List[str]], + colors: np.ndarray, + char_size: int = 8, + color_mode: str = "color", + background: Tuple[int, int, int] = (0, 0, 0), +) -> np.ndarray: + """ + Render character grid to an image. + + Args: + frame: Original frame (for dimensions) + chars: 2D list of characters + colors: Color for each cell (rows, cols, 3) + char_size: Size of each character cell + color_mode: "color", "white", or "green" + background: Background RGB color + + Returns: + Rendered frame + """ + check_deps() + h, w = frame.shape[:2] + rows = len(chars) + cols = len(chars[0]) if chars else 0 + + # Create output image + img = Image.new("RGB", (w, h), background) + draw = ImageDraw.Draw(img) + + # Try to get a monospace font + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size) + except (IOError, OSError): + try: + font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", char_size) + except (IOError, OSError): + font = ImageFont.load_default() + + for r in range(rows): + for c in range(cols): + char = chars[r][c] + if char == ' ': + continue + + x = c * char_size + y = r * char_size + + if color_mode == "color": + color = tuple(int(v) for v in colors[r, c]) + elif color_mode == "green": + color = (0, 255, 0) + else: # white + color = (255, 255, 255) + + draw.text((x, y), char, fill=color, font=font) + + return np.array(img) + + +def ascii_art_frame( + frame: np.ndarray, + char_size: int = 8, + alphabet: str = "standard", + color_mode: str = "color", + contrast: float = 1.5, + background: Tuple[int, int, int] = (0, 0, 0), +) -> np.ndarray: + """ + Apply ASCII art effect to a frame. + + This is the main entry point for the ascii_art effect. + """ + check_deps() + colors, luminances = cell_sample(frame, char_size) + chars = luminance_to_chars(luminances, alphabet, contrast) + return render_char_grid(frame, chars, colors, char_size, color_mode, background) + + +# ============================================================================ +# ASCII Zones Primitives +# ============================================================================ + +def ascii_zones_frame( + frame: np.ndarray, + char_size: int = 8, + zone_threshold: int = 128, + dark_chars: str = " .-:", + light_chars: str = "=+*#", +) -> np.ndarray: + """ + Apply zone-based ASCII art effect. + + Different character sets for dark vs light regions. + """ + check_deps() + colors, luminances = cell_sample(frame, char_size) + + rows, cols = luminances.shape + chars = [] + + for r in range(rows): + row_chars = [] + for c in range(cols): + lum = luminances[r, c] + if lum < zone_threshold: + # Dark zone + charset = dark_chars + local_lum = lum / zone_threshold # 0-1 within zone + else: + # Light zone + charset = light_chars + local_lum = (lum - zone_threshold) / (255 - zone_threshold) + + idx = int(local_lum * len(charset)) + idx = min(idx, len(charset) - 1) + row_chars.append(charset[idx]) + chars.append(row_chars) + + return render_char_grid(frame, chars, colors, char_size, "color", (0, 0, 0)) + + +# ============================================================================ +# Kaleidoscope Primitives (Python fallback) +# ============================================================================ + +def kaleidoscope_displace( + w: int, + h: int, + segments: int = 6, + rotation: float = 0, + cx: float = None, + cy: float = None, + zoom: float = 1.0, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute kaleidoscope displacement coordinates. + + Returns (x_coords, y_coords) arrays for remapping. + """ + check_deps() + if cx is None: + cx = w / 2 + if cy is None: + cy = h / 2 + + # Create coordinate grids + y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32) + + # Center coordinates + x_centered = x_grid - cx + y_centered = y_grid - cy + + # Convert to polar + r = np.sqrt(x_centered**2 + y_centered**2) / zoom + theta = np.arctan2(y_centered, x_centered) + + # Apply rotation + theta = theta - np.radians(rotation) + + # Kaleidoscope: fold angle into segment + segment_angle = 2 * np.pi / segments + theta = np.abs(np.mod(theta, segment_angle) - segment_angle / 2) + + # Convert back to cartesian + x_new = r * np.cos(theta) + cx + y_new = r * np.sin(theta) + cy + + return x_new, y_new + + +def remap( + frame: np.ndarray, + x_coords: np.ndarray, + y_coords: np.ndarray, +) -> np.ndarray: + """ + Remap frame using coordinate arrays. + + Uses bilinear interpolation. + """ + check_deps() + from scipy import ndimage + + h, w = frame.shape[:2] + + # Clip coordinates + x_coords = np.clip(x_coords, 0, w - 1) + y_coords = np.clip(y_coords, 0, h - 1) + + # Remap each channel + if len(frame.shape) == 3: + result = np.zeros_like(frame) + for c in range(frame.shape[2]): + result[:, :, c] = ndimage.map_coordinates( + frame[:, :, c], + [y_coords, x_coords], + order=1, + mode='reflect', + ) + return result + else: + return ndimage.map_coordinates(frame, [y_coords, x_coords], order=1, mode='reflect') + + +def kaleidoscope_frame( + frame: np.ndarray, + segments: int = 6, + rotation: float = 0, + center_x: float = 0.5, + center_y: float = 0.5, + zoom: float = 1.0, +) -> np.ndarray: + """ + Apply kaleidoscope effect to a frame. + + This is a Python fallback - FFmpeg version is faster. + """ + check_deps() + h, w = frame.shape[:2] + cx = w * center_x + cy = h * center_y + + x_coords, y_coords = kaleidoscope_displace(w, h, segments, rotation, cx, cy, zoom) + return remap(frame, x_coords, y_coords) + + +# ============================================================================ +# Datamosh Primitives (simplified Python version) +# ============================================================================ + +def datamosh_frame( + frame: np.ndarray, + prev_frame: Optional[np.ndarray], + block_size: int = 32, + corruption: float = 0.3, + max_offset: int = 50, + color_corrupt: bool = True, +) -> np.ndarray: + """ + Simplified datamosh effect using block displacement. + + This is a basic approximation - real datamosh works on compressed video. + """ + check_deps() + if prev_frame is None: + return frame.copy() + + h, w = frame.shape[:2] + result = frame.copy() + + # Process in blocks + for y in range(0, h - block_size, block_size): + for x in range(0, w - block_size, block_size): + if np.random.random() < corruption: + # Random offset + ox = np.random.randint(-max_offset, max_offset + 1) + oy = np.random.randint(-max_offset, max_offset + 1) + + # Source from previous frame with offset + src_y = np.clip(y + oy, 0, h - block_size) + src_x = np.clip(x + ox, 0, w - block_size) + + block = prev_frame[src_y:src_y+block_size, src_x:src_x+block_size] + + # Color corruption + if color_corrupt and np.random.random() < 0.3: + # Swap or shift channels + block = np.roll(block, np.random.randint(1, 3), axis=2) + + result[y:y+block_size, x:x+block_size] = block + + return result + + +# ============================================================================ +# Pixelsort Primitives (Python version) +# ============================================================================ + +def pixelsort_frame( + frame: np.ndarray, + sort_by: str = "lightness", + threshold_low: float = 50, + threshold_high: float = 200, + angle: float = 0, + reverse: bool = False, +) -> np.ndarray: + """ + Apply pixel sorting effect to a frame. + """ + check_deps() + from scipy import ndimage + + # Rotate if needed + if angle != 0: + frame = ndimage.rotate(frame, -angle, reshape=False, mode='reflect') + + h, w = frame.shape[:2] + result = frame.copy() + + # Compute sort key + if sort_by == "lightness": + key = 0.299 * frame[:,:,0] + 0.587 * frame[:,:,1] + 0.114 * frame[:,:,2] + elif sort_by == "hue": + # Simple hue approximation + key = np.arctan2( + np.sqrt(3) * (frame[:,:,1].astype(float) - frame[:,:,2]), + 2 * frame[:,:,0].astype(float) - frame[:,:,1] - frame[:,:,2] + ) + elif sort_by == "saturation": + mx = frame.max(axis=2).astype(float) + mn = frame.min(axis=2).astype(float) + key = np.where(mx > 0, (mx - mn) / mx, 0) + else: + key = frame[:,:,0] # Red channel + + # Sort each row + for y in range(h): + row = result[y] + row_key = key[y] + + # Find sortable intervals (pixels within threshold) + mask = (row_key >= threshold_low) & (row_key <= threshold_high) + + # Find runs of True in mask + runs = [] + start = None + for x in range(w): + if mask[x] and start is None: + start = x + elif not mask[x] and start is not None: + if x - start > 1: + runs.append((start, x)) + start = None + if start is not None and w - start > 1: + runs.append((start, w)) + + # Sort each run + for start, end in runs: + indices = np.argsort(row_key[start:end]) + if reverse: + indices = indices[::-1] + result[y, start:end] = row[start:end][indices] + + # Rotate back + if angle != 0: + result = ndimage.rotate(result, angle, reshape=False, mode='reflect') + + return result + + +# ============================================================================ +# Primitive Registry +# ============================================================================ + +def map_char_grid( + chars, + luminances, + fn, +): + """ + Map a function over each cell of a character grid. + + Args: + chars: 2D array/list of characters (rows, cols) + luminances: 2D array of luminance values + fn: Function or Lambda (row, col, char, luminance) -> new_char + + Returns: + New character grid with mapped values (list of lists) + """ + from .parser import Lambda + from .evaluator import evaluate + + # Handle both list and numpy array inputs + if isinstance(chars, np.ndarray): + rows, cols = chars.shape[:2] + else: + rows = len(chars) + cols = len(chars[0]) if rows > 0 and isinstance(chars[0], (list, tuple, str)) else 1 + + # Get luminances as 2D + if isinstance(luminances, np.ndarray): + lum_arr = luminances + else: + lum_arr = np.array(luminances) + + # Check if fn is a Lambda (from sexp) or a Python callable + is_lambda = isinstance(fn, Lambda) + + result = [] + for r in range(rows): + row_result = [] + for c in range(cols): + # Get character + if isinstance(chars, np.ndarray): + ch = chars[r, c] if len(chars.shape) > 1 else chars[r] + elif isinstance(chars[r], str): + ch = chars[r][c] if c < len(chars[r]) else ' ' + else: + ch = chars[r][c] if c < len(chars[r]) else ' ' + + # Get luminance + if len(lum_arr.shape) > 1: + lum = lum_arr[r, c] + else: + lum = lum_arr[r] + + # Call the function + if is_lambda: + # Evaluate the Lambda with arguments bound + call_env = dict(fn.closure) if fn.closure else {} + for param, val in zip(fn.params, [r, c, ch, float(lum)]): + call_env[param] = val + new_ch = evaluate(fn.body, call_env) + else: + new_ch = fn(r, c, ch, float(lum)) + + row_result.append(new_ch) + result.append(row_result) + + return result + + +def alphabet_char(alphabet: str, index: int) -> str: + """ + Get a character from an alphabet at a given index. + + Args: + alphabet: Alphabet name (from ASCII_ALPHABETS) or literal string + index: Index into the alphabet (clamped to valid range) + + Returns: + Character at the index + """ + # Get alphabet string + if alphabet in ASCII_ALPHABETS: + chars = ASCII_ALPHABETS[alphabet] + else: + chars = alphabet + + # Clamp index to valid range + index = int(index) + index = max(0, min(index, len(chars) - 1)) + + return chars[index] + + +PRIMITIVES = { + # ASCII + "cell-sample": cell_sample, + "luminance-to-chars": luminance_to_chars, + "render-char-grid": render_char_grid, + "map-char-grid": map_char_grid, + "alphabet-char": alphabet_char, + "ascii_art_frame": ascii_art_frame, + "ascii_zones_frame": ascii_zones_frame, + + # Kaleidoscope + "kaleidoscope-displace": kaleidoscope_displace, + "remap": remap, + "kaleidoscope_frame": kaleidoscope_frame, + + # Datamosh + "datamosh": datamosh_frame, + "datamosh_frame": datamosh_frame, + + # Pixelsort + "pixelsort": pixelsort_frame, + "pixelsort_frame": pixelsort_frame, +} + + +def get_primitive(name: str): + """Get a primitive function by name.""" + return PRIMITIVES.get(name) + + +def list_primitives() -> List[str]: + """List all available primitives.""" + return list(PRIMITIVES.keys()) diff --git a/artdag/sexp/scheduler.py b/artdag/sexp/scheduler.py new file mode 100644 index 0000000..65daf28 --- /dev/null +++ b/artdag/sexp/scheduler.py @@ -0,0 +1,779 @@ +""" +Celery scheduler for S-expression execution plans. + +Distributes plan steps to workers as S-expressions. +The S-expression is the canonical format - workers receive +serialized S-expressions and can verify cache_ids by hashing them. + +Usage: + from artdag.sexp import compile_string, create_plan + from artdag.sexp.scheduler import schedule_plan + + recipe = compile_string(sexp_content) + plan = create_plan(recipe, inputs={'video': 'abc123...'}) + result = schedule_plan(plan) +""" + +import hashlib +import json +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Callable + +from .parser import Symbol, Keyword, serialize, parse +from .planner import ExecutionPlanSexp, PlanStep + +logger = logging.getLogger(__name__) + + +@dataclass +class StepResult: + """Result from executing a step.""" + step_id: str + cache_id: str + status: str # 'completed', 'cached', 'failed', 'pending' + output_path: Optional[str] = None + error: Optional[str] = None + ipfs_cid: Optional[str] = None + + +@dataclass +class PlanResult: + """Result from executing a complete plan.""" + plan_id: str + status: str # 'completed', 'failed', 'partial' + steps_completed: int = 0 + steps_cached: int = 0 + steps_failed: int = 0 + output_cache_id: Optional[str] = None + output_path: Optional[str] = None + output_ipfs_cid: Optional[str] = None + step_results: Dict[str, StepResult] = field(default_factory=dict) + error: Optional[str] = None + + +def step_to_sexp(step: PlanStep) -> List: + """ + Convert a PlanStep to minimal S-expression for worker. + + This is the canonical form that workers receive. + Workers can verify cache_id by hashing this S-expression. + """ + sexp = [Symbol(step.node_type.lower())] + + # Add config as keywords + for key, value in step.config.items(): + sexp.extend([Keyword(key.replace('_', '-')), value]) + + # Add inputs as cache IDs (not step IDs) + if step.inputs: + sexp.extend([Keyword("inputs"), step.inputs]) + + return sexp + + +def step_sexp_to_string(step: PlanStep) -> str: + """Serialize step to S-expression string for Celery task.""" + return serialize(step_to_sexp(step)) + + +def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool: + """ + Verify that a step's cache_id matches its S-expression. + + Workers should call this to verify they're executing the correct task. + """ + data = {"sexp": step_sexp} + if cluster_key: + data = {"_cluster_key": cluster_key, "_data": data} + + json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) + computed = hashlib.sha3_256(json_str.encode()).hexdigest() + return computed == expected_cache_id + + +class PlanScheduler: + """ + Schedules execution of S-expression plans on Celery workers. + + The scheduler: + 1. Groups steps by dependency level + 2. Checks cache for already-computed results + 3. Dispatches uncached steps to workers + 4. Waits for completion before proceeding to next level + """ + + def __init__( + self, + cache_manager=None, + celery_app=None, + execute_task_name: str = 'tasks.execute_step_sexp', + ): + """ + Initialize the scheduler. + + Args: + cache_manager: L1 cache manager for checking cached results + celery_app: Celery application instance + execute_task_name: Name of the Celery task for step execution + """ + self.cache_manager = cache_manager + self.celery_app = celery_app + self.execute_task_name = execute_task_name + + def schedule( + self, + plan: ExecutionPlanSexp, + timeout: int = 3600, + ) -> PlanResult: + """ + Schedule and execute a plan. + + Args: + plan: The execution plan (S-expression format) + timeout: Timeout in seconds for the entire plan + + Returns: + PlanResult with execution results + """ + from celery import group + + logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)") + + # Build step lookup and group by level + steps_by_id = {s.step_id: s for s in plan.steps} + steps_by_level = self._group_by_level(plan.steps) + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + + # Track results + result = PlanResult( + plan_id=plan.plan_id, + status="pending", + ) + + # Map step_id -> cache_id for resolving inputs + cache_ids = dict(plan.inputs) # Start with input hashes + for step in plan.steps: + cache_ids[step.step_id] = step.cache_id + + # Execute level by level + for level in range(max_level + 1): + level_steps = steps_by_level.get(level, []) + if not level_steps: + continue + + logger.info(f"Level {level}: {len(level_steps)} steps") + + # Check cache for each step + steps_to_run = [] + for step in level_steps: + if self._is_cached(step.cache_id): + result.steps_cached += 1 + result.step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="cached", + output_path=self._get_cached_path(step.cache_id), + ) + else: + steps_to_run.append(step) + + if not steps_to_run: + logger.info(f"Level {level}: all {len(level_steps)} steps cached") + continue + + # Dispatch uncached steps to workers + logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps") + + tasks = [] + for step in steps_to_run: + # Build task arguments + step_sexp = step_sexp_to_string(step) + input_cache_ids = { + inp: cache_ids.get(inp, inp) + for inp in step.inputs + } + + task = self._get_execute_task().s( + step_sexp=step_sexp, + step_id=step.step_id, + cache_id=step.cache_id, + plan_id=plan.plan_id, + input_cache_ids=input_cache_ids, + ) + tasks.append(task) + + # Execute in parallel + job = group(tasks) + async_result = job.apply_async() + + try: + step_results = async_result.get(timeout=timeout) + except Exception as e: + logger.error(f"Level {level} failed: {e}") + result.status = "failed" + result.error = f"Level {level} failed: {e}" + return result + + # Process results + for step_result in step_results: + step_id = step_result.get("step_id") + status = step_result.get("status") + + result.step_results[step_id] = StepResult( + step_id=step_id, + cache_id=step_result.get("cache_id"), + status=status, + output_path=step_result.get("output_path"), + error=step_result.get("error"), + ipfs_cid=step_result.get("ipfs_cid"), + ) + + if status in ("completed", "cached", "completed_by_other"): + result.steps_completed += 1 + elif status == "failed": + result.steps_failed += 1 + result.status = "failed" + result.error = step_result.get("error") + return result + + # Get final output + output_step = steps_by_id.get(plan.output_step_id) + if output_step: + output_result = result.step_results.get(output_step.step_id) + if output_result: + result.output_cache_id = output_step.cache_id + result.output_path = output_result.output_path + result.output_ipfs_cid = output_result.ipfs_cid + + result.status = "completed" + logger.info( + f"Plan {plan.plan_id[:16]}... completed: " + f"{result.steps_completed} executed, {result.steps_cached} cached" + ) + return result + + def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]: + """Group steps by dependency level.""" + by_level = {} + for step in steps: + by_level.setdefault(step.level, []).append(step) + return by_level + + def _is_cached(self, cache_id: str) -> bool: + """Check if a cache_id exists in cache.""" + if self.cache_manager is None: + return False + path = self.cache_manager.get_by_cid(cache_id) + return path is not None + + def _get_cached_path(self, cache_id: str) -> Optional[str]: + """Get the path for a cached item.""" + if self.cache_manager is None: + return None + path = self.cache_manager.get_by_cid(cache_id) + return str(path) if path else None + + def _get_execute_task(self): + """Get the Celery task for step execution.""" + if self.celery_app is None: + raise RuntimeError("Celery app not configured") + return self.celery_app.tasks[self.execute_task_name] + + +def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler: + """ + Create a scheduler with the given dependencies. + + If not provided, attempts to import from art-celery. + """ + if celery_app is None: + try: + from celery_app import app as celery_app + except ImportError: + pass + + if cache_manager is None: + try: + from cache_manager import get_cache_manager + cache_manager = get_cache_manager() + except ImportError: + pass + + return PlanScheduler( + cache_manager=cache_manager, + celery_app=celery_app, + ) + + +def schedule_plan( + plan: ExecutionPlanSexp, + cache_manager=None, + celery_app=None, + timeout: int = 3600, +) -> PlanResult: + """ + Convenience function to schedule a plan. + + Args: + plan: The execution plan + cache_manager: Optional cache manager + celery_app: Optional Celery app + timeout: Execution timeout + + Returns: + PlanResult + """ + scheduler = create_scheduler(cache_manager, celery_app) + return scheduler.schedule(plan, timeout=timeout) + + +# Stage-aware scheduling + +@dataclass +class StageResult: + """Result from executing a stage.""" + stage_name: str + cache_id: str + status: str # 'completed', 'cached', 'failed', 'pending' + step_results: Dict[str, StepResult] = field(default_factory=dict) + outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id + error: Optional[str] = None + + +@dataclass +class StagePlanResult: + """Result from executing a plan with stages.""" + plan_id: str + status: str # 'completed', 'failed', 'partial' + stages_completed: int = 0 + stages_cached: int = 0 + stages_failed: int = 0 + steps_completed: int = 0 + steps_cached: int = 0 + steps_failed: int = 0 + stage_results: Dict[str, StageResult] = field(default_factory=dict) + output_cache_id: Optional[str] = None + output_path: Optional[str] = None + error: Optional[str] = None + + +class StagePlanScheduler: + """ + Stage-aware scheduler for S-expression plans. + + The scheduler: + 1. Groups stages by level (parallel groups) + 2. For each stage level: + - Check stage cache, skip entire stage if hit + - Execute stage steps (grouped by level within stage) + - Cache stage outputs + 3. Stages at same level can run in parallel + """ + + def __init__( + self, + cache_manager=None, + stage_cache=None, + celery_app=None, + execute_task_name: str = 'tasks.execute_step_sexp', + ): + """ + Initialize the stage-aware scheduler. + + Args: + cache_manager: L1 cache manager for step-level caching + stage_cache: StageCache instance for stage-level caching + celery_app: Celery application instance + execute_task_name: Name of the Celery task for step execution + """ + self.cache_manager = cache_manager + self.stage_cache = stage_cache + self.celery_app = celery_app + self.execute_task_name = execute_task_name + + def schedule( + self, + plan: ExecutionPlanSexp, + timeout: int = 3600, + ) -> StagePlanResult: + """ + Schedule and execute a plan with stage awareness. + + If the plan has stages, uses stage-level scheduling. + Otherwise, falls back to step-level scheduling. + + Args: + plan: The execution plan (S-expression format) + timeout: Timeout in seconds for the entire plan + + Returns: + StagePlanResult with execution results + """ + # If no stages, use regular scheduling + if not plan.stage_plans: + logger.info("Plan has no stages, using step-level scheduling") + regular_scheduler = PlanScheduler( + cache_manager=self.cache_manager, + celery_app=self.celery_app, + execute_task_name=self.execute_task_name, + ) + step_result = regular_scheduler.schedule(plan, timeout) + return StagePlanResult( + plan_id=step_result.plan_id, + status=step_result.status, + steps_completed=step_result.steps_completed, + steps_cached=step_result.steps_cached, + steps_failed=step_result.steps_failed, + output_cache_id=step_result.output_cache_id, + output_path=step_result.output_path, + error=step_result.error, + ) + + logger.info( + f"Scheduling plan {plan.plan_id[:16]}... " + f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)" + ) + + result = StagePlanResult( + plan_id=plan.plan_id, + status="pending", + ) + + # Group stages by level + stages_by_level = self._group_stages_by_level(plan.stage_plans) + max_level = max(stages_by_level.keys()) if stages_by_level else 0 + + # Track stage outputs for data flow + stage_outputs = {} # stage_name -> {binding_name -> cache_id} + + # Execute stage by stage level + for level in range(max_level + 1): + level_stages = stages_by_level.get(level, []) + if not level_stages: + continue + + logger.info(f"Stage level {level}: {len(level_stages)} stages") + + # Check stage cache for each stage + stages_to_run = [] + for stage_plan in level_stages: + if self._is_stage_cached(stage_plan.cache_id): + result.stages_cached += 1 + cached_entry = self._load_cached_stage(stage_plan.cache_id) + if cached_entry: + stage_outputs[stage_plan.stage_name] = { + name: out.cache_id + for name, out in cached_entry.outputs.items() + } + result.stage_results[stage_plan.stage_name] = StageResult( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + status="cached", + outputs=stage_outputs[stage_plan.stage_name], + ) + logger.info(f"Stage {stage_plan.stage_name}: cached") + else: + stages_to_run.append(stage_plan) + + if not stages_to_run: + logger.info(f"Stage level {level}: all {len(level_stages)} stages cached") + continue + + # Execute uncached stages + # For now, execute sequentially; L1 Celery will add parallel execution + for stage_plan in stages_to_run: + logger.info(f"Executing stage: {stage_plan.stage_name}") + + stage_result = self._execute_stage( + stage_plan, + plan, + stage_outputs, + timeout, + ) + + result.stage_results[stage_plan.stage_name] = stage_result + + if stage_result.status == "completed": + result.stages_completed += 1 + stage_outputs[stage_plan.stage_name] = stage_result.outputs + + # Cache the stage result + self._cache_stage(stage_plan, stage_result) + elif stage_result.status == "failed": + result.stages_failed += 1 + result.status = "failed" + result.error = stage_result.error + return result + + # Accumulate step counts + for sr in stage_result.step_results.values(): + if sr.status == "completed": + result.steps_completed += 1 + elif sr.status == "cached": + result.steps_cached += 1 + elif sr.status == "failed": + result.steps_failed += 1 + + # Get final output + if plan.stage_plans: + last_stage = plan.stage_plans[-1] + if last_stage.stage_name in result.stage_results: + stage_res = result.stage_results[last_stage.stage_name] + result.output_cache_id = last_stage.cache_id + # Find the output step's path from step results + for step_res in stage_res.step_results.values(): + if step_res.output_path: + result.output_path = step_res.output_path + + result.status = "completed" + logger.info( + f"Plan {plan.plan_id[:16]}... completed: " + f"{result.stages_completed} stages executed, " + f"{result.stages_cached} stages cached" + ) + return result + + def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]: + """Group stage plans by their level.""" + by_level = {} + for stage_plan in stage_plans: + by_level.setdefault(stage_plan.level, []).append(stage_plan) + return by_level + + def _is_stage_cached(self, cache_id: str) -> bool: + """Check if a stage is cached.""" + if self.stage_cache is None: + return False + return self.stage_cache.has_stage(cache_id) + + def _load_cached_stage(self, cache_id: str): + """Load a cached stage entry.""" + if self.stage_cache is None: + return None + return self.stage_cache.load_stage(cache_id) + + def _cache_stage(self, stage_plan, stage_result: StageResult) -> None: + """Cache a stage result.""" + if self.stage_cache is None: + return + + from .stage_cache import StageCacheEntry, StageOutput + + outputs = {} + for name, cache_id in stage_result.outputs.items(): + outputs[name] = StageOutput( + cache_id=cache_id, + output_type="artifact", + ) + + entry = StageCacheEntry( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + outputs=outputs, + ) + self.stage_cache.save_stage(entry) + + def _execute_stage( + self, + stage_plan, + plan: ExecutionPlanSexp, + stage_outputs: Dict, + timeout: int, + ) -> StageResult: + """ + Execute a single stage. + + Uses the PlanScheduler to execute the stage's steps. + """ + # Create a mini-plan with just this stage's steps + stage_steps = stage_plan.steps + + # Build step lookup + steps_by_id = {s.step_id: s for s in stage_steps} + steps_by_level = {} + for step in stage_steps: + steps_by_level.setdefault(step.level, []).append(step) + + max_level = max(steps_by_level.keys()) if steps_by_level else 0 + + # Track step results + step_results = {} + cache_ids = dict(plan.inputs) # Start with input hashes + for step in plan.steps: + cache_ids[step.step_id] = step.cache_id + + # Include outputs from previous stages + for stage_name, outputs in stage_outputs.items(): + for binding_name, binding_cache_id in outputs.items(): + cache_ids[binding_name] = binding_cache_id + + # Execute steps level by level + for level in range(max_level + 1): + level_steps = steps_by_level.get(level, []) + if not level_steps: + continue + + # Check cache for each step + steps_to_run = [] + for step in level_steps: + if self._is_step_cached(step.cache_id): + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="cached", + output_path=self._get_cached_path(step.cache_id), + ) + else: + steps_to_run.append(step) + + if not steps_to_run: + continue + + # Execute steps (for now, sequentially - L1 will add Celery dispatch) + for step in steps_to_run: + # In a full implementation, this would dispatch to Celery + # For now, mark as pending + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="pending", + ) + + # If Celery is configured, dispatch the task + if self.celery_app: + try: + task_result = self._dispatch_step(step, cache_ids, timeout) + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status=task_result.get("status", "completed"), + output_path=task_result.get("output_path"), + error=task_result.get("error"), + ipfs_cid=task_result.get("ipfs_cid"), + ) + except Exception as e: + step_results[step.step_id] = StepResult( + step_id=step.step_id, + cache_id=step.cache_id, + status="failed", + error=str(e), + ) + return StageResult( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + status="failed", + step_results=step_results, + error=str(e), + ) + + # Build output bindings + outputs = {} + for out_name, node_id in stage_plan.output_bindings.items(): + outputs[out_name] = cache_ids.get(node_id, node_id) + + return StageResult( + stage_name=stage_plan.stage_name, + cache_id=stage_plan.cache_id, + status="completed", + step_results=step_results, + outputs=outputs, + ) + + def _is_step_cached(self, cache_id: str) -> bool: + """Check if a step is cached.""" + if self.cache_manager is None: + return False + path = self.cache_manager.get_by_cid(cache_id) + return path is not None + + def _get_cached_path(self, cache_id: str) -> Optional[str]: + """Get the path for a cached step.""" + if self.cache_manager is None: + return None + path = self.cache_manager.get_by_cid(cache_id) + return str(path) if path else None + + def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict: + """Dispatch a step to Celery for execution.""" + if self.celery_app is None: + raise RuntimeError("Celery app not configured") + + task = self.celery_app.tasks[self.execute_task_name] + + step_sexp = step_sexp_to_string(step) + input_cache_ids = { + inp: cache_ids.get(inp, inp) + for inp in step.inputs + } + + async_result = task.apply_async( + kwargs={ + "step_sexp": step_sexp, + "step_id": step.step_id, + "cache_id": step.cache_id, + "input_cache_ids": input_cache_ids, + } + ) + + return async_result.get(timeout=timeout) + + +def create_stage_scheduler( + cache_manager=None, + stage_cache=None, + celery_app=None, +) -> StagePlanScheduler: + """ + Create a stage-aware scheduler with the given dependencies. + + Args: + cache_manager: L1 cache manager for step-level caching + stage_cache: StageCache instance for stage-level caching + celery_app: Celery application instance + + Returns: + StagePlanScheduler + """ + if celery_app is None: + try: + from celery_app import app as celery_app + except ImportError: + pass + + if cache_manager is None: + try: + from cache_manager import get_cache_manager + cache_manager = get_cache_manager() + except ImportError: + pass + + return StagePlanScheduler( + cache_manager=cache_manager, + stage_cache=stage_cache, + celery_app=celery_app, + ) + + +def schedule_staged_plan( + plan: ExecutionPlanSexp, + cache_manager=None, + stage_cache=None, + celery_app=None, + timeout: int = 3600, +) -> StagePlanResult: + """ + Convenience function to schedule a plan with stage awareness. + + Args: + plan: The execution plan + cache_manager: Optional step-level cache manager + stage_cache: Optional stage-level cache + celery_app: Optional Celery app + timeout: Execution timeout + + Returns: + StagePlanResult + """ + scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app) + return scheduler.schedule(plan, timeout=timeout) diff --git a/artdag/sexp/stage_cache.py b/artdag/sexp/stage_cache.py new file mode 100644 index 0000000..44cbe4c --- /dev/null +++ b/artdag/sexp/stage_cache.py @@ -0,0 +1,412 @@ +""" +Stage-level cache layer using S-expression storage. + +Provides caching for stage outputs, enabling: +- Stage-level cache hits (skip entire stage execution) +- Analysis result persistence as sexp +- Cross-worker stage cache sharing (for L1 Celery integration) + +All cache files use .sexp extension - no JSON in the pipeline. +""" + +import os +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from .parser import Symbol, Keyword, parse, serialize + + +@dataclass +class StageOutput: + """A single output from a stage.""" + cache_id: Optional[str] = None # For artifacts (files, analysis data) + value: Any = None # For scalar values + output_type: str = "artifact" # "artifact", "analysis", "scalar" + + def to_sexp(self) -> List: + """Convert to S-expression.""" + sexp = [] + if self.cache_id: + sexp.extend([Keyword("cache-id"), self.cache_id]) + if self.value is not None: + sexp.extend([Keyword("value"), self.value]) + sexp.extend([Keyword("type"), Keyword(self.output_type)]) + return sexp + + @classmethod + def from_sexp(cls, sexp: List) -> 'StageOutput': + """Parse from S-expression list.""" + cache_id = None + value = None + output_type = "artifact" + + i = 0 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name + if i + 1 < len(sexp): + val = sexp[i + 1] + if key == "cache-id": + cache_id = val + elif key == "value": + value = val + elif key == "type": + if isinstance(val, Keyword): + output_type = val.name + else: + output_type = str(val) + i += 2 + else: + i += 1 + else: + i += 1 + + return cls(cache_id=cache_id, value=value, output_type=output_type) + + +@dataclass +class StageCacheEntry: + """Cached result of a stage execution.""" + stage_name: str + cache_id: str + outputs: Dict[str, StageOutput] # binding_name -> output info + completed_at: float = field(default_factory=time.time) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_sexp(self) -> List: + """ + Convert to S-expression for storage. + + Format: + (stage-result + :name "analyze-a" + :cache-id "abc123..." + :completed-at 1705678900.123 + :outputs + ((beats-a :cache-id "def456..." :type :analysis) + (tempo :value 120.5 :type :scalar))) + """ + sexp = [Symbol("stage-result")] + sexp.extend([Keyword("name"), self.stage_name]) + sexp.extend([Keyword("cache-id"), self.cache_id]) + sexp.extend([Keyword("completed-at"), self.completed_at]) + + if self.outputs: + outputs_sexp = [] + for name, output in self.outputs.items(): + output_sexp = [Symbol(name)] + output.to_sexp() + outputs_sexp.append(output_sexp) + sexp.extend([Keyword("outputs"), outputs_sexp]) + + if self.metadata: + sexp.extend([Keyword("metadata"), self.metadata]) + + return sexp + + def to_string(self, pretty: bool = True) -> str: + """Serialize to S-expression string.""" + return serialize(self.to_sexp(), pretty=pretty) + + @classmethod + def from_sexp(cls, sexp: List) -> 'StageCacheEntry': + """Parse from S-expression.""" + if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "stage-result": + raise ValueError("Invalid stage-result sexp") + + stage_name = None + cache_id = None + completed_at = time.time() + outputs = {} + metadata = {} + + i = 1 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name + if i + 1 < len(sexp): + val = sexp[i + 1] + if key == "name": + stage_name = val + elif key == "cache-id": + cache_id = val + elif key == "completed-at": + completed_at = float(val) + elif key == "outputs": + if isinstance(val, list): + for output_sexp in val: + if isinstance(output_sexp, list) and output_sexp: + out_name = output_sexp[0] + if isinstance(out_name, Symbol): + out_name = out_name.name + outputs[out_name] = StageOutput.from_sexp(output_sexp[1:]) + elif key == "metadata": + metadata = val if isinstance(val, dict) else {} + i += 2 + else: + i += 1 + else: + i += 1 + + if not stage_name or not cache_id: + raise ValueError("stage-result missing required fields (name, cache-id)") + + return cls( + stage_name=stage_name, + cache_id=cache_id, + outputs=outputs, + completed_at=completed_at, + metadata=metadata, + ) + + @classmethod + def from_string(cls, text: str) -> 'StageCacheEntry': + """Parse from S-expression string.""" + sexp = parse(text) + return cls.from_sexp(sexp) + + +class StageCache: + """ + Stage-level cache manager using S-expression files. + + Cache structure: + cache_dir/ + _stages/ + {cache_id}.sexp <- Stage result files + """ + + def __init__(self, cache_dir: Union[str, Path]): + """ + Initialize stage cache. + + Args: + cache_dir: Base cache directory + """ + self.cache_dir = Path(cache_dir) + self.stages_dir = self.cache_dir / "_stages" + self.stages_dir.mkdir(parents=True, exist_ok=True) + + def get_cache_path(self, cache_id: str) -> Path: + """Get the path for a stage cache file.""" + return self.stages_dir / f"{cache_id}.sexp" + + def has_stage(self, cache_id: str) -> bool: + """Check if a stage result is cached.""" + return self.get_cache_path(cache_id).exists() + + def load_stage(self, cache_id: str) -> Optional[StageCacheEntry]: + """ + Load a cached stage result. + + Args: + cache_id: Stage cache ID + + Returns: + StageCacheEntry if found, None otherwise + """ + path = self.get_cache_path(cache_id) + if not path.exists(): + return None + + try: + content = path.read_text() + return StageCacheEntry.from_string(content) + except Exception as e: + # Corrupted cache file - log and return None + import sys + print(f"Warning: corrupted stage cache {cache_id}: {e}", file=sys.stderr) + return None + + def save_stage(self, entry: StageCacheEntry) -> Path: + """ + Save a stage result to cache. + + Args: + entry: Stage cache entry to save + + Returns: + Path to the saved cache file + """ + path = self.get_cache_path(entry.cache_id) + content = entry.to_string(pretty=True) + path.write_text(content) + return path + + def delete_stage(self, cache_id: str) -> bool: + """ + Delete a cached stage result. + + Args: + cache_id: Stage cache ID + + Returns: + True if deleted, False if not found + """ + path = self.get_cache_path(cache_id) + if path.exists(): + path.unlink() + return True + return False + + def list_stages(self) -> List[str]: + """List all cached stage IDs.""" + return [ + p.stem for p in self.stages_dir.glob("*.sexp") + ] + + def clear(self) -> int: + """ + Clear all cached stages. + + Returns: + Number of entries cleared + """ + count = 0 + for path in self.stages_dir.glob("*.sexp"): + path.unlink() + count += 1 + return count + + +@dataclass +class AnalysisResult: + """ + Analysis result stored as S-expression. + + Format: + (analysis-result + :analyzer "beats" + :input-hash "abc123..." + :duration 120.5 + :tempo 128.0 + :times (0.0 0.468 0.937 1.406 ...) + :values (0.8 0.9 0.7 0.85 ...)) + """ + analyzer: str + input_hash: str + data: Dict[str, Any] # Analysis data (times, values, duration, etc.) + computed_at: float = field(default_factory=time.time) + + def to_sexp(self) -> List: + """Convert to S-expression.""" + sexp = [Symbol("analysis-result")] + sexp.extend([Keyword("analyzer"), self.analyzer]) + sexp.extend([Keyword("input-hash"), self.input_hash]) + sexp.extend([Keyword("computed-at"), self.computed_at]) + + # Add all data fields + for key, value in self.data.items(): + # Convert key to keyword + sexp.extend([Keyword(key.replace("_", "-")), value]) + + return sexp + + def to_string(self, pretty: bool = True) -> str: + """Serialize to S-expression string.""" + return serialize(self.to_sexp(), pretty=pretty) + + @classmethod + def from_sexp(cls, sexp: List) -> 'AnalysisResult': + """Parse from S-expression.""" + if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis-result": + raise ValueError("Invalid analysis-result sexp") + + analyzer = None + input_hash = None + computed_at = time.time() + data = {} + + i = 1 + while i < len(sexp): + if isinstance(sexp[i], Keyword): + key = sexp[i].name + if i + 1 < len(sexp): + val = sexp[i + 1] + if key == "analyzer": + analyzer = val + elif key == "input-hash": + input_hash = val + elif key == "computed-at": + computed_at = float(val) + else: + # Convert kebab-case back to snake_case + data_key = key.replace("-", "_") + data[data_key] = val + i += 2 + else: + i += 1 + else: + i += 1 + + if not analyzer: + raise ValueError("analysis-result missing analyzer field") + + return cls( + analyzer=analyzer, + input_hash=input_hash or "", + data=data, + computed_at=computed_at, + ) + + @classmethod + def from_string(cls, text: str) -> 'AnalysisResult': + """Parse from S-expression string.""" + sexp = parse(text) + return cls.from_sexp(sexp) + + +def save_analysis_result( + cache_dir: Union[str, Path], + node_id: str, + result: AnalysisResult, +) -> Path: + """ + Save an analysis result as S-expression. + + Args: + cache_dir: Base cache directory + node_id: Node ID for the analysis + result: Analysis result to save + + Returns: + Path to the saved file + """ + cache_dir = Path(cache_dir) + node_dir = cache_dir / node_id + node_dir.mkdir(parents=True, exist_ok=True) + + path = node_dir / "analysis.sexp" + content = result.to_string(pretty=True) + path.write_text(content) + return path + + +def load_analysis_result( + cache_dir: Union[str, Path], + node_id: str, +) -> Optional[AnalysisResult]: + """ + Load an analysis result from cache. + + Args: + cache_dir: Base cache directory + node_id: Node ID for the analysis + + Returns: + AnalysisResult if found, None otherwise + """ + cache_dir = Path(cache_dir) + path = cache_dir / node_id / "analysis.sexp" + + if not path.exists(): + return None + + try: + content = path.read_text() + return AnalysisResult.from_string(content) + except Exception as e: + import sys + print(f"Warning: corrupted analysis cache {node_id}: {e}", file=sys.stderr) + return None diff --git a/artdag/sexp/test_ffmpeg_compiler.py b/artdag/sexp/test_ffmpeg_compiler.py new file mode 100644 index 0000000..1cfafe5 --- /dev/null +++ b/artdag/sexp/test_ffmpeg_compiler.py @@ -0,0 +1,146 @@ +""" +Tests for FFmpeg filter compilation. + +Validates that each filter mapping produces valid FFmpeg commands. +""" + +import subprocess +import tempfile +from pathlib import Path + +from .ffmpeg_compiler import FFmpegCompiler, EFFECT_MAPPINGS + + +def test_filter_syntax(filter_str: str, duration: float = 0.1, is_complex: bool = False) -> tuple[bool, str]: + """ + Test if an FFmpeg filter string is valid by running it on a test pattern. + + Args: + filter_str: The filter string to test + duration: Duration of test video + is_complex: If True, use -filter_complex instead of -vf + + Returns (success, error_message) + """ + with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f: + output_path = f.name + + try: + if is_complex: + # Complex filter graph needs -filter_complex and explicit output mapping + cmd = [ + 'ffmpeg', '-y', + '-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10', + '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', + '-filter_complex', f'[0:v]{filter_str}[out]', + '-map', '[out]', '-map', '1:a', + '-c:v', 'libx264', '-preset', 'ultrafast', + '-c:a', 'aac', + '-t', str(duration), + output_path + ] + else: + # Simple filter uses -vf + cmd = [ + 'ffmpeg', '-y', + '-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10', + '-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}', + '-vf', filter_str, + '-c:v', 'libx264', '-preset', 'ultrafast', + '-c:a', 'aac', + '-t', str(duration), + output_path + ] + + result = subprocess.run(cmd, capture_output=True, text=True, timeout=10) + + if result.returncode == 0: + return True, "" + else: + # Extract relevant error + stderr = result.stderr + for line in stderr.split('\n'): + if 'Error' in line or 'error' in line or 'Invalid' in line: + return False, line.strip() + return False, stderr[-500:] if len(stderr) > 500 else stderr + except subprocess.TimeoutExpired: + return False, "Timeout" + except Exception as e: + return False, str(e) + finally: + Path(output_path).unlink(missing_ok=True) + + +def run_all_tests(): + """Test all effect mappings.""" + compiler = FFmpegCompiler() + results = [] + + for effect_name, mapping in EFFECT_MAPPINGS.items(): + filter_name = mapping.get("filter") + + # Skip effects with no FFmpeg equivalent (external tools or python primitives) + if filter_name is None: + reason = "No FFmpeg equivalent" + if mapping.get("external_tool"): + reason = f"External tool: {mapping['external_tool']}" + elif mapping.get("python_primitive"): + reason = f"Python primitive: {mapping['python_primitive']}" + results.append((effect_name, "SKIP", reason)) + continue + + # Check if complex filter + is_complex = mapping.get("complex", False) + + # Build filter string + if "static" in mapping: + filter_str = f"{filter_name}={mapping['static']}" + else: + filter_str = filter_name + + # Test it + success, error = test_filter_syntax(filter_str, is_complex=is_complex) + + if success: + results.append((effect_name, "PASS", filter_str)) + else: + results.append((effect_name, "FAIL", f"{filter_str} -> {error}")) + + return results + + +def print_results(results): + """Print test results.""" + passed = sum(1 for _, status, _ in results if status == "PASS") + failed = sum(1 for _, status, _ in results if status == "FAIL") + skipped = sum(1 for _, status, _ in results if status == "SKIP") + + print(f"\n{'='*60}") + print(f"FFmpeg Filter Tests: {passed} passed, {failed} failed, {skipped} skipped") + print(f"{'='*60}\n") + + # Print failures first + if failed > 0: + print("FAILURES:") + for name, status, msg in results: + if status == "FAIL": + print(f" {name}: {msg}") + print() + + # Print passes + print("PASSED:") + for name, status, msg in results: + if status == "PASS": + print(f" {name}: {msg}") + + # Print skips + if skipped > 0: + print("\nSKIPPED (Python fallback):") + for name, status, msg in results: + if status == "SKIP": + print(f" {name}") + + +if __name__ == "__main__": + results = run_all_tests() + print_results(results) diff --git a/artdag/sexp/test_primitives.py b/artdag/sexp/test_primitives.py new file mode 100644 index 0000000..193c7fd --- /dev/null +++ b/artdag/sexp/test_primitives.py @@ -0,0 +1,201 @@ +""" +Tests for Python primitive effects. + +Tests that ascii_art, ascii_zones, and other Python primitives +can be executed via the EffectExecutor. +""" + +import subprocess +import tempfile +from pathlib import Path + +import pytest + +try: + import numpy as np + from PIL import Image + HAS_DEPS = True +except ImportError: + HAS_DEPS = False + +from .primitives import ( + ascii_art_frame, + ascii_zones_frame, + get_primitive, + list_primitives, +) +from .ffmpeg_compiler import FFmpegCompiler + + +def create_test_video(path: Path, duration: float = 0.5, size: str = "64x64") -> bool: + """Create a short test video using ffmpeg.""" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", f"testsrc=duration={duration}:size={size}:rate=10", + "-c:v", "libx264", "-preset", "ultrafast", + str(path) + ] + result = subprocess.run(cmd, capture_output=True) + return result.returncode == 0 + + +@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available") +class TestPrimitives: + """Test primitive functions directly.""" + + def test_ascii_art_frame_basic(self): + """Test ascii_art_frame produces output of same shape.""" + frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = ascii_art_frame(frame, char_size=8) + assert result.shape == frame.shape + assert result.dtype == np.uint8 + + def test_ascii_zones_frame_basic(self): + """Test ascii_zones_frame produces output of same shape.""" + frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = ascii_zones_frame(frame, char_size=8) + assert result.shape == frame.shape + assert result.dtype == np.uint8 + + def test_get_primitive(self): + """Test primitive lookup.""" + assert get_primitive("ascii_art_frame") is ascii_art_frame + assert get_primitive("ascii_zones_frame") is ascii_zones_frame + assert get_primitive("nonexistent") is None + + def test_list_primitives(self): + """Test listing primitives.""" + primitives = list_primitives() + assert "ascii_art_frame" in primitives + assert "ascii_zones_frame" in primitives + assert len(primitives) > 5 + + +class TestFFmpegCompilerPrimitives: + """Test FFmpegCompiler python_primitive mappings.""" + + def test_has_python_primitive_ascii_art(self): + """Test ascii_art has python_primitive.""" + compiler = FFmpegCompiler() + assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame" + + def test_has_python_primitive_ascii_zones(self): + """Test ascii_zones has python_primitive.""" + compiler = FFmpegCompiler() + assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame" + + def test_has_python_primitive_ffmpeg_effect(self): + """Test FFmpeg effects don't have python_primitive.""" + compiler = FFmpegCompiler() + assert compiler.has_python_primitive("brightness") is None + assert compiler.has_python_primitive("blur") is None + + def test_compile_effect_returns_none_for_primitives(self): + """Test compile_effect returns None for primitive effects.""" + compiler = FFmpegCompiler() + assert compiler.compile_effect("ascii_art", {}) is None + assert compiler.compile_effect("ascii_zones", {}) is None + + +@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available") +class TestEffectExecutorPrimitives: + """Test EffectExecutor with Python primitives.""" + + def test_executor_loads_primitive(self): + """Test that executor finds primitive effects.""" + from ..nodes.effect import _get_python_primitive_effect + + effect_fn = _get_python_primitive_effect("ascii_art") + assert effect_fn is not None + + effect_fn = _get_python_primitive_effect("ascii_zones") + assert effect_fn is not None + + def test_executor_rejects_unknown_effect(self): + """Test that executor returns None for unknown effects.""" + from ..nodes.effect import _get_python_primitive_effect + + effect_fn = _get_python_primitive_effect("nonexistent_effect") + assert effect_fn is None + + def test_execute_ascii_art_effect(self, tmp_path): + """Test executing ascii_art effect on a video.""" + from ..nodes.effect import EffectExecutor + + # Create test video + input_path = tmp_path / "input.mp4" + if not create_test_video(input_path): + pytest.skip("Could not create test video") + + output_path = tmp_path / "output.mkv" + executor = EffectExecutor() + + result = executor.execute( + config={"effect": "ascii_art", "char_size": 8}, + inputs=[input_path], + output_path=output_path, + ) + + assert result.exists() + assert result.stat().st_size > 0 + + +def run_all_tests(): + """Run tests manually.""" + import sys + + # Check dependencies + if not HAS_DEPS: + print("SKIP: numpy/PIL not available") + return + + print("Testing primitives...") + + # Test primitive functions + frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + + print(" ascii_art_frame...", end=" ") + result = ascii_art_frame(frame, char_size=8) + assert result.shape == frame.shape + print("PASS") + + print(" ascii_zones_frame...", end=" ") + result = ascii_zones_frame(frame, char_size=8) + assert result.shape == frame.shape + print("PASS") + + # Test FFmpegCompiler mappings + print("\nTesting FFmpegCompiler mappings...") + compiler = FFmpegCompiler() + + print(" ascii_art python_primitive...", end=" ") + assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame" + print("PASS") + + print(" ascii_zones python_primitive...", end=" ") + assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame" + print("PASS") + + # Test executor lookup + print("\nTesting EffectExecutor...") + try: + from ..nodes.effect import _get_python_primitive_effect + + print(" _get_python_primitive_effect(ascii_art)...", end=" ") + effect_fn = _get_python_primitive_effect("ascii_art") + assert effect_fn is not None + print("PASS") + + print(" _get_python_primitive_effect(ascii_zones)...", end=" ") + effect_fn = _get_python_primitive_effect("ascii_zones") + assert effect_fn is not None + print("PASS") + + except ImportError as e: + print(f"SKIP: {e}") + + print("\n=== All tests passed ===") + + +if __name__ == "__main__": + run_all_tests() diff --git a/artdag/sexp/test_stage_cache.py b/artdag/sexp/test_stage_cache.py new file mode 100644 index 0000000..87daf3f --- /dev/null +++ b/artdag/sexp/test_stage_cache.py @@ -0,0 +1,324 @@ +""" +Tests for stage cache layer. + +Tests S-expression storage for stage results and analysis data. +""" + +import pytest +import tempfile +from pathlib import Path + +from .stage_cache import ( + StageCache, + StageCacheEntry, + StageOutput, + AnalysisResult, + save_analysis_result, + load_analysis_result, +) +from .parser import parse, serialize + + +class TestStageOutput: + """Test StageOutput dataclass and serialization.""" + + def test_stage_output_artifact(self): + """StageOutput can represent an artifact.""" + output = StageOutput( + cache_id="abc123", + output_type="artifact", + ) + assert output.cache_id == "abc123" + assert output.output_type == "artifact" + + def test_stage_output_scalar(self): + """StageOutput can represent a scalar value.""" + output = StageOutput( + value=120.5, + output_type="scalar", + ) + assert output.value == 120.5 + assert output.output_type == "scalar" + + def test_stage_output_to_sexp(self): + """StageOutput serializes to sexp.""" + output = StageOutput( + cache_id="abc123", + output_type="artifact", + ) + sexp = output.to_sexp() + sexp_str = serialize(sexp) + + assert "cache-id" in sexp_str + assert "abc123" in sexp_str + assert "type" in sexp_str + assert "artifact" in sexp_str + + def test_stage_output_from_sexp(self): + """StageOutput parses from sexp.""" + sexp = parse('(:cache-id "def456" :type :analysis)') + output = StageOutput.from_sexp(sexp) + + assert output.cache_id == "def456" + assert output.output_type == "analysis" + + +class TestStageCacheEntry: + """Test StageCacheEntry serialization.""" + + def test_stage_cache_entry_to_sexp(self): + """StageCacheEntry serializes to sexp.""" + entry = StageCacheEntry( + stage_name="analyze-a", + cache_id="stage_abc123", + outputs={ + "beats": StageOutput(cache_id="beats_def456", output_type="analysis"), + "tempo": StageOutput(value=120.5, output_type="scalar"), + }, + completed_at=1705678900.123, + ) + + sexp = entry.to_sexp() + sexp_str = serialize(sexp) + + assert "stage-result" in sexp_str + assert "analyze-a" in sexp_str + assert "stage_abc123" in sexp_str + assert "outputs" in sexp_str + assert "beats" in sexp_str + + def test_stage_cache_entry_roundtrip(self): + """save -> load produces identical data.""" + entry = StageCacheEntry( + stage_name="analyze-b", + cache_id="stage_xyz789", + outputs={ + "segments": StageOutput(cache_id="seg_123", output_type="artifact"), + }, + completed_at=1705678900.0, + ) + + sexp_str = entry.to_string() + loaded = StageCacheEntry.from_string(sexp_str) + + assert loaded.stage_name == entry.stage_name + assert loaded.cache_id == entry.cache_id + assert "segments" in loaded.outputs + assert loaded.outputs["segments"].cache_id == "seg_123" + + def test_stage_cache_entry_from_sexp(self): + """StageCacheEntry parses from sexp.""" + sexp_str = ''' + (stage-result + :name "test-stage" + :cache-id "cache123" + :completed-at 1705678900.0 + :outputs ((beats :cache-id "beats123" :type :analysis))) + ''' + entry = StageCacheEntry.from_string(sexp_str) + + assert entry.stage_name == "test-stage" + assert entry.cache_id == "cache123" + assert "beats" in entry.outputs + assert entry.outputs["beats"].cache_id == "beats123" + + +class TestStageCache: + """Test StageCache file operations.""" + + def test_save_and_load_stage(self): + """Save and load a stage result.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + entry = StageCacheEntry( + stage_name="analyze", + cache_id="test_cache_id", + outputs={ + "beats": StageOutput(cache_id="beats_out", output_type="analysis"), + }, + ) + + path = cache.save_stage(entry) + assert path.exists() + assert path.suffix == ".sexp" + + loaded = cache.load_stage("test_cache_id") + assert loaded is not None + assert loaded.stage_name == "analyze" + assert "beats" in loaded.outputs + + def test_has_stage(self): + """Check if stage is cached.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + assert not cache.has_stage("nonexistent") + + entry = StageCacheEntry( + stage_name="test", + cache_id="exists_cache_id", + outputs={}, + ) + cache.save_stage(entry) + + assert cache.has_stage("exists_cache_id") + + def test_delete_stage(self): + """Delete a cached stage.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + entry = StageCacheEntry( + stage_name="test", + cache_id="to_delete", + outputs={}, + ) + cache.save_stage(entry) + + assert cache.has_stage("to_delete") + result = cache.delete_stage("to_delete") + assert result is True + assert not cache.has_stage("to_delete") + + def test_list_stages(self): + """List all cached stages.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + for i in range(3): + entry = StageCacheEntry( + stage_name=f"stage{i}", + cache_id=f"cache_{i}", + outputs={}, + ) + cache.save_stage(entry) + + stages = cache.list_stages() + assert len(stages) == 3 + assert "cache_0" in stages + assert "cache_1" in stages + assert "cache_2" in stages + + def test_clear(self): + """Clear all cached stages.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + for i in range(3): + entry = StageCacheEntry( + stage_name=f"stage{i}", + cache_id=f"cache_{i}", + outputs={}, + ) + cache.save_stage(entry) + + count = cache.clear() + assert count == 3 + assert len(cache.list_stages()) == 0 + + def test_cache_file_extension(self): + """Cache files use .sexp extension.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + path = cache.get_cache_path("test_id") + assert path.suffix == ".sexp" + + def test_invalid_sexp_error_handling(self): + """Graceful error on corrupted cache file.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache = StageCache(tmpdir) + + # Write corrupted content + corrupt_path = cache.get_cache_path("corrupted") + corrupt_path.write_text("this is not valid sexp )()(") + + # Should return None, not raise + result = cache.load_stage("corrupted") + assert result is None + + +class TestAnalysisResult: + """Test AnalysisResult serialization.""" + + def test_analysis_result_to_sexp(self): + """AnalysisResult serializes to sexp.""" + result = AnalysisResult( + analyzer="beats", + input_hash="input_abc123", + data={ + "duration": 120.5, + "tempo": 128.0, + "times": [0.0, 0.468, 0.937, 1.406], + "values": [0.8, 0.9, 0.7, 0.85], + }, + ) + + sexp = result.to_sexp() + sexp_str = serialize(sexp) + + assert "analysis-result" in sexp_str + assert "beats" in sexp_str + assert "duration" in sexp_str + assert "tempo" in sexp_str + assert "times" in sexp_str + + def test_analysis_result_roundtrip(self): + """Analysis result round-trips through sexp.""" + original = AnalysisResult( + analyzer="scenes", + input_hash="video_xyz", + data={ + "scene_count": 5, + "scene_times": [0.0, 10.5, 25.0, 45.2, 60.0], + }, + ) + + sexp_str = original.to_string() + loaded = AnalysisResult.from_string(sexp_str) + + assert loaded.analyzer == original.analyzer + assert loaded.input_hash == original.input_hash + assert loaded.data["scene_count"] == 5 + + def test_save_and_load_analysis_result(self): + """Save and load analysis result from cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + result = AnalysisResult( + analyzer="beats", + input_hash="audio_123", + data={ + "tempo": 120.0, + "times": [0.0, 0.5, 1.0], + }, + ) + + path = save_analysis_result(tmpdir, "node_abc", result) + assert path.exists() + assert path.name == "analysis.sexp" + + loaded = load_analysis_result(tmpdir, "node_abc") + assert loaded is not None + assert loaded.analyzer == "beats" + assert loaded.data["tempo"] == 120.0 + + def test_analysis_result_kebab_case(self): + """Keys convert between snake_case and kebab-case.""" + result = AnalysisResult( + analyzer="test", + input_hash="hash", + data={ + "scene_count": 5, + "beat_times": [1, 2, 3], + }, + ) + + sexp_str = result.to_string() + # Kebab case in sexp + assert "scene-count" in sexp_str + assert "beat-times" in sexp_str + + # Back to snake_case after parsing + loaded = AnalysisResult.from_string(sexp_str) + assert "scene_count" in loaded.data + assert "beat_times" in loaded.data diff --git a/artdag/sexp/test_stage_compiler.py b/artdag/sexp/test_stage_compiler.py new file mode 100644 index 0000000..c1d3cc2 --- /dev/null +++ b/artdag/sexp/test_stage_compiler.py @@ -0,0 +1,286 @@ +""" +Tests for stage compilation and scoping. + +Tests the CompiledStage dataclass, stage form parsing, +variable scoping, and dependency validation. +""" + +import pytest + +from .parser import parse, Symbol, Keyword +from .compiler import ( + compile_recipe, + CompileError, + CompiledStage, + CompilerContext, + _topological_sort_stages, +) + + +class TestStageCompilation: + """Test stage form compilation.""" + + def test_parse_stage_form_basic(self): + """Stage parses correctly with name and outputs.""" + recipe = ''' + (recipe "test-stage" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats))) + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 1 + assert compiled.stages[0].name == "analyze" + assert "beats" in compiled.stages[0].outputs + assert len(compiled.stages[0].node_ids) > 0 + + def test_parse_stage_with_requires(self): + """Stage parses correctly with requires and inputs.""" + recipe = ''' + (recipe "test-requires" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [:analyze] + :inputs [beats] + :outputs [segments] + (def segments (-> audio (segment :times beats))) + (-> segments (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 2 + process_stage = next(s for s in compiled.stages if s.name == "process") + assert process_stage.requires == ["analyze"] + assert "beats" in process_stage.inputs + assert "segments" in process_stage.outputs + + def test_stage_outputs_recorded(self): + """Stage outputs are tracked in CompiledStage.""" + recipe = ''' + (recipe "test-outputs" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats tempo] + (def beats (-> audio (analyze beats))) + (def tempo (-> audio (analyze tempo))) + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + stage = compiled.stages[0] + assert "beats" in stage.outputs + assert "tempo" in stage.outputs + assert "beats" in stage.output_bindings + assert "tempo" in stage.output_bindings + + def test_stage_order_topological(self): + """Stages are topologically sorted.""" + recipe = ''' + (recipe "test-order" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + # analyze should come before output + assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output") + + +class TestStageValidation: + """Test stage dependency and input validation.""" + + def test_stage_requires_validation(self): + """Error if requiring non-existent stage.""" + recipe = ''' + (recipe "test-bad-require" + (def audio (source :path "test.mp3")) + + (stage :process + :requires [:nonexistent] + :inputs [beats] + (def result audio))) + ''' + with pytest.raises(CompileError, match="requires undefined stage"): + compile_recipe(parse(recipe)) + + def test_stage_inputs_validation(self): + """Error if input not produced by required stage.""" + recipe = ''' + (recipe "test-bad-input" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [:analyze] + :inputs [nonexistent] + (def result audio))) + ''' + with pytest.raises(CompileError, match="not an output of any required stage"): + compile_recipe(parse(recipe)) + + def test_undeclared_output_error(self): + """Error if stage declares output not defined in body.""" + recipe = ''' + (recipe "test-missing-output" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats nonexistent] + (def beats (-> audio (analyze beats))))) + ''' + with pytest.raises(CompileError, match="not defined in the stage body"): + compile_recipe(parse(recipe)) + + def test_forward_reference_detection(self): + """Error when requiring a stage not yet defined.""" + # Forward references are not allowed - stages must be defined + # before they can be required + recipe = ''' + (recipe "test-forward" + (def audio (source :path "test.mp3")) + + (stage :a + :requires [:b] + :outputs [out-a] + (def out-a audio)) + + (stage :b + :outputs [out-b] + (def out-b audio) + audio)) + ''' + with pytest.raises(CompileError, match="requires undefined stage"): + compile_recipe(parse(recipe)) + + +class TestStageScoping: + """Test variable scoping between stages.""" + + def test_pre_stage_bindings_accessible(self): + """Sources defined before stages accessible to all stages.""" + recipe = ''' + (recipe "test-pre-stage" + (def audio (source :path "test.mp3")) + (def video (source :path "test.mp4")) + + (stage :analyze-audio + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :analyze-video + :outputs [scenes] + (def scenes (-> video (analyze scenes))) + (-> video (segment :times scenes) (sequence)))) + ''' + # Should compile without error - audio and video accessible to both stages + compiled = compile_recipe(parse(recipe)) + assert len(compiled.stages) == 2 + + def test_stage_bindings_flow_through_requires(self): + """Stage bindings accessible to dependent stages via :inputs.""" + recipe = ''' + (recipe "test-binding-flow" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [:analyze] + :inputs [beats] + :outputs [result] + (def result (-> audio (segment :times beats))) + (-> result (sequence)))) + ''' + # Should compile without error - beats flows from analyze to process + compiled = compile_recipe(parse(recipe)) + assert len(compiled.stages) == 2 + + +class TestTopologicalSort: + """Test stage topological sorting.""" + + def test_empty_stages(self): + """Empty stages returns empty list.""" + assert _topological_sort_stages({}) == [] + + def test_single_stage(self): + """Single stage returns single element.""" + stages = { + "a": CompiledStage( + name="a", + requires=[], + inputs=[], + outputs=["out"], + node_ids=["n1"], + output_bindings={"out": "n1"}, + ) + } + assert _topological_sort_stages(stages) == ["a"] + + def test_linear_chain(self): + """Linear chain sorted correctly.""" + stages = { + "a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + "b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + "c": CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + } + result = _topological_sort_stages(stages) + assert result.index("a") < result.index("b") < result.index("c") + + def test_parallel_stages_same_level(self): + """Parallel stages are both valid orderings.""" + stages = { + "a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + "b": CompiledStage(name="b", requires=[], inputs=[], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + } + result = _topological_sort_stages(stages) + # Both a and b should be in result (order doesn't matter) + assert set(result) == {"a", "b"} + + def test_diamond_dependency(self): + """Diamond pattern: A -> B, A -> C, B+C -> D.""" + stages = { + "a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + "b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + "c": CompiledStage(name="c", requires=["a"], inputs=["x"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + "d": CompiledStage(name="d", requires=["b", "c"], inputs=["y", "z"], outputs=["out"], + node_ids=["n4"], output_bindings={"out": "n4"}), + } + result = _topological_sort_stages(stages) + # a must be first, d must be last + assert result[0] == "a" + assert result[-1] == "d" + # b and c must be before d + assert result.index("b") < result.index("d") + assert result.index("c") < result.index("d") diff --git a/artdag/sexp/test_stage_integration.py b/artdag/sexp/test_stage_integration.py new file mode 100644 index 0000000..f32aa46 --- /dev/null +++ b/artdag/sexp/test_stage_integration.py @@ -0,0 +1,739 @@ +""" +End-to-end integration tests for staged recipes. + +Tests the complete flow: compile -> plan -> execute +for recipes with stages. +""" + +import pytest +import tempfile +from pathlib import Path + +from .parser import parse, serialize +from .compiler import compile_recipe, CompileError +from .planner import ExecutionPlanSexp, StagePlan +from .stage_cache import StageCache, StageCacheEntry, StageOutput +from .scheduler import StagePlanScheduler, StagePlanResult + + +class TestSimpleTwoStageRecipe: + """Test basic two-stage recipe flow.""" + + def test_compile_two_stage_recipe(self): + """Compile a simple two-stage recipe.""" + recipe = ''' + (recipe "test-two-stages" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 2 + assert compiled.stage_order == ["analyze", "output"] + + analyze_stage = compiled.stages[0] + assert analyze_stage.name == "analyze" + assert "beats" in analyze_stage.outputs + + output_stage = compiled.stages[1] + assert output_stage.name == "output" + assert output_stage.requires == ["analyze"] + assert "beats" in output_stage.inputs + + +class TestParallelAnalysisStages: + """Test parallel analysis stages.""" + + def test_compile_parallel_stages(self): + """Two analysis stages can run in parallel.""" + recipe = ''' + (recipe "test-parallel" + (def audio-a (source :path "a.mp3")) + (def audio-b (source :path "b.mp3")) + + (stage :analyze-a + :outputs [beats-a] + (def beats-a (-> audio-a (analyze beats)))) + + (stage :analyze-b + :outputs [beats-b] + (def beats-b (-> audio-b (analyze beats)))) + + (stage :combine + :requires [:analyze-a :analyze-b] + :inputs [beats-a beats-b] + (-> audio-a (segment :times beats-a) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 3 + + # analyze-a and analyze-b should both be at level 0 (parallel) + analyze_a = next(s for s in compiled.stages if s.name == "analyze-a") + analyze_b = next(s for s in compiled.stages if s.name == "analyze-b") + combine = next(s for s in compiled.stages if s.name == "combine") + + assert analyze_a.requires == [] + assert analyze_b.requires == [] + assert set(combine.requires) == {"analyze-a", "analyze-b"} + + +class TestDiamondDependency: + """Test diamond dependency pattern: A -> B, A -> C, B+C -> D.""" + + def test_compile_diamond_pattern(self): + """Diamond pattern compiles correctly.""" + recipe = ''' + (recipe "test-diamond" + (def audio (source :path "test.mp3")) + + (stage :source-stage + :outputs [audio-ref] + (def audio-ref audio)) + + (stage :branch-b + :requires [:source-stage] + :inputs [audio-ref] + :outputs [result-b] + (def result-b (-> audio-ref (effect gain :amount 0.5)))) + + (stage :branch-c + :requires [:source-stage] + :inputs [audio-ref] + :outputs [result-c] + (def result-c (-> audio-ref (effect gain :amount 0.8)))) + + (stage :merge + :requires [:branch-b :branch-c] + :inputs [result-b result-c] + (-> result-b (blend result-c :mode "mix")))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 4 + + # Check dependencies + source = next(s for s in compiled.stages if s.name == "source-stage") + branch_b = next(s for s in compiled.stages if s.name == "branch-b") + branch_c = next(s for s in compiled.stages if s.name == "branch-c") + merge = next(s for s in compiled.stages if s.name == "merge") + + assert source.requires == [] + assert branch_b.requires == ["source-stage"] + assert branch_c.requires == ["source-stage"] + assert set(merge.requires) == {"branch-b", "branch-c"} + + # source-stage should come first in order + assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-b") + assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-c") + # merge should come last + assert compiled.stage_order.index("branch-b") < compiled.stage_order.index("merge") + assert compiled.stage_order.index("branch-c") < compiled.stage_order.index("merge") + + +class TestStageReuseOnRerun: + """Test that re-running recipe uses cached stages.""" + + def test_stage_reuse(self): + """Re-running recipe uses cached stages.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + # Simulate first run by caching a stage + entry = StageCacheEntry( + stage_name="analyze", + cache_id="fixed_cache_id", + outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")}, + ) + stage_cache.save_stage(entry) + + # Verify cache exists + assert stage_cache.has_stage("fixed_cache_id") + + # Second run should find cache + loaded = stage_cache.load_stage("fixed_cache_id") + assert loaded is not None + assert loaded.stage_name == "analyze" + + +class TestExplicitDataFlowEndToEnd: + """Test that analysis results flow through :inputs/:outputs.""" + + def test_data_flow_declaration(self): + """Explicit data flow is declared correctly.""" + recipe = ''' + (recipe "test-data-flow" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats tempo] + (def beats (-> audio (analyze beats))) + (def tempo (-> audio (analyze tempo)))) + + (stage :process + :requires [:analyze] + :inputs [beats tempo] + :outputs [result] + (def result (-> audio (segment :times beats) (effect speed :factor tempo))) + (-> result (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + + analyze = next(s for s in compiled.stages if s.name == "analyze") + process = next(s for s in compiled.stages if s.name == "process") + + # Analyze outputs + assert set(analyze.outputs) == {"beats", "tempo"} + assert "beats" in analyze.output_bindings + assert "tempo" in analyze.output_bindings + + # Process inputs + assert set(process.inputs) == {"beats", "tempo"} + assert process.requires == ["analyze"] + + +class TestRecipeFixtures: + """Test using recipe fixtures.""" + + @pytest.fixture + def test_recipe_two_stages(self): + return ''' + (recipe "test-two-stages" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + + @pytest.fixture + def test_recipe_parallel_stages(self): + return ''' + (recipe "test-parallel" + (def audio-a (source :path "a.mp3")) + (def audio-b (source :path "b.mp3")) + + (stage :analyze-a + :outputs [beats-a] + (def beats-a (-> audio-a (analyze beats)))) + + (stage :analyze-b + :outputs [beats-b] + (def beats-b (-> audio-b (analyze beats)))) + + (stage :combine + :requires [:analyze-a :analyze-b] + :inputs [beats-a beats-b] + (-> audio-a (blend audio-b :mode "mix")))) + ''' + + def test_two_stages_fixture(self, test_recipe_two_stages): + """Two-stage recipe fixture compiles.""" + compiled = compile_recipe(parse(test_recipe_two_stages)) + assert len(compiled.stages) == 2 + + def test_parallel_stages_fixture(self, test_recipe_parallel_stages): + """Parallel stages recipe fixture compiles.""" + compiled = compile_recipe(parse(test_recipe_parallel_stages)) + assert len(compiled.stages) == 3 + + +class TestStageValidationErrors: + """Test error handling for invalid stage recipes.""" + + def test_missing_output_declaration(self): + """Error when stage output not declared.""" + recipe = ''' + (recipe "test-missing-output" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats nonexistent] + (def beats (-> audio (analyze beats))))) + ''' + with pytest.raises(CompileError, match="not defined in the stage body"): + compile_recipe(parse(recipe)) + + def test_input_without_requires(self): + """Error when using input not from required stage.""" + recipe = ''' + (recipe "test-bad-input" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :process + :requires [] + :inputs [beats] + (def result audio))) + ''' + with pytest.raises(CompileError, match="not an output of any required stage"): + compile_recipe(parse(recipe)) + + def test_forward_reference(self): + """Error when requiring stage not yet defined (forward reference).""" + recipe = ''' + (recipe "test-forward-ref" + (def audio (source :path "test.mp3")) + + (stage :a + :requires [:b] + :outputs [out-a] + (def out-a audio) + audio) + + (stage :b + :outputs [out-b] + (def out-b audio) + audio)) + ''' + with pytest.raises(CompileError, match="requires undefined stage"): + compile_recipe(parse(recipe)) + + +class TestBeatSyncDemoRecipe: + """Test the beat-sync demo recipe from examples.""" + + BEAT_SYNC_RECIPE = ''' + ;; Simple staged recipe demo + (recipe "beat-sync-demo" + :version "1.0" + :description "Demo of staged beat-sync workflow" + + ;; Pre-stage definitions (available to all stages) + (def audio (source :path "input.mp3")) + + ;; Stage 1: Analysis (expensive, cached) + (stage :analyze + :outputs [beats tempo] + (def beats (-> audio (analyze beats))) + (def tempo (-> audio (analyze tempo)))) + + ;; Stage 2: Processing (uses analysis results) + (stage :process + :requires [:analyze] + :inputs [beats] + :outputs [segments] + (def segments (-> audio (segment :times beats))) + (-> segments (sequence)))) + ''' + + def test_compile_beat_sync_recipe(self): + """Beat-sync demo recipe compiles correctly.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + assert compiled.name == "beat-sync-demo" + assert compiled.version == "1.0" + assert compiled.description == "Demo of staged beat-sync workflow" + + def test_beat_sync_stage_count(self): + """Beat-sync has 2 stages in correct order.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + assert len(compiled.stages) == 2 + assert compiled.stage_order == ["analyze", "process"] + + def test_beat_sync_analyze_stage(self): + """Analyze stage has correct outputs.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + analyze = next(s for s in compiled.stages if s.name == "analyze") + assert analyze.requires == [] + assert analyze.inputs == [] + assert set(analyze.outputs) == {"beats", "tempo"} + assert "beats" in analyze.output_bindings + assert "tempo" in analyze.output_bindings + + def test_beat_sync_process_stage(self): + """Process stage has correct dependencies and inputs.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + process = next(s for s in compiled.stages if s.name == "process") + assert process.requires == ["analyze"] + assert "beats" in process.inputs + assert "segments" in process.outputs + + def test_beat_sync_node_count(self): + """Beat-sync generates expected number of nodes.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + # 1 SOURCE + 2 ANALYZE + 1 SEGMENT + 1 SEQUENCE = 5 nodes + assert len(compiled.nodes) == 5 + + def test_beat_sync_node_types(self): + """Beat-sync generates correct node types.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + node_types = [n["type"] for n in compiled.nodes] + assert node_types.count("SOURCE") == 1 + assert node_types.count("ANALYZE") == 2 + assert node_types.count("SEGMENT") == 1 + assert node_types.count("SEQUENCE") == 1 + + def test_beat_sync_output_is_sequence(self): + """Beat-sync output node is the sequence node.""" + compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE)) + + output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id) + assert output_node["type"] == "SEQUENCE" + + +class TestAsciiArtStagedRecipe: + """Test the ASCII art staged recipe.""" + + ASCII_ART_STAGED_RECIPE = ''' + ;; ASCII art effect with staged execution + (recipe "ascii_art_staged" + :version "1.0" + :description "ASCII art effect with staged execution" + :encoding (:codec "libx264" :crf 20 :preset "medium" :audio-codec "aac" :fps 30) + + ;; Registry + (effect ascii_art :path "sexp_effects/effects/ascii_art.sexp") + (analyzer energy :path "../artdag-analyzers/energy/analyzer.py") + + ;; Pre-stage definitions + (def color_mode "color") + (def video (source :path "monday.webm")) + (def audio (source :path "dizzy.mp3")) + + ;; Stage 1: Analysis + (stage :analyze + :outputs [energy-data] + (def audio-clip (-> audio (segment :start 60 :duration 10))) + (def energy-data (-> audio-clip (analyze energy)))) + + ;; Stage 2: Process + (stage :process + :requires [:analyze] + :inputs [energy-data] + :outputs [result audio-clip] + (def clip (-> video (segment :start 0 :duration 10))) + (def audio-clip (-> audio (segment :start 60 :duration 10))) + (def result (-> clip + (effect ascii_art + :char_size (bind energy-data values :range [2 32]) + :color_mode color_mode)))) + + ;; Stage 3: Output + (stage :output + :requires [:process] + :inputs [result audio-clip] + (mux result audio-clip))) + ''' + + def test_compile_ascii_art_staged(self): + """ASCII art staged recipe compiles correctly.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + assert compiled.name == "ascii_art_staged" + assert compiled.version == "1.0" + + def test_ascii_art_stage_count(self): + """ASCII art has 3 stages in correct order.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + assert len(compiled.stages) == 3 + assert compiled.stage_order == ["analyze", "process", "output"] + + def test_ascii_art_analyze_stage(self): + """Analyze stage outputs energy-data.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + analyze = next(s for s in compiled.stages if s.name == "analyze") + assert analyze.requires == [] + assert analyze.inputs == [] + assert "energy-data" in analyze.outputs + + def test_ascii_art_process_stage(self): + """Process stage requires analyze and outputs result.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + process = next(s for s in compiled.stages if s.name == "process") + assert process.requires == ["analyze"] + assert "energy-data" in process.inputs + assert "result" in process.outputs + assert "audio-clip" in process.outputs + + def test_ascii_art_output_stage(self): + """Output stage requires process and has mux.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + output = next(s for s in compiled.stages if s.name == "output") + assert output.requires == ["process"] + assert "result" in output.inputs + assert "audio-clip" in output.inputs + + def test_ascii_art_node_count(self): + """ASCII art generates expected nodes.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + # 2 SOURCE + 2 SEGMENT + 1 ANALYZE + 1 EFFECT + 1 MUX = 7+ nodes + assert len(compiled.nodes) >= 7 + + def test_ascii_art_has_mux_output(self): + """ASCII art output is MUX node.""" + compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE)) + + output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id) + assert output_node["type"] == "MUX" + + +class TestMixedStagedAndNonStagedRecipes: + """Test that non-staged recipes still work.""" + + def test_recipe_without_stages(self): + """Non-staged recipe compiles normally.""" + recipe = ''' + (recipe "no-stages" + (-> (source :path "test.mp3") + (effect gain :amount 0.5))) + ''' + compiled = compile_recipe(parse(recipe)) + + assert compiled.stages == [] + assert compiled.stage_order == [] + # Should still have nodes + assert len(compiled.nodes) > 0 + + def test_mixed_pre_stage_and_stages(self): + """Pre-stage definitions work with stages.""" + recipe = ''' + (recipe "mixed" + ;; Pre-stage definitions + (def audio (source :path "test.mp3")) + (def volume 0.8) + + ;; Stage using pre-stage definitions, ending with output expression + (stage :process + :outputs [result] + (def result (-> audio (effect gain :amount volume))) + result)) + ''' + compiled = compile_recipe(parse(recipe)) + + assert len(compiled.stages) == 1 + # audio and volume should be accessible in stage + process = compiled.stages[0] + assert process.name == "process" + assert "result" in process.outputs + + +class TestEffectParamsBlock: + """Test :params block parsing in effect definitions.""" + + def test_parse_effect_with_params_block(self): + """Parse effect with new :params syntax.""" + from .effect_loader import load_sexp_effect + + effect_code = ''' + (define-effect test_effect + :params ( + (size :type int :default 10 :range [1 100] :desc "Size parameter") + (color :type string :default "red" :desc "Color parameter") + (enabled :type int :default 1 :range [0 1] :desc "Enable flag") + ) + frame) + ''' + name, process_fn, defaults, param_defs = load_sexp_effect(effect_code) + + assert name == "test_effect" + assert len(param_defs) == 3 + assert defaults["size"] == 10 + assert defaults["color"] == "red" + assert defaults["enabled"] == 1 + + # Check ParamDef objects + size_param = param_defs[0] + assert size_param.name == "size" + assert size_param.param_type == "int" + assert size_param.default == 10 + assert size_param.range_min == 1.0 + assert size_param.range_max == 100.0 + assert size_param.description == "Size parameter" + + color_param = param_defs[1] + assert color_param.name == "color" + assert color_param.param_type == "string" + assert color_param.default == "red" + + def test_parse_effect_with_choices(self): + """Parse effect with choices in :params.""" + from .effect_loader import load_sexp_effect + + effect_code = ''' + (define-effect mode_effect + :params ( + (mode :type string :default "fast" + :choices [fast slow medium] + :desc "Processing mode") + ) + frame) + ''' + name, _, defaults, param_defs = load_sexp_effect(effect_code) + + assert name == "mode_effect" + assert defaults["mode"] == "fast" + + mode_param = param_defs[0] + assert mode_param.choices == ["fast", "slow", "medium"] + + def test_legacy_effect_syntax_rejected(self): + """Legacy effect syntax should be rejected.""" + from .effect_loader import load_sexp_effect + import pytest + + effect_code = ''' + (define-effect legacy_effect + ((width 100) + (height 200) + (name "default")) + frame) + ''' + with pytest.raises(ValueError) as exc_info: + load_sexp_effect(effect_code) + + assert "Legacy parameter syntax" in str(exc_info.value) + assert ":params" in str(exc_info.value) + + def test_effect_params_introspection(self): + """Test that effect params are available for introspection.""" + from .effect_loader import load_sexp_effect_file + from pathlib import Path + + # Create a temp effect file + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f: + f.write(''' + (define-effect introspect_test + :params ( + (alpha :type float :default 0.5 :range [0 1] :desc "Alpha value") + ) + frame) + ''') + temp_path = Path(f.name) + + try: + name, _, defaults, param_defs = load_sexp_effect_file(temp_path) + assert name == "introspect_test" + assert len(param_defs) == 1 + assert param_defs[0].name == "alpha" + assert param_defs[0].param_type == "float" + finally: + temp_path.unlink() + + +class TestConstructParamsBlock: + """Test :params block parsing in construct definitions.""" + + def test_parse_construct_params_helper(self): + """Test the _parse_construct_params helper function.""" + from .planner import _parse_construct_params + from .parser import Symbol, Keyword + + params_list = [ + [Symbol("duration"), Keyword("type"), Symbol("float"), + Keyword("default"), 5.0, Keyword("desc"), "Duration in seconds"], + [Symbol("count"), Keyword("type"), Symbol("int"), + Keyword("default"), 10], + ] + + param_names, param_defaults = _parse_construct_params(params_list) + + assert param_names == ["duration", "count"] + assert param_defaults["duration"] == 5.0 + assert param_defaults["count"] == 10 + + def test_construct_params_with_no_defaults(self): + """Test construct params where some have no default.""" + from .planner import _parse_construct_params + from .parser import Symbol, Keyword + + params_list = [ + [Symbol("required_param"), Keyword("type"), Symbol("string")], + [Symbol("optional_param"), Keyword("type"), Symbol("int"), + Keyword("default"), 42], + ] + + param_names, param_defaults = _parse_construct_params(params_list) + + assert param_names == ["required_param", "optional_param"] + assert param_defaults["required_param"] is None + assert param_defaults["optional_param"] == 42 + + +class TestParameterValidation: + """Test that unknown parameters are rejected.""" + + def test_effect_rejects_unknown_params(self): + """Effects should reject unknown parameters.""" + from .effect_loader import load_sexp_effect + import numpy as np + import pytest + + effect_code = ''' + (define-effect test_effect + :params ( + (brightness :type int :default 0 :desc "Brightness") + ) + frame) + ''' + name, process_frame, defaults, _ = load_sexp_effect(effect_code) + + # Create a test frame + frame = np.zeros((100, 100, 3), dtype=np.uint8) + state = {} + + # Valid param should work + result, _ = process_frame(frame, {"brightness": 10}, state) + assert isinstance(result, np.ndarray) + + # Unknown param should raise + with pytest.raises(ValueError) as exc_info: + process_frame(frame, {"unknown_param": 42}, state) + + assert "Unknown parameter 'unknown_param'" in str(exc_info.value) + assert "brightness" in str(exc_info.value) + + def test_effect_no_params_rejects_all(self): + """Effects with no params should reject any parameter.""" + from .effect_loader import load_sexp_effect + import numpy as np + import pytest + + effect_code = ''' + (define-effect no_params_effect + :params () + frame) + ''' + name, process_frame, defaults, _ = load_sexp_effect(effect_code) + + # Create a test frame + frame = np.zeros((100, 100, 3), dtype=np.uint8) + state = {} + + # Empty params should work + result, _ = process_frame(frame, {}, state) + assert isinstance(result, np.ndarray) + + # Any param should raise + with pytest.raises(ValueError) as exc_info: + process_frame(frame, {"any_param": 42}, state) + + assert "Unknown parameter 'any_param'" in str(exc_info.value) + assert "(none)" in str(exc_info.value) diff --git a/artdag/sexp/test_stage_planner.py b/artdag/sexp/test_stage_planner.py new file mode 100644 index 0000000..51d6d33 --- /dev/null +++ b/artdag/sexp/test_stage_planner.py @@ -0,0 +1,228 @@ +""" +Tests for stage-aware planning. + +Tests stage topological sorting, level computation, cache ID computation, +and plan metadata generation. +""" + +import pytest +from pathlib import Path + +from .parser import parse +from .compiler import compile_recipe, CompiledStage +from .planner import ( + create_plan, + StagePlan, + _compute_stage_levels, + _compute_stage_cache_id, +) + + +class TestStagePlanning: + """Test stage-aware plan creation.""" + + def test_stage_topological_sort_in_plan(self): + """Stages sorted by dependencies in plan.""" + recipe = ''' + (recipe "test-sort" + (def audio (source :path "test.mp3")) + + (stage :analyze + :outputs [beats] + (def beats (-> audio (analyze beats)))) + + (stage :output + :requires [:analyze] + :inputs [beats] + (-> audio (segment :times beats) (sequence)))) + ''' + compiled = compile_recipe(parse(recipe)) + # Note: create_plan needs recipe_dir for analysis, we'll test the ordering differently + assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output") + + def test_stage_level_computation(self): + """Independent stages get same level.""" + stages = [ + CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + CompiledStage(name="b", requires=[], inputs=[], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + CompiledStage(name="c", requires=["a", "b"], inputs=["x", "y"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + ] + levels = _compute_stage_levels(stages) + + assert levels["a"] == 0 + assert levels["b"] == 0 + assert levels["c"] == 1 # Depends on a and b + + def test_stage_level_chain(self): + """Chain stages get increasing levels.""" + stages = [ + CompiledStage(name="a", requires=[], inputs=[], outputs=["x"], + node_ids=["n1"], output_bindings={"x": "n1"}), + CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"], + node_ids=["n2"], output_bindings={"y": "n2"}), + CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"], + node_ids=["n3"], output_bindings={"z": "n3"}), + ] + levels = _compute_stage_levels(stages) + + assert levels["a"] == 0 + assert levels["b"] == 1 + assert levels["c"] == 2 + + def test_stage_cache_id_deterministic(self): + """Same stage = same cache ID.""" + stage = CompiledStage( + name="analyze", + requires=[], + inputs=[], + outputs=["beats"], + node_ids=["abc123"], + output_bindings={"beats": "abc123"}, + ) + + cache_id_1 = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key=None, + ) + cache_id_2 = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key=None, + ) + + assert cache_id_1 == cache_id_2 + + def test_stage_cache_id_includes_requires(self): + """Cache ID changes when required stage cache ID changes.""" + stage = CompiledStage( + name="process", + requires=["analyze"], + inputs=["beats"], + outputs=["result"], + node_ids=["def456"], + output_bindings={"result": "def456"}, + ) + + cache_id_1 = _compute_stage_cache_id( + stage, + stage_cache_ids={"analyze": "req_cache_a"}, + node_cache_ids={"def456": "node_def"}, + cluster_key=None, + ) + cache_id_2 = _compute_stage_cache_id( + stage, + stage_cache_ids={"analyze": "req_cache_b"}, + node_cache_ids={"def456": "node_def"}, + cluster_key=None, + ) + + # Different required stage cache IDs should produce different cache IDs + assert cache_id_1 != cache_id_2 + + def test_stage_cache_id_cluster_key(self): + """Cache ID changes with cluster key.""" + stage = CompiledStage( + name="analyze", + requires=[], + inputs=[], + outputs=["beats"], + node_ids=["abc123"], + output_bindings={"beats": "abc123"}, + ) + + cache_id_no_key = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key=None, + ) + cache_id_with_key = _compute_stage_cache_id( + stage, + stage_cache_ids={}, + node_cache_ids={"abc123": "nodeabc"}, + cluster_key="cluster123", + ) + + # Cluster key should change the cache ID + assert cache_id_no_key != cache_id_with_key + + +class TestStagePlanMetadata: + """Test stage metadata in execution plans.""" + + def test_plan_without_stages(self): + """Plan without stages has empty stage fields.""" + recipe = ''' + (recipe "no-stages" + (-> (source :path "test.mp3") (effect gain :amount 0.5))) + ''' + compiled = compile_recipe(parse(recipe)) + assert compiled.stages == [] + assert compiled.stage_order == [] + + +class TestStagePlanDataclass: + """Test StagePlan dataclass.""" + + def test_stage_plan_creation(self): + """StagePlan can be created with all fields.""" + from .planner import PlanStep + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={"analyzer": "beats"}, + inputs=["input1"], + cache_id="cache123", + level=0, + stage="analyze", + stage_cache_id="stage_cache_123", + ) + + stage_plan = StagePlan( + stage_name="analyze", + cache_id="stage_cache_123", + steps=[step], + requires=[], + output_bindings={"beats": "cache123"}, + level=0, + ) + + assert stage_plan.stage_name == "analyze" + assert stage_plan.cache_id == "stage_cache_123" + assert len(stage_plan.steps) == 1 + assert stage_plan.level == 0 + + +class TestExplicitDataRouting: + """Test that plan includes explicit data routing.""" + + def test_plan_step_includes_stage_info(self): + """PlanStep includes stage and stage_cache_id.""" + from .planner import PlanStep + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="cache123", + level=0, + stage="analyze", + stage_cache_id="stage_cache_abc", + ) + + sexp = step.to_sexp() + # Convert to string to check for stage info + from .parser import serialize + sexp_str = serialize(sexp) + + assert "stage" in sexp_str + assert "analyze" in sexp_str + assert "stage-cache-id" in sexp_str diff --git a/artdag/sexp/test_stage_scheduler.py b/artdag/sexp/test_stage_scheduler.py new file mode 100644 index 0000000..c7bab64 --- /dev/null +++ b/artdag/sexp/test_stage_scheduler.py @@ -0,0 +1,323 @@ +""" +Tests for stage-aware scheduler. + +Tests stage cache hit/miss, stage execution ordering, +and parallel stage support. +""" + +import pytest +import tempfile +from unittest.mock import Mock, MagicMock, patch + +from .scheduler import ( + StagePlanScheduler, + StageResult, + StagePlanResult, + create_stage_scheduler, + schedule_staged_plan, +) +from .planner import ExecutionPlanSexp, PlanStep, StagePlan +from .stage_cache import StageCache, StageCacheEntry, StageOutput + + +class TestStagePlanScheduler: + """Test stage-aware scheduling.""" + + def test_plan_without_stages_uses_regular_scheduling(self): + """Plans without stages fall back to regular scheduling.""" + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[], + output_step_id="output", + stage_plans=[], # No stages + ) + + scheduler = StagePlanScheduler() + # This will use PlanScheduler internally + # Without Celery, it just returns completed status + result = scheduler.schedule(plan) + + assert isinstance(result, StagePlanResult) + + def test_stage_cache_hit_skips_execution(self): + """Cached stage not re-executed.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + # Pre-populate cache + entry = StageCacheEntry( + stage_name="analyze", + cache_id="stage_cache_123", + outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")}, + ) + stage_cache.save_stage(entry) + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="step_cache", + level=0, + stage="analyze", + stage_cache_id="stage_cache_123", + ) + + stage_plan = StagePlan( + stage_name="analyze", + cache_id="stage_cache_123", + steps=[step], + requires=[], + output_bindings={"beats": "beats_out"}, + level=0, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[step], + output_step_id="step1", + stage_plans=[stage_plan], + stage_order=["analyze"], + stage_levels={"analyze": 0}, + stage_cache_ids={"analyze": "stage_cache_123"}, + ) + + scheduler = StagePlanScheduler(stage_cache=stage_cache) + result = scheduler.schedule(plan) + + assert result.stages_cached == 1 + assert result.stages_completed == 0 + + def test_stage_inputs_loaded_from_cache(self): + """Stage receives inputs from required stage cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + # Pre-populate upstream stage cache + upstream_entry = StageCacheEntry( + stage_name="analyze", + cache_id="upstream_cache", + outputs={"beats": StageOutput(cache_id="beats_data", output_type="analysis")}, + ) + stage_cache.save_stage(upstream_entry) + + # Steps for stages + upstream_step = PlanStep( + step_id="analyze_step", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="analyze_cache", + level=0, + stage="analyze", + stage_cache_id="upstream_cache", + ) + + downstream_step = PlanStep( + step_id="process_step", + node_type="SEGMENT", + config={}, + inputs=["analyze_step"], + cache_id="process_cache", + level=1, + stage="process", + stage_cache_id="downstream_cache", + ) + + upstream_plan = StagePlan( + stage_name="analyze", + cache_id="upstream_cache", + steps=[upstream_step], + requires=[], + output_bindings={"beats": "beats_data"}, + level=0, + ) + + downstream_plan = StagePlan( + stage_name="process", + cache_id="downstream_cache", + steps=[downstream_step], + requires=["analyze"], + output_bindings={"result": "process_cache"}, + level=1, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[upstream_step, downstream_step], + output_step_id="process_step", + stage_plans=[upstream_plan, downstream_plan], + stage_order=["analyze", "process"], + stage_levels={"analyze": 0, "process": 1}, + stage_cache_ids={"analyze": "upstream_cache", "process": "downstream_cache"}, + ) + + scheduler = StagePlanScheduler(stage_cache=stage_cache) + result = scheduler.schedule(plan) + + # Upstream should be cached, downstream executed + assert result.stages_cached == 1 + assert "analyze" in result.stage_results + assert result.stage_results["analyze"].status == "cached" + + def test_parallel_stages_same_level(self): + """Stages at same level can run in parallel.""" + step_a = PlanStep( + step_id="step_a", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="cache_a", + level=0, + stage="analyze-a", + stage_cache_id="stage_a", + ) + + step_b = PlanStep( + step_id="step_b", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="cache_b", + level=0, + stage="analyze-b", + stage_cache_id="stage_b", + ) + + stage_a = StagePlan( + stage_name="analyze-a", + cache_id="stage_a", + steps=[step_a], + requires=[], + output_bindings={"beats-a": "cache_a"}, + level=0, + ) + + stage_b = StagePlan( + stage_name="analyze-b", + cache_id="stage_b", + steps=[step_b], + requires=[], + output_bindings={"beats-b": "cache_b"}, + level=0, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[step_a, step_b], + output_step_id="step_b", + stage_plans=[stage_a, stage_b], + stage_order=["analyze-a", "analyze-b"], + stage_levels={"analyze-a": 0, "analyze-b": 0}, + stage_cache_ids={"analyze-a": "stage_a", "analyze-b": "stage_b"}, + ) + + scheduler = StagePlanScheduler() + # Group stages by level + stages_by_level = scheduler._group_stages_by_level(plan.stage_plans) + + # Both stages should be at level 0 + assert len(stages_by_level[0]) == 2 + + def test_stage_outputs_cached_after_execution(self): + """Stage outputs written to cache after completion.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + + step = PlanStep( + step_id="step1", + node_type="ANALYZE", + config={}, + inputs=[], + cache_id="step_cache", + level=0, + stage="analyze", + stage_cache_id="new_stage_cache", + ) + + stage_plan = StagePlan( + stage_name="analyze", + cache_id="new_stage_cache", + steps=[step], + requires=[], + output_bindings={"beats": "step_cache"}, + level=0, + ) + + plan = ExecutionPlanSexp( + plan_id="test_plan", + recipe_id="test_recipe", + recipe_hash="abc123", + steps=[step], + output_step_id="step1", + stage_plans=[stage_plan], + stage_order=["analyze"], + stage_levels={"analyze": 0}, + stage_cache_ids={"analyze": "new_stage_cache"}, + ) + + scheduler = StagePlanScheduler(stage_cache=stage_cache) + result = scheduler.schedule(plan) + + # Stage should now be cached + assert stage_cache.has_stage("new_stage_cache") + + +class TestStageResult: + """Test StageResult dataclass.""" + + def test_stage_result_creation(self): + """StageResult can be created with all fields.""" + result = StageResult( + stage_name="test", + cache_id="cache123", + status="completed", + step_results={}, + outputs={"out": "out_cache"}, + ) + + assert result.stage_name == "test" + assert result.status == "completed" + assert result.outputs["out"] == "out_cache" + + +class TestStagePlanResult: + """Test StagePlanResult dataclass.""" + + def test_stage_plan_result_creation(self): + """StagePlanResult can be created with all fields.""" + result = StagePlanResult( + plan_id="plan123", + status="completed", + stages_completed=2, + stages_cached=1, + stages_failed=0, + ) + + assert result.plan_id == "plan123" + assert result.stages_completed == 2 + assert result.stages_cached == 1 + + +class TestSchedulerFactory: + """Test scheduler factory functions.""" + + def test_create_stage_scheduler(self): + """create_stage_scheduler returns StagePlanScheduler.""" + scheduler = create_stage_scheduler() + assert isinstance(scheduler, StagePlanScheduler) + + def test_create_stage_scheduler_with_cache(self): + """create_stage_scheduler accepts stage_cache.""" + with tempfile.TemporaryDirectory() as tmpdir: + stage_cache = StageCache(tmpdir) + scheduler = create_stage_scheduler(stage_cache=stage_cache) + assert scheduler.stage_cache is stage_cache diff --git a/docs/EXECUTION_MODEL.md b/docs/EXECUTION_MODEL.md new file mode 100644 index 0000000..6779721 --- /dev/null +++ b/docs/EXECUTION_MODEL.md @@ -0,0 +1,384 @@ +# Art DAG 3-Phase Execution Model + +## Overview + +The execution model separates DAG processing into three distinct phases: + +``` +Recipe + Inputs → ANALYZE → Analysis Results + ↓ +Analysis + Recipe → PLAN → Execution Plan (with cache IDs) + ↓ +Execution Plan → EXECUTE → Cached Results +``` + +This separation enables: +1. **Incremental development** - Re-run recipes without reprocessing unchanged steps +2. **Parallel execution** - Independent steps run concurrently via Celery +3. **Deterministic caching** - Same inputs always produce same cache IDs +4. **Cost estimation** - Plan phase can estimate work before executing + +## Phase 1: Analysis + +### Purpose +Extract features from input media that inform downstream processing decisions. + +### Inputs +- Recipe YAML with input references +- Input media files (by content hash) + +### Outputs +Analysis results stored as JSON, keyed by input hash: + +```python +@dataclass +class AnalysisResult: + input_hash: str + features: Dict[str, Any] + # Audio features + beats: Optional[List[float]] # Beat times in seconds + downbeats: Optional[List[float]] # Bar-start times + tempo: Optional[float] # BPM + energy: Optional[List[Tuple[float, float]]] # (time, value) envelope + spectrum: Optional[Dict[str, List[Tuple[float, float]]]] # band envelopes + # Video features + duration: float + frame_rate: float + dimensions: Tuple[int, int] + motion_tempo: Optional[float] # Estimated BPM from motion +``` + +### Implementation +```python +class Analyzer: + def analyze(self, input_hash: str, features: List[str]) -> AnalysisResult: + """Extract requested features from input.""" + + def analyze_audio(self, path: Path) -> AudioFeatures: + """Extract all audio features using librosa/essentia.""" + + def analyze_video(self, path: Path) -> VideoFeatures: + """Extract video metadata and motion analysis.""" +``` + +### Caching +Analysis results are cached by: +``` +analysis_cache_id = SHA3-256(input_hash + sorted(feature_names)) +``` + +## Phase 2: Planning + +### Purpose +Convert recipe + analysis into a complete execution plan with pre-computed cache IDs. + +### Inputs +- Recipe YAML (parsed) +- Analysis results for all inputs +- Recipe parameters (user-supplied values) + +### Outputs +An ExecutionPlan containing ordered steps, each with a pre-computed cache ID: + +```python +@dataclass +class ExecutionStep: + step_id: str # Unique identifier + node_type: str # Primitive type (SOURCE, SEQUENCE, etc.) + config: Dict[str, Any] # Node configuration + input_steps: List[str] # IDs of steps this depends on + cache_id: str # Pre-computed: hash(inputs + config) + estimated_duration: float # Optional: for progress reporting + +@dataclass +class ExecutionPlan: + plan_id: str # Hash of entire plan + recipe_id: str # Source recipe + steps: List[ExecutionStep] # Topologically sorted + analysis: Dict[str, AnalysisResult] + output_step: str # Final step ID + + def compute_cache_ids(self): + """Compute all cache IDs in dependency order.""" +``` + +### Cache ID Computation + +Cache IDs are computed in topological order so each step's cache ID +incorporates its inputs' cache IDs: + +```python +def compute_cache_id(step: ExecutionStep, resolved_inputs: Dict[str, str]) -> str: + """ + Cache ID = SHA3-256( + node_type + + canonical_json(config) + + sorted([input_cache_ids]) + ) + """ + components = [ + step.node_type, + json.dumps(step.config, sort_keys=True), + *sorted(resolved_inputs[s] for s in step.input_steps) + ] + return sha3_256('|'.join(components)) +``` + +### Plan Generation + +The planner expands recipe nodes into concrete steps: + +1. **SOURCE nodes** → Direct step with input hash as cache ID +2. **ANALYZE nodes** → Step that references analysis results +3. **TRANSFORM nodes** → Step with static config +4. **TRANSFORM_DYNAMIC nodes** → Expanded to per-frame steps (or use BIND output) +5. **SEQUENCE nodes** → Tree reduction for parallel composition +6. **MAP nodes** → Expanded to N parallel steps + reduction + +### Tree Reduction for Composition + +Instead of sequential pairwise composition: +``` +A → B → C → D (3 sequential steps) +``` + +Use parallel tree reduction: +``` +A ─┬─ AB ─┬─ ABCD +B ─┘ │ +C ─┬─ CD ─┘ +D ─┘ + +Level 0: [A, B, C, D] (4 parallel) +Level 1: [AB, CD] (2 parallel) +Level 2: [ABCD] (1 final) +``` + +This reduces O(N) to O(log N) levels. + +## Phase 3: Execution + +### Purpose +Execute the plan, skipping steps with cached results. + +### Inputs +- ExecutionPlan with pre-computed cache IDs +- Cache state (which IDs already exist) + +### Process + +1. **Claim Check**: For each step, atomically check if result is cached +2. **Task Dispatch**: Uncached steps dispatched to Celery workers +3. **Parallel Execution**: Independent steps run concurrently +4. **Result Storage**: Each step stores result with its cache ID +5. **Progress Tracking**: Real-time status updates + +### Hash-Based Task Claiming + +Prevents duplicate work when multiple workers process the same plan: + +```lua +-- Redis Lua script for atomic claim +local key = KEYS[1] +local data = redis.call('GET', key) +if data then + local status = cjson.decode(data) + if status.status == 'running' or + status.status == 'completed' or + status.status == 'cached' then + return 0 -- Already claimed/done + end +end +local claim_data = ARGV[1] +local ttl = tonumber(ARGV[2]) +redis.call('SETEX', key, ttl, claim_data) +return 1 -- Successfully claimed +``` + +### Celery Task Structure + +```python +@app.task(bind=True) +def execute_step(self, step_json: str, plan_id: str) -> dict: + """Execute a single step with caching.""" + step = ExecutionStep.from_json(step_json) + + # Check cache first + if cache.has(step.cache_id): + return {'status': 'cached', 'cache_id': step.cache_id} + + # Try to claim this work + if not claim_task(step.cache_id, self.request.id): + # Another worker is handling it, wait for result + return wait_for_result(step.cache_id) + + # Do the work + executor = get_executor(step.node_type) + input_paths = [cache.get(s) for s in step.input_steps] + output_path = cache.get_output_path(step.cache_id) + + result_path = executor.execute(step.config, input_paths, output_path) + cache.put(step.cache_id, result_path) + + return {'status': 'completed', 'cache_id': step.cache_id} +``` + +### Execution Orchestration + +```python +class PlanExecutor: + def execute(self, plan: ExecutionPlan) -> ExecutionResult: + """Execute plan with parallel Celery tasks.""" + + # Group steps by level (steps at same level can run in parallel) + levels = self.compute_dependency_levels(plan.steps) + + for level_steps in levels: + # Dispatch all steps at this level + tasks = [ + execute_step.delay(step.to_json(), plan.plan_id) + for step in level_steps + if not self.cache.has(step.cache_id) + ] + + # Wait for level completion + results = [task.get() for task in tasks] + + return self.collect_results(plan) +``` + +## Data Flow Example + +### Recipe: beat-cuts +```yaml +nodes: + - id: music + type: SOURCE + config: { input: true } + + - id: beats + type: ANALYZE + config: { feature: beats } + inputs: [music] + + - id: videos + type: SOURCE_LIST + config: { input: true } + + - id: slices + type: MAP + config: { operation: RANDOM_SLICE } + inputs: + items: videos + timing: beats + + - id: final + type: SEQUENCE + inputs: [slices] +``` + +### Phase 1: Analysis +```python +# Input: music file with hash abc123 +analysis = { + 'abc123': AnalysisResult( + beats=[0.0, 0.48, 0.96, 1.44, ...], + tempo=125.0, + duration=180.0 + ) +} +``` + +### Phase 2: Planning +```python +# Expands MAP into concrete steps +plan = ExecutionPlan( + steps=[ + # Source steps + ExecutionStep(id='music', cache_id='abc123', ...), + ExecutionStep(id='video_0', cache_id='def456', ...), + ExecutionStep(id='video_1', cache_id='ghi789', ...), + + # Slice steps (one per beat group) + ExecutionStep(id='slice_0', cache_id='hash(video_0+timing)', ...), + ExecutionStep(id='slice_1', cache_id='hash(video_1+timing)', ...), + ... + + # Tree reduction for sequence + ExecutionStep(id='seq_0_1', inputs=['slice_0', 'slice_1'], ...), + ExecutionStep(id='seq_2_3', inputs=['slice_2', 'slice_3'], ...), + ExecutionStep(id='seq_final', inputs=['seq_0_1', 'seq_2_3'], ...), + ] +) +``` + +### Phase 3: Execution +``` +Level 0: [music, video_0, video_1] → all cached (SOURCE) +Level 1: [slice_0, slice_1, slice_2, slice_3] → 4 parallel tasks +Level 2: [seq_0_1, seq_2_3] → 2 parallel SEQUENCE tasks +Level 3: [seq_final] → 1 final SEQUENCE task +``` + +## File Structure + +``` +artdag/ +├── artdag/ +│ ├── analysis/ +│ │ ├── __init__.py +│ │ ├── analyzer.py # Main Analyzer class +│ │ ├── audio.py # Audio feature extraction +│ │ └── video.py # Video feature extraction +│ ├── planning/ +│ │ ├── __init__.py +│ │ ├── planner.py # RecipePlanner class +│ │ ├── schema.py # ExecutionPlan, ExecutionStep +│ │ └── tree_reduction.py # Parallel composition optimizer +│ └── execution/ +│ ├── __init__.py +│ ├── executor.py # PlanExecutor class +│ └── claiming.py # Hash-based task claiming + +art-celery/ +├── tasks/ +│ ├── __init__.py +│ ├── analyze.py # analyze_inputs task +│ ├── plan.py # generate_plan task +│ ├── execute.py # execute_step task +│ └── orchestrate.py # run_plan (coordinates all) +├── claiming.py # Redis Lua scripts +└── ... +``` + +## CLI Interface + +```bash +# Full pipeline +artdag run-recipe recipes/beat-cuts/recipe.yaml \ + -i music:abc123 \ + -i videos:def456,ghi789 + +# Phase by phase +artdag analyze recipes/beat-cuts/recipe.yaml -i music:abc123 +# → outputs analysis.json + +artdag plan recipes/beat-cuts/recipe.yaml --analysis analysis.json +# → outputs plan.json + +artdag execute plan.json +# → runs with caching, skips completed steps + +# Dry run (show what would execute) +artdag execute plan.json --dry-run +# → shows which steps are cached vs need execution +``` + +## Benefits + +1. **Development Speed**: Change recipe, re-run → only affected steps execute +2. **Parallelism**: Independent steps run on multiple Celery workers +3. **Reproducibility**: Same inputs + recipe = same cache IDs = same output +4. **Visibility**: Plan shows exactly what will happen before execution +5. **Cost Control**: Estimate compute before committing resources +6. **Fault Tolerance**: Failed runs resume from last successful step diff --git a/docs/IPFS_PRIMARY_ARCHITECTURE.md b/docs/IPFS_PRIMARY_ARCHITECTURE.md new file mode 100644 index 0000000..2e53aaf --- /dev/null +++ b/docs/IPFS_PRIMARY_ARCHITECTURE.md @@ -0,0 +1,443 @@ +# IPFS-Primary Architecture (Sketch) + +A simplified L1 architecture for large-scale distributed rendering where IPFS is the primary data store. + +## Current vs Simplified + +| Component | Current | Simplified | +|-----------|---------|------------| +| Local cache | Custom, per-worker | IPFS node handles it | +| Redis content_index | content_hash → node_id | Eliminated | +| Redis ipfs_index | content_hash → ipfs_cid | Eliminated | +| Step inputs | File paths | IPFS CIDs | +| Step outputs | File path + CID | Just CID | +| Cache lookup | Local → Redis → IPFS | Just IPFS | + +## Core Principle + +**Steps receive CIDs, produce CIDs. No file paths cross machine boundaries.** + +``` +Step input: [cid1, cid2, ...] +Step output: cid_out +``` + +## Worker Architecture + +Each worker runs: + +``` +┌─────────────────────────────────────┐ +│ Worker Node │ +│ │ +│ ┌───────────┐ ┌──────────────┐ │ +│ │ Celery │────│ IPFS Node │ │ +│ │ Worker │ │ (local) │ │ +│ └───────────┘ └──────────────┘ │ +│ │ │ │ +│ │ ┌─────┴─────┐ │ +│ │ │ Local │ │ +│ │ │ Blockstore│ │ +│ │ └───────────┘ │ +│ │ │ +│ ┌────┴────┐ │ +│ │ /tmp │ (ephemeral workspace) │ +│ └─────────┘ │ +└─────────────────────────────────────┘ + │ + │ IPFS libp2p + ▼ + ┌─────────────┐ + │ Other IPFS │ + │ Nodes │ + └─────────────┘ +``` + +## Execution Flow + +### 1. Plan Generation (unchanged) + +```python +plan = planner.plan(recipe, input_hashes) +# plan.steps[].cache_id = deterministic hash +``` + +### 2. Input Registration + +Before execution, register inputs with IPFS: + +```python +input_cids = {} +for name, path in inputs.items(): + cid = ipfs.add(path) + input_cids[name] = cid + +# Plan now carries CIDs +plan.input_cids = input_cids +``` + +### 3. Step Execution + +```python +@celery.task +def execute_step(step_json: str, input_cids: dict[str, str]) -> str: + """Execute step, return output CID.""" + step = ExecutionStep.from_json(step_json) + + # Check if already computed (by cache_id as IPNS key or DHT lookup) + existing_cid = ipfs.resolve(f"/ipns/{step.cache_id}") + if existing_cid: + return existing_cid + + # Fetch inputs from IPFS → local temp files + input_paths = [] + for input_step_id in step.input_steps: + cid = input_cids[input_step_id] + path = ipfs.get(cid, f"/tmp/{cid}") # IPFS node caches automatically + input_paths.append(path) + + # Execute + output_path = f"/tmp/{step.cache_id}.mkv" + executor = get_executor(step.node_type) + executor.execute(step.config, input_paths, output_path) + + # Add output to IPFS + output_cid = ipfs.add(output_path) + + # Publish cache_id → CID mapping (optional, for cache hits) + ipfs.name_publish(step.cache_id, output_cid) + + # Cleanup temp files + cleanup_temp(input_paths + [output_path]) + + return output_cid +``` + +### 4. Orchestration + +```python +@celery.task +def run_plan(plan_json: str) -> str: + """Execute plan, return final output CID.""" + plan = ExecutionPlan.from_json(plan_json) + + # CID results accumulate as steps complete + cid_results = dict(plan.input_cids) + + for level in plan.get_steps_by_level(): + # Parallel execution within level + tasks = [] + for step in level: + step_input_cids = { + sid: cid_results[sid] + for sid in step.input_steps + } + tasks.append(execute_step.s(step.to_json(), step_input_cids)) + + # Wait for level to complete + results = group(tasks).apply_async().get() + + # Record output CIDs + for step, cid in zip(level, results): + cid_results[step.step_id] = cid + + return cid_results[plan.output_step] +``` + +## What's Eliminated + +### No more Redis indexes + +```python +# BEFORE: Complex index management +self._set_content_index(content_hash, node_id) # Redis + local +self._set_ipfs_index(content_hash, ipfs_cid) # Redis + local +node_id = self._get_content_index(content_hash) # Check Redis, fallback local + +# AFTER: Just CIDs +output_cid = ipfs.add(output_path) +return output_cid +``` + +### No more local cache management + +```python +# BEFORE: Custom cache with entries, metadata, cleanup +cache.put(node_id, source_path, node_type, execution_time) +cache.get(node_id) +cache.has(node_id) +cache.cleanup_lru() + +# AFTER: IPFS handles it +ipfs.add(path) # Store +ipfs.get(cid) # Retrieve (cached by IPFS node) +ipfs.pin(cid) # Keep permanently +ipfs.gc() # Cleanup unpinned +``` + +### No more content_hash vs node_id confusion + +```python +# BEFORE: Two identifiers +content_hash = sha3_256(file_bytes) # What the file IS +node_id = cache_id # What computation produced it +# Need indexes to map between them + +# AFTER: One identifier +cid = ipfs.add(file) # Content-addressed, includes hash +# CID IS the identifier +``` + +## Cache Hit Detection + +Two options: + +### Option A: IPNS (mutable names) + +```python +# Publish: cache_id → CID +ipfs.name_publish(key=cache_id, value=output_cid) + +# Lookup before executing +existing = ipfs.name_resolve(cache_id) +if existing: + return existing # Cache hit +``` + +### Option B: DHT record + +```python +# Store in DHT: cache_id → CID +ipfs.dht_put(cache_id, output_cid) + +# Lookup +existing = ipfs.dht_get(cache_id) +``` + +### Option C: Redis (minimal) + +Keep Redis just for cache_id → CID mapping: + +```python +# Store +redis.hset("artdag:cache", cache_id, output_cid) + +# Lookup +existing = redis.hget("artdag:cache", cache_id) +``` + +This is simpler than current approach - one hash, one mapping, no content_hash/node_id confusion. + +## Claiming (Preventing Duplicate Work) + +Still need Redis for atomic claiming: + +```python +# Claim before executing +claimed = redis.set(f"artdag:claim:{cache_id}", worker_id, nx=True, ex=300) +if not claimed: + # Another worker is doing it - wait for result + return wait_for_result(cache_id) +``` + +Or use IPFS pubsub for coordination. + +## Data Flow Diagram + +``` + ┌─────────────┐ + │ Recipe │ + │ + Inputs │ + └──────┬──────┘ + │ + ▼ + ┌─────────────┐ + │ Planner │ + │ (compute │ + │ cache_ids) │ + └──────┬──────┘ + │ + ▼ + ┌─────────────────────────────────┐ + │ ExecutionPlan │ + │ - steps with cache_ids │ + │ - input_cids (from ipfs.add) │ + └─────────────────┬───────────────┘ + │ + ┌────────────┼────────────┐ + ▼ ▼ ▼ + ┌────────┐ ┌────────┐ ┌────────┐ + │Worker 1│ │Worker 2│ │Worker 3│ + │ │ │ │ │ │ + │ IPFS │◄──│ IPFS │◄──│ IPFS │ + │ Node │──►│ Node │──►│ Node │ + └───┬────┘ └───┬────┘ └───┬────┘ + │ │ │ + └────────────┼────────────┘ + │ + ▼ + ┌─────────────┐ + │ Final CID │ + │ (output) │ + └─────────────┘ +``` + +## Benefits + +1. **Simpler code** - No custom cache, no dual indexes +2. **Automatic distribution** - IPFS handles replication +3. **Content verification** - CIDs are self-verifying +4. **Scalable** - Add workers = add IPFS nodes = more cache capacity +5. **Resilient** - Any node can serve any content + +## Tradeoffs + +1. **IPFS dependency** - Every worker needs IPFS node +2. **Initial fetch latency** - First fetch may be slower than local disk +3. **IPNS latency** - Name resolution can be slow (Option C avoids this) + +## Trust Domains (Cluster Key) + +Systems can share work through IPFS, but how do you trust them? + +**Problem:** A malicious system could return wrong CIDs for computed steps. + +**Solution:** Cluster key creates isolated trust domains: + +```bash +export ARTDAG_CLUSTER_KEY="my-secret-shared-key" +``` + +**How it works:** +- The cluster key is mixed into all cache_id computations +- Systems with the same key produce the same cache_ids +- Systems with different keys have separate cache namespaces +- Only share the key with trusted partners + +``` +cache_id = SHA3-256(cluster_key + node_type + config + inputs) +``` + +**Trust model:** +| Scenario | Same Key? | Can Share Work? | +|----------|-----------|-----------------| +| Same organization | Yes | Yes | +| Trusted partner | Yes (shared) | Yes | +| Unknown system | No | No (different cache_ids) | + +**Configuration:** +```yaml +# docker-compose.yml +environment: + - ARTDAG_CLUSTER_KEY=your-secret-key-here +``` + +**Programmatic:** +```python +from artdag.planning.schema import set_cluster_key +set_cluster_key("my-secret-key") +``` + +## Implementation + +The simplified architecture is implemented in `art-celery/`: + +| File | Purpose | +|------|---------| +| `hybrid_state.py` | Hybrid state manager (Redis + IPNS) | +| `tasks/execute_cid.py` | Step execution with CIDs | +| `tasks/analyze_cid.py` | Analysis with CIDs | +| `tasks/orchestrate_cid.py` | Full pipeline orchestration | + +### Key Functions + +**Registration (local → IPFS):** +- `register_input_cid(path)` → `{cid, content_hash}` +- `register_recipe_cid(path)` → `{cid, name, version}` + +**Analysis:** +- `analyze_input_cid(input_cid, input_hash, features)` → `{analysis_cid}` + +**Planning:** +- `generate_plan_cid(recipe_cid, input_cids, input_hashes, analysis_cids)` → `{plan_cid}` + +**Execution:** +- `execute_step_cid(step_json, input_cids)` → `{cid}` +- `execute_plan_from_cid(plan_cid, input_cids)` → `{output_cid}` + +**Full Pipeline:** +- `run_recipe_cid(recipe_cid, input_cids, input_hashes)` → `{output_cid, all_cids}` +- `run_from_local(recipe_path, input_paths)` → registers + runs + +### Hybrid State Manager + +For distributed L1 coordination, use the `HybridStateManager` which provides: + +**Fast path (local Redis):** +- `get_cached_cid(cache_id)` / `set_cached_cid(cache_id, cid)` - microsecond lookups +- `try_claim(cache_id, worker_id)` / `release_claim(cache_id)` - atomic claiming +- `get_analysis_cid()` / `set_analysis_cid()` - analysis cache +- `get_plan_cid()` / `set_plan_cid()` - plan cache +- `get_run_cid()` / `set_run_cid()` - run cache + +**Slow path (background IPNS sync):** +- Periodically syncs local state with global IPNS state (default: every 30s) +- Pulls new entries from remote nodes +- Pushes local updates to IPNS + +**Configuration:** +```bash +# Enable IPNS sync +export ARTDAG_IPNS_SYNC=true +export ARTDAG_IPNS_SYNC_INTERVAL=30 # seconds +``` + +**Usage:** +```python +from hybrid_state import get_state_manager + +state = get_state_manager() + +# Fast local lookup +cid = state.get_cached_cid(cache_id) + +# Fast local write (synced in background) +state.set_cached_cid(cache_id, output_cid) + +# Atomic claim +if state.try_claim(cache_id, worker_id): + # We have the lock + ... +``` + +**Trade-offs:** +- Local Redis: Fast (microseconds), single node +- IPNS sync: Slow (seconds), eventually consistent across nodes +- Duplicate work: Accepted (idempotent - same inputs → same CID) + +### Redis Usage (minimal) + +| Key | Type | Purpose | +|-----|------|---------| +| `artdag:cid_cache` | Hash | cache_id → output CID | +| `artdag:analysis_cache` | Hash | input_hash:features → analysis CID | +| `artdag:plan_cache` | Hash | plan_id → plan CID | +| `artdag:run_cache` | Hash | run_id → output CID | +| `artdag:claim:{cache_id}` | String | worker_id (TTL 5 min) | + +## Migration Path + +1. Keep current system working ✓ +2. Add CID-based tasks ✓ + - `execute_cid.py` ✓ + - `analyze_cid.py` ✓ + - `orchestrate_cid.py` ✓ +3. Add `--ipfs-primary` flag to CLI ✓ +4. Add hybrid state manager for L1 coordination ✓ +5. Gradually deprecate local cache code +6. Remove old tasks when CID versions are stable + +## See Also + +- [L1_STORAGE.md](L1_STORAGE.md) - Current L1 architecture +- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase model diff --git a/docs/L1_STORAGE.md b/docs/L1_STORAGE.md new file mode 100644 index 0000000..c371329 --- /dev/null +++ b/docs/L1_STORAGE.md @@ -0,0 +1,181 @@ +# L1 Distributed Storage Architecture + +This document describes how data is stored when running artdag on L1 (the distributed rendering layer). + +## Overview + +L1 uses four storage systems working together: + +| System | Purpose | Data Stored | +|--------|---------|-------------| +| **Local Cache** | Hot storage (fast access) | Media files, plans, analysis | +| **IPFS** | Durable content-addressed storage | All media outputs | +| **Redis** | Coordination & indexes | Claims, mappings, run status | +| **PostgreSQL** | Metadata & ownership | User data, provenance | + +## Storage Flow + +When a step executes on L1: + +``` +1. Executor produces output file +2. Store in local cache (fast) +3. Compute content_hash = SHA3-256(file) +4. Upload to IPFS → get ipfs_cid +5. Update indexes: + - content_hash → node_id (Redis + local) + - content_hash → ipfs_cid (Redis + local) +``` + +Every intermediate step output (SEGMENT, SEQUENCE, etc.) gets its own IPFS CID. + +## Local Cache + +Hot storage on each worker node: + +``` +cache_dir/ + index.json # Cache metadata + content_index.json # content_hash → node_id + ipfs_index.json # content_hash → ipfs_cid + plans/ + {plan_id}.json # Cached execution plans + analysis/ + {hash}.json # Analysis results + {node_id}/ + output.mkv # Media output + metadata.json # CacheEntry metadata +``` + +## IPFS - Durable Media Storage + +All media files are stored in IPFS for durability and content-addressing. + +**Supported pinning providers:** +- Pinata +- web3.storage +- NFT.Storage +- Infura IPFS +- Filebase (S3-compatible) +- Storj (decentralized) +- Local IPFS node + +**Configuration:** +```bash +IPFS_API=/ip4/127.0.0.1/tcp/5001 # Local IPFS daemon +``` + +## Redis - Coordination + +Redis handles distributed coordination across workers. + +### Key Patterns + +| Key | Type | Purpose | +|-----|------|---------| +| `artdag:run:{run_id}` | String | Run status, timestamps, celery task ID | +| `artdag:content_index` | Hash | content_hash → node_id mapping | +| `artdag:ipfs_index` | Hash | content_hash → ipfs_cid mapping | +| `artdag:claim:{cache_id}` | String | Task claiming (prevents duplicate work) | + +### Task Claiming + +Lua scripts ensure atomic claiming across workers: + +``` +Status flow: PENDING → CLAIMED → RUNNING → COMPLETED/CACHED/FAILED +TTL: 5 minutes for claims, 1 hour for results +``` + +This prevents two workers from executing the same step. + +## PostgreSQL - Metadata + +Stores ownership, provenance, and sharing metadata. + +### Tables + +```sql +-- Core cache (shared) +cache_items (content_hash, ipfs_cid, created_at) + +-- Per-user ownership +item_types (content_hash, actor_id, type, metadata) + +-- Run cache (deterministic identity) +run_cache ( + run_id, -- SHA3-256(sorted_inputs + recipe) + output_hash, + ipfs_cid, + provenance_cid, + recipe, inputs, actor_id +) + +-- Storage backends +storage_backends (actor_id, provider_type, config, capacity_gb) + +-- What's stored where +storage_pins (content_hash, storage_id, ipfs_cid, pin_type) +``` + +## Cache Lookup Flow + +When a worker needs a file: + +``` +1. Check local cache by cache_id (fastest) +2. Check Redis content_index: content_hash → node_id +3. Check PostgreSQL cache_items +4. Retrieve from IPFS by CID +5. Store in local cache for next hit +``` + +## Local vs L1 Comparison + +| Feature | Local Testing | L1 Distributed | +|---------|---------------|----------------| +| Local cache | Yes | Yes | +| IPFS | No | Yes | +| Redis | No | Yes | +| PostgreSQL | No | Yes | +| Multi-worker | No | Yes | +| Task claiming | No | Yes (Lua scripts) | +| Durability | Filesystem only | IPFS + PostgreSQL | + +## Content Addressing + +All storage uses SHA3-256 (quantum-resistant): + +- **Files:** `content_hash = SHA3-256(file_bytes)` +- **Computation:** `cache_id = SHA3-256(type + config + input_hashes)` +- **Run identity:** `run_id = SHA3-256(sorted_inputs + recipe)` +- **Plans:** `plan_id = SHA3-256(recipe + inputs + analysis)` + +This ensures: +- Same inputs → same outputs (reproducibility) +- Automatic deduplication across workers +- Content verification (tamper detection) + +## Configuration + +Default locations: + +```bash +# Local cache +~/.artdag/cache # Default +/data/cache # Docker + +# Redis +redis://localhost:6379/5 + +# PostgreSQL +postgresql://user:pass@host/artdag + +# IPFS +/ip4/127.0.0.1/tcp/5001 +``` + +## See Also + +- [OFFLINE_TESTING.md](OFFLINE_TESTING.md) - Local testing without L1 +- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase execution model diff --git a/docs/OFFLINE_TESTING.md b/docs/OFFLINE_TESTING.md new file mode 100644 index 0000000..68d1559 --- /dev/null +++ b/docs/OFFLINE_TESTING.md @@ -0,0 +1,211 @@ +# Offline Testing Strategy + +This document describes how to test artdag locally without requiring Redis, IPFS, Celery, or any external distributed infrastructure. + +## Overview + +The artdag system uses a **3-Phase Execution Model** that enables complete offline testing: + +1. **Analysis** - Extract features from input media +2. **Planning** - Generate deterministic execution plan with pre-computed cache IDs +3. **Execution** - Run plan steps, skipping cached results + +This separation allows testing each phase independently and running full pipelines locally. + +## Quick Start + +Run a full offline test with a video file: + +```bash +./examples/test_local.sh ../artdag-art-source/dog.mkv +``` + +This will: +1. Compute the SHA3-256 hash of the input video +2. Run the `simple_sequence` recipe +3. Store all outputs in `test_cache/` + +## Test Scripts + +### `test_local.sh` - Full Pipeline Test + +Location: `./examples/test_local.sh` + +Runs the complete artdag pipeline offline with a real video file. + +**Usage:** +```bash +./examples/test_local.sh +``` + +**Example:** +```bash +./examples/test_local.sh ../artdag-art-source/dog.mkv +``` + +**What it does:** +- Computes content hash of input video +- Runs `artdag run-recipe` with `simple_sequence.yaml` +- Stores outputs in `test_cache/` directory +- No external services required + +### `test_plan.py` - Planning Phase Test + +Location: `./examples/test_plan.py` + +Tests the planning phase without requiring any media files. + +**Usage:** +```bash +python3 examples/test_plan.py +``` + +**What it tests:** +- Recipe loading and YAML parsing +- Execution plan generation +- Cache ID computation (deterministic) +- Multi-level parallel step organization +- Human-readable step names +- Multi-output support + +**Output:** +- Prints plan structure to console +- Saves full plan to `test_plan_output.json` + +### `simple_sequence.yaml` - Sample Recipe + +Location: `./examples/simple_sequence.yaml` + +A simple recipe for testing that: +- Takes a video input +- Extracts two segments (0-2s and 5-7s) +- Concatenates them with SEQUENCE + +## Test Outputs + +All test outputs are stored locally and git-ignored: + +| Output | Description | +|--------|-------------| +| `test_cache/` | Cached execution results (media files, analysis, plans) | +| `test_cache/plans/` | Cached execution plans by plan_id | +| `test_cache/analysis/` | Cached analysis results by input hash | +| `test_plan_output.json` | Generated execution plan from `test_plan.py` | + +## Unit Tests + +The project includes a comprehensive pytest test suite in `tests/`: + +```bash +# Run all unit tests +pytest + +# Run specific test file +pytest tests/test_dag.py +pytest tests/test_engine.py +pytest tests/test_cache.py +``` + +## Testing Each Phase + +### Phase 1: Analysis Only + +Extract features without full execution: + +```bash +python3 -m artdag.cli analyze -i :@ --features beats,energy +``` + +### Phase 2: Planning Only + +Generate an execution plan (no media needed): + +```bash +python3 -m artdag.cli plan -i : +``` + +Or use the test script: + +```bash +python3 examples/test_plan.py +``` + +### Phase 3: Execution Only + +Execute a pre-generated plan: + +```bash +python3 -m artdag.cli execute plan.json +``` + +With dry-run to see what would execute: + +```bash +python3 -m artdag.cli execute plan.json --dry-run +``` + +## Key Testing Features + +### Content Addressing + +All nodes have deterministic IDs computed as: +``` +SHA3-256(type + config + sorted(input_IDs)) +``` + +Same inputs always produce same cache IDs, enabling: +- Reproducibility across runs +- Automatic deduplication +- Incremental execution (only changed steps run) + +### Local Caching + +The `test_cache/` directory stores: +- `plans/{plan_id}.json` - Execution plans (deterministic hash of recipe + inputs + analysis) +- `analysis/{hash}.json` - Analysis results (audio beats, tempo, energy) +- `{cache_id}/output.mkv` - Media outputs from each step + +Subsequent test runs automatically skip cached steps. Plans are cached by their `plan_id`, which is a SHA3-256 hash of the recipe, input hashes, and analysis results - so the same recipe with the same inputs always produces the same plan. + +### No External Dependencies + +Offline testing requires: +- Python 3.9+ +- ffmpeg (for media processing) +- No Redis, IPFS, Celery, or network access + +## Debugging Tips + +1. **Check cache contents:** + ```bash + ls -la test_cache/ + ls -la test_cache/plans/ + ``` + +2. **View cached plan:** + ```bash + cat test_cache/plans/*.json | python3 -m json.tool | head -50 + ``` + +3. **View execution plan structure:** + ```bash + cat test_plan_output.json | python3 -m json.tool + ``` + +4. **Run with verbose output:** + ```bash + python3 -m artdag.cli run-recipe examples/simple_sequence.yaml \ + -i "video:HASH@path" \ + --cache-dir test_cache \ + -v + ``` + +5. **Dry-run to see what would execute:** + ```bash + python3 -m artdag.cli execute plan.json --dry-run + ``` + +## See Also + +- [L1_STORAGE.md](L1_STORAGE.md) - Distributed storage on L1 (IPFS, Redis, PostgreSQL) +- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase execution model diff --git a/effects/identity/README.md b/effects/identity/README.md new file mode 100644 index 0000000..afb6cb0 --- /dev/null +++ b/effects/identity/README.md @@ -0,0 +1,35 @@ +# Identity Effect + +The identity effect returns its input unchanged. It serves as the foundational primitive in the effects registry. + +## Purpose + +- **Testing**: Verify the effects pipeline is working correctly +- **No-op placeholder**: Use when an effect slot requires a value but no transformation is needed +- **Composition base**: The neutral element for effect composition + +## Signature + +``` +identity(input) → input +``` + +## Properties + +- **Idempotent**: `identity(identity(x)) = identity(x)` +- **Neutral**: For any effect `f`, `identity ∘ f = f ∘ identity = f` + +## Implementation + +```python +def identity(input): + return input +``` + +## Content Hash + +The identity effect is content-addressed by its behavior: given any input, the output hash equals the input hash. + +## Owner + +Registered by `@giles@artdag.rose-ash.com` diff --git a/effects/identity/requirements.txt b/effects/identity/requirements.txt new file mode 100644 index 0000000..805e561 --- /dev/null +++ b/effects/identity/requirements.txt @@ -0,0 +1,2 @@ +# Identity effect has no dependencies +# It's a pure function: identity(x) = x diff --git a/examples/simple_sequence.yaml b/examples/simple_sequence.yaml new file mode 100644 index 0000000..d4ce009 --- /dev/null +++ b/examples/simple_sequence.yaml @@ -0,0 +1,42 @@ +# Simple sequence recipe - concatenates segments from a single input video +name: simple_sequence +version: "1.0" +description: "Split input into segments and concatenate them" +owner: test@local + +dag: + nodes: + # Input source - variable (provided at runtime) + - id: video + type: SOURCE + config: + input: true + name: "Input Video" + description: "The video to process" + + # Extract first 2 seconds + - id: seg1 + type: SEGMENT + config: + start: 0.0 + end: 2.0 + inputs: + - video + + # Extract seconds 5-7 + - id: seg2 + type: SEGMENT + config: + start: 5.0 + end: 7.0 + inputs: + - video + + # Concatenate the segments + - id: output + type: SEQUENCE + inputs: + - seg1 + - seg2 + + output: output diff --git a/examples/test_local.sh b/examples/test_local.sh new file mode 100755 index 0000000..083f718 --- /dev/null +++ b/examples/test_local.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Local testing script for artdag +# Tests the 3-phase execution without Redis/IPFS + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +ARTDAG_DIR="$(dirname "$SCRIPT_DIR")" +CACHE_DIR="${ARTDAG_DIR}/test_cache" +RECIPE="${SCRIPT_DIR}/simple_sequence.yaml" + +# Check for input video +if [ -z "$1" ]; then + echo "Usage: $0 " + echo "" + echo "Example:" + echo " $0 /path/to/test_video.mp4" + exit 1 +fi + +VIDEO_PATH="$1" +if [ ! -f "$VIDEO_PATH" ]; then + echo "Error: Video file not found: $VIDEO_PATH" + exit 1 +fi + +# Compute content hash of input +echo "=== Computing input hash ===" +VIDEO_HASH=$(python3 -c " +import hashlib +with open('$VIDEO_PATH', 'rb') as f: + print(hashlib.sha3_256(f.read()).hexdigest()) +") +echo "Input hash: ${VIDEO_HASH:0:16}..." + +# Change to artdag directory +cd "$ARTDAG_DIR" + +# Run the full pipeline +echo "" +echo "=== Running artdag run-recipe ===" +echo "Recipe: $RECIPE" +echo "Input: video:${VIDEO_HASH:0:16}...@$VIDEO_PATH" +echo "Cache: $CACHE_DIR" +echo "" + +python3 -m artdag.cli run-recipe "$RECIPE" \ + -i "video:${VIDEO_HASH}@${VIDEO_PATH}" \ + --cache-dir "$CACHE_DIR" + +echo "" +echo "=== Done ===" +echo "Cache directory: $CACHE_DIR" +echo "Use 'ls -la $CACHE_DIR' to see cached outputs" diff --git a/examples/test_plan.py b/examples/test_plan.py new file mode 100755 index 0000000..9b3a257 --- /dev/null +++ b/examples/test_plan.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +""" +Test the planning phase locally. + +This tests the new human-readable names and multi-output support +without requiring actual video files or execution. +""" + +import hashlib +import json +import sys +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from artdag.planning import RecipePlanner, Recipe, ExecutionPlan + + +def main(): + # Load recipe + recipe_path = Path(__file__).parent / "simple_sequence.yaml" + if not recipe_path.exists(): + print(f"Recipe not found: {recipe_path}") + return 1 + + recipe = Recipe.from_file(recipe_path) + print(f"Recipe: {recipe.name} v{recipe.version}") + print(f"Nodes: {len(recipe.nodes)}") + print() + + # Fake input hash (would be real content hash in production) + fake_input_hash = hashlib.sha3_256(b"fake video content").hexdigest() + input_hashes = {"video": fake_input_hash} + + print(f"Input: video -> {fake_input_hash[:16]}...") + print() + + # Generate plan + planner = RecipePlanner(use_tree_reduction=True) + plan = planner.plan( + recipe=recipe, + input_hashes=input_hashes, + seed=42, # Optional seed for reproducibility + ) + + print("=== Generated Plan ===") + print(f"Plan ID: {plan.plan_id[:24]}...") + print(f"Plan Name: {plan.name}") + print(f"Recipe Name: {plan.recipe_name}") + print(f"Output: {plan.output_name}") + print(f"Steps: {len(plan.steps)}") + print() + + # Show steps by level + steps_by_level = plan.get_steps_by_level() + for level in sorted(steps_by_level.keys()): + steps = steps_by_level[level] + print(f"Level {level}: {len(steps)} step(s)") + for step in steps: + # Show human-readable name + name = step.name or step.step_id[:20] + print(f" - {name}") + print(f" Type: {step.node_type}") + print(f" Cache ID: {step.cache_id[:16]}...") + if step.outputs: + print(f" Outputs: {len(step.outputs)}") + for out in step.outputs: + print(f" - {out.name} ({out.media_type})") + if step.inputs: + print(f" Inputs: {[inp.name for inp in step.inputs]}") + print() + + # Save plan for inspection + plan_path = Path(__file__).parent.parent / "test_plan_output.json" + with open(plan_path, "w") as f: + f.write(plan.to_json()) + print(f"Plan saved to: {plan_path}") + + # Show plan JSON structure + print() + print("=== Plan JSON Preview ===") + plan_dict = json.loads(plan.to_json()) + # Show first step as example + if plan_dict.get("steps"): + first_step = plan_dict["steps"][0] + print(json.dumps(first_step, indent=2)[:500] + "...") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9ac24c7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,62 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "artdag" +version = "0.1.0" +description = "Content-addressed DAG execution engine with ActivityPub ownership" +readme = "README.md" +license = {text = "MIT"} +requires-python = ">=3.10" +authors = [ + {name = "Giles", email = "giles@rose-ash.com"} +] +keywords = ["dag", "content-addressed", "activitypub", "video", "processing"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "cryptography>=41.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", +] +analysis = [ + "librosa>=0.10.0", + "numpy>=1.24.0", + "pyyaml>=6.0", +] +cv = [ + "opencv-python>=4.8.0", +] +all = [ + "librosa>=0.10.0", + "numpy>=1.24.0", + "pyyaml>=6.0", + "opencv-python>=4.8.0", +] + +[project.scripts] +artdag = "artdag.cli:main" + +[project.urls] +Homepage = "https://artdag.rose-ash.com" +Repository = "https://github.com/giles/artdag" + +[tool.setuptools.packages.find] +where = ["."] +include = ["artdag*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] diff --git a/scripts/compute_repo_hash.py b/scripts/compute_repo_hash.py new file mode 100644 index 0000000..8e841e1 --- /dev/null +++ b/scripts/compute_repo_hash.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +Compute content hash of a git repository. + +Hashes all tracked files (respects .gitignore) in sorted order. +""" + +import hashlib +import subprocess +import sys +from pathlib import Path + + +def repo_hash(repo_path: Path) -> str: + """ + Compute SHA3-256 hash of all tracked files in a repo. + + Uses git ls-files to respect .gitignore. + Files are hashed in sorted order for determinism. + Each file contributes: relative_path + file_contents + """ + # Get list of tracked files + result = subprocess.run( + ["git", "ls-files"], + cwd=repo_path, + capture_output=True, + text=True, + check=True, + ) + + files = sorted(result.stdout.strip().split("\n")) + + hasher = hashlib.sha3_256() + + for rel_path in files: + if not rel_path: + continue + + file_path = repo_path / rel_path + if not file_path.is_file(): + continue + + # Include path in hash + hasher.update(rel_path.encode()) + + # Include contents + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + + return hasher.hexdigest() + + +def main(): + if len(sys.argv) > 1: + repo_path = Path(sys.argv[1]) + else: + repo_path = Path.cwd() + + h = repo_hash(repo_path) + print(f"Repository: {repo_path}") + print(f"Hash: {h}") + return h + + +if __name__ == "__main__": + main() diff --git a/scripts/install-ffglitch.sh b/scripts/install-ffglitch.sh new file mode 100755 index 0000000..d7301f2 --- /dev/null +++ b/scripts/install-ffglitch.sh @@ -0,0 +1,82 @@ +#!/bin/bash +# Install ffglitch for datamosh effects +# Usage: ./install-ffglitch.sh [install_dir] + +set -e + +FFGLITCH_VERSION="0.10.2" +INSTALL_DIR="${1:-/usr/local/bin}" + +# Detect architecture +ARCH=$(uname -m) +case "$ARCH" in + x86_64) + URL="https://ffglitch.org/pub/bin/linux64/ffglitch-${FFGLITCH_VERSION}-linux-x86_64.zip" + ARCHIVE="ffglitch.zip" + ;; + aarch64) + URL="https://ffglitch.org/pub/bin/linux-aarch64/ffglitch-${FFGLITCH_VERSION}-linux-aarch64.7z" + ARCHIVE="ffglitch.7z" + ;; + *) + echo "Unsupported architecture: $ARCH" + exit 1 + ;; +esac + +echo "Installing ffglitch ${FFGLITCH_VERSION} for ${ARCH}..." + +# Create temp directory +TMPDIR=$(mktemp -d) +cd "$TMPDIR" + +# Download +echo "Downloading from ${URL}..." +curl -L -o "$ARCHIVE" "$URL" + +# Extract +echo "Extracting..." +if [[ "$ARCHIVE" == *.zip ]]; then + unzip -q "$ARCHIVE" +elif [[ "$ARCHIVE" == *.7z ]]; then + # Requires p7zip + if ! command -v 7z &> /dev/null; then + echo "7z not found. Install with: apt install p7zip-full" + exit 1 + fi + 7z x "$ARCHIVE" > /dev/null +fi + +# Find and install binaries +echo "Installing to ${INSTALL_DIR}..." +find . -name "ffgac" -o -name "ffedit" | while read bin; do + chmod +x "$bin" + if [ -w "$INSTALL_DIR" ]; then + cp "$bin" "$INSTALL_DIR/" + else + sudo cp "$bin" "$INSTALL_DIR/" + fi + echo " Installed: $(basename $bin)" +done + +# Cleanup +cd / +rm -rf "$TMPDIR" + +# Verify +echo "" +echo "Verifying installation..." +if command -v ffgac &> /dev/null; then + echo "ffgac: $(which ffgac)" +else + echo "Warning: ffgac not in PATH. Add ${INSTALL_DIR} to PATH." +fi + +if command -v ffedit &> /dev/null; then + echo "ffedit: $(which ffedit)" +else + echo "Warning: ffedit not in PATH. Add ${INSTALL_DIR} to PATH." +fi + +echo "" +echo "Done! ffglitch installed." diff --git a/scripts/register_identity_effect.py b/scripts/register_identity_effect.py new file mode 100644 index 0000000..0194698 --- /dev/null +++ b/scripts/register_identity_effect.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Register the identity effect owned by giles. +""" + +import hashlib +from pathlib import Path +import sys + +# Add parent to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from artdag.activitypub.ownership import OwnershipManager + + +def folder_hash(folder: Path) -> str: + """ + Compute SHA3-256 hash of an entire folder. + + Hashes all files in sorted order for deterministic results. + Each file contributes: relative_path + file_contents + """ + hasher = hashlib.sha3_256() + + # Get all files sorted by relative path + files = sorted(folder.rglob("*")) + + for file_path in files: + if file_path.is_file(): + # Include relative path in hash for structure + rel_path = file_path.relative_to(folder) + hasher.update(str(rel_path).encode()) + + # Include file contents + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): + hasher.update(chunk) + + return hasher.hexdigest() + + +def main(): + # Use .cache as the ownership data directory + base_dir = Path(__file__).parent.parent / ".cache" / "ownership" + manager = OwnershipManager(base_dir) + + # Create or get giles actor + actor = manager.get_actor("giles") + if not actor: + actor = manager.create_actor("giles", "Giles Bradshaw") + print(f"Created actor: {actor.handle}") + else: + print(f"Using existing actor: {actor.handle}") + + # Register the identity effect folder + effect_path = Path(__file__).parent.parent / "effects" / "identity" + cid = folder_hash(effect_path) + + asset, activity = manager.register_asset( + actor=actor, + name="effect:identity", + cid=cid, + local_path=effect_path, + tags=["effect", "primitive", "identity"], + metadata={ + "type": "effect", + "description": "The identity effect - returns input unchanged", + "signature": "identity(input) → input", + }, + ) + + print(f"\nRegistered: {asset.name}") + print(f" Hash: {asset.cid}") + print(f" Path: {asset.local_path}") + print(f" Activity: {activity.activity_id}") + print(f" Owner: {actor.handle}") + + # Verify ownership + verified = manager.verify_ownership(asset.name, actor) + print(f" Ownership verified: {verified}") + +if __name__ == "__main__": + main() diff --git a/scripts/setup_actor.py b/scripts/setup_actor.py new file mode 100644 index 0000000..b1c80cf --- /dev/null +++ b/scripts/setup_actor.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Set up actor with keypair stored securely. + +Private key: ~/.artdag/keys/{username}.pem +Public key: exported for registry +""" + +import json +import os +import sys +from datetime import datetime, timezone +from pathlib import Path + +# Add artdag to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.backends import default_backend + + +def create_keypair(): + """Generate RSA-2048 keypair.""" + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend(), + ) + return private_key + + +def save_private_key(private_key, path: Path): + """Save private key to PEM file.""" + pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(pem) + os.chmod(path, 0o600) # Owner read/write only + return pem.decode() + + +def get_public_key_pem(private_key) -> str: + """Extract public key as PEM string.""" + public_key = private_key.public_key() + pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + return pem.decode() + + +def create_actor_json(username: str, display_name: str, public_key_pem: str, domain: str = "artdag.rose-ash.com"): + """Create ActivityPub actor JSON.""" + return { + "@context": [ + "https://www.w3.org/ns/activitystreams", + "https://w3id.org/security/v1" + ], + "type": "Person", + "id": f"https://{domain}/users/{username}", + "preferredUsername": username, + "name": display_name, + "inbox": f"https://{domain}/users/{username}/inbox", + "outbox": f"https://{domain}/users/{username}/outbox", + "publicKey": { + "id": f"https://{domain}/users/{username}#main-key", + "owner": f"https://{domain}/users/{username}", + "publicKeyPem": public_key_pem + } + } + + +def main(): + username = "giles" + display_name = "Giles Bradshaw" + domain = "artdag.rose-ash.com" + + keys_dir = Path.home() / ".artdag" / "keys" + private_key_path = keys_dir / f"{username}.pem" + + # Check if key already exists + if private_key_path.exists(): + print(f"Private key already exists: {private_key_path}") + print("Delete it first if you want to regenerate.") + sys.exit(1) + + # Create new keypair + print(f"Creating new keypair for @{username}@{domain}...") + private_key = create_keypair() + + # Save private key + save_private_key(private_key, private_key_path) + print(f"Private key saved: {private_key_path}") + print(f" Mode: 600 (owner read/write only)") + print(f" BACK THIS UP!") + + # Get public key + public_key_pem = get_public_key_pem(private_key) + + # Create actor JSON + actor = create_actor_json(username, display_name, public_key_pem, domain) + + # Output actor JSON + actor_json = json.dumps(actor, indent=2) + print(f"\nActor JSON (for registry/actors/{username}.json):") + print(actor_json) + + # Save to registry + registry_path = Path.home() / "artdag-registry" / "actors" / f"{username}.json" + registry_path.parent.mkdir(parents=True, exist_ok=True) + registry_path.write_text(actor_json) + print(f"\nSaved to: {registry_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/sign_assets.py b/scripts/sign_assets.py new file mode 100644 index 0000000..8021f78 --- /dev/null +++ b/scripts/sign_assets.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +""" +Sign assets in the registry with giles's private key. + +Creates ActivityPub Create activities with RSA signatures. +""" + +import base64 +import hashlib +import json +import sys +import uuid +from datetime import datetime, timezone +from pathlib import Path + +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.backends import default_backend + + +def load_private_key(path: Path): + """Load private key from PEM file.""" + pem_data = path.read_bytes() + return serialization.load_pem_private_key(pem_data, password=None, backend=default_backend()) + + +def sign_data(private_key, data: str) -> str: + """Sign data with RSA private key, return base64 signature.""" + signature = private_key.sign( + data.encode(), + padding.PKCS1v15(), + hashes.SHA256(), + ) + return base64.b64encode(signature).decode() + + +def create_activity(actor_id: str, asset_name: str, cid: str, asset_type: str, domain: str = "artdag.rose-ash.com"): + """Create a Create activity for an asset.""" + now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + + return { + "activity_id": str(uuid.uuid4()), + "activity_type": "Create", + "actor_id": actor_id, + "object_data": { + "type": asset_type_to_ap(asset_type), + "name": asset_name, + "id": f"https://{domain}/objects/{cid}", + "contentHash": { + "algorithm": "sha3-256", + "value": cid + }, + "attributedTo": actor_id + }, + "published": now, + } + + +def asset_type_to_ap(asset_type: str) -> str: + """Convert asset type to ActivityPub type.""" + type_map = { + "image": "Image", + "video": "Video", + "audio": "Audio", + "effect": "Application", + "infrastructure": "Application", + } + return type_map.get(asset_type, "Document") + + +def sign_activity(activity: dict, private_key, actor_id: str, domain: str = "artdag.rose-ash.com") -> dict: + """Add signature to activity.""" + # Create canonical string to sign + to_sign = json.dumps(activity["object_data"], sort_keys=True, separators=(",", ":")) + + signature_value = sign_data(private_key, to_sign) + + activity["signature"] = { + "type": "RsaSignature2017", + "creator": f"{actor_id}#main-key", + "created": activity["published"], + "signatureValue": signature_value + } + + return activity + + +def main(): + username = "giles" + domain = "artdag.rose-ash.com" + actor_id = f"https://{domain}/users/{username}" + + # Load private key + private_key_path = Path.home() / ".artdag" / "keys" / f"{username}.pem" + if not private_key_path.exists(): + print(f"Private key not found: {private_key_path}") + print("Run setup_actor.py first.") + sys.exit(1) + + private_key = load_private_key(private_key_path) + print(f"Loaded private key: {private_key_path}") + + # Load registry + registry_path = Path.home() / "artdag-registry" / "registry.json" + with open(registry_path) as f: + registry = json.load(f) + + # Create signed activities for each asset + activities = [] + + for asset_name, asset_data in registry["assets"].items(): + print(f"\nSigning: {asset_name}") + print(f" Hash: {asset_data['cid'][:16]}...") + + activity = create_activity( + actor_id=actor_id, + asset_name=asset_name, + cid=asset_data["cid"], + asset_type=asset_data["asset_type"], + domain=domain, + ) + + signed_activity = sign_activity(activity, private_key, actor_id, domain) + activities.append(signed_activity) + + print(f" Activity ID: {signed_activity['activity_id']}") + print(f" Signature: {signed_activity['signature']['signatureValue'][:32]}...") + + # Save activities + activities_path = Path.home() / "artdag-registry" / "activities.json" + activities_data = { + "version": "1.0", + "activities": activities + } + + with open(activities_path, "w") as f: + json.dump(activities_data, f, indent=2) + + print(f"\nSaved {len(activities)} signed activities to: {activities_path}") + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..f6aed20 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Tests for new standalone primitive engine diff --git a/tests/test_activities.py b/tests/test_activities.py new file mode 100644 index 0000000..36ba61d --- /dev/null +++ b/tests/test_activities.py @@ -0,0 +1,613 @@ +# tests/test_activities.py +"""Tests for the activity tracking and cache deletion system.""" + +import tempfile +import time +from pathlib import Path + +import pytest + +from artdag import Cache, DAG, Node, NodeType +from artdag.activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn + + +class MockActivityPubStore: + """Mock ActivityPub store for testing is_shared functionality.""" + + def __init__(self): + self._shared_hashes = set() + + def mark_shared(self, cid: str): + """Mark a content hash as shared (published).""" + self._shared_hashes.add(cid) + + def find_by_object_hash(self, cid: str): + """Return mock activities for shared hashes.""" + if cid in self._shared_hashes: + return [MockActivity("Create")] + return [] + + +class MockActivity: + """Mock ActivityPub activity.""" + def __init__(self, activity_type: str): + self.activity_type = activity_type + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def cache(temp_dir): + """Create a cache instance.""" + return Cache(temp_dir / "cache") + + +@pytest.fixture +def activity_store(temp_dir): + """Create an activity store instance.""" + return ActivityStore(temp_dir / "activities") + + +@pytest.fixture +def ap_store(): + """Create a mock ActivityPub store.""" + return MockActivityPubStore() + + +@pytest.fixture +def manager(cache, activity_store, ap_store): + """Create an ActivityManager instance.""" + return ActivityManager( + cache=cache, + activity_store=activity_store, + is_shared_fn=make_is_shared_fn(ap_store), + ) + + +def create_test_file(path: Path, content: str = "test content") -> Path: + """Create a test file with content.""" + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content) + return path + + +class TestCacheEntryContentHash: + """Tests for cid in CacheEntry.""" + + def test_put_computes_cid(self, cache, temp_dir): + """put() should compute and store cid.""" + test_file = create_test_file(temp_dir / "input.txt", "hello world") + + cache.put("node1", test_file, "test") + entry = cache.get_entry("node1") + + assert entry is not None + assert entry.cid != "" + assert len(entry.cid) == 64 # SHA-3-256 hex + + def test_same_content_same_hash(self, cache, temp_dir): + """Same file content should produce same hash.""" + file1 = create_test_file(temp_dir / "file1.txt", "identical content") + file2 = create_test_file(temp_dir / "file2.txt", "identical content") + + cache.put("node1", file1, "test") + cache.put("node2", file2, "test") + + entry1 = cache.get_entry("node1") + entry2 = cache.get_entry("node2") + + assert entry1.cid == entry2.cid + + def test_different_content_different_hash(self, cache, temp_dir): + """Different file content should produce different hash.""" + file1 = create_test_file(temp_dir / "file1.txt", "content A") + file2 = create_test_file(temp_dir / "file2.txt", "content B") + + cache.put("node1", file1, "test") + cache.put("node2", file2, "test") + + entry1 = cache.get_entry("node1") + entry2 = cache.get_entry("node2") + + assert entry1.cid != entry2.cid + + def test_find_by_cid(self, cache, temp_dir): + """Should find entry by content hash.""" + test_file = create_test_file(temp_dir / "input.txt", "unique content") + cache.put("node1", test_file, "test") + + entry = cache.get_entry("node1") + found = cache.find_by_cid(entry.cid) + + assert found is not None + assert found.node_id == "node1" + + def test_cid_persists(self, temp_dir): + """cid should persist across cache reloads.""" + cache1 = Cache(temp_dir / "cache") + test_file = create_test_file(temp_dir / "input.txt", "persistent") + cache1.put("node1", test_file, "test") + original_hash = cache1.get_entry("node1").cid + + # Create new cache instance (reload from disk) + cache2 = Cache(temp_dir / "cache") + entry = cache2.get_entry("node1") + + assert entry.cid == original_hash + + +class TestActivity: + """Tests for Activity dataclass.""" + + def test_activity_from_dag(self): + """Activity.from_dag() should classify nodes correctly.""" + # Build a simple DAG: source -> transform -> output + dag = DAG() + source = Node(NodeType.SOURCE, {"path": "/test.mp4"}) + transform = Node(NodeType.TRANSFORM, {"effect": "blur"}, inputs=[source.node_id]) + output = Node(NodeType.RESIZE, {"width": 100}, inputs=[transform.node_id]) + + dag.add_node(source) + dag.add_node(transform) + dag.add_node(output) + dag.set_output(output.node_id) + + activity = Activity.from_dag(dag) + + assert source.node_id in activity.input_ids + assert activity.output_id == output.node_id + assert transform.node_id in activity.intermediate_ids + + def test_activity_with_multiple_inputs(self): + """Activity should handle DAGs with multiple source nodes.""" + dag = DAG() + source1 = Node(NodeType.SOURCE, {"path": "/a.mp4"}) + source2 = Node(NodeType.SOURCE, {"path": "/b.mp4"}) + sequence = Node(NodeType.SEQUENCE, {}, inputs=[source1.node_id, source2.node_id]) + + dag.add_node(source1) + dag.add_node(source2) + dag.add_node(sequence) + dag.set_output(sequence.node_id) + + activity = Activity.from_dag(dag) + + assert len(activity.input_ids) == 2 + assert source1.node_id in activity.input_ids + assert source2.node_id in activity.input_ids + assert activity.output_id == sequence.node_id + assert len(activity.intermediate_ids) == 0 + + def test_activity_serialization(self): + """Activity should serialize and deserialize correctly.""" + dag = DAG() + source = Node(NodeType.SOURCE, {"path": "/test.mp4"}) + dag.add_node(source) + dag.set_output(source.node_id) + + activity = Activity.from_dag(dag) + data = activity.to_dict() + restored = Activity.from_dict(data) + + assert restored.activity_id == activity.activity_id + assert restored.input_ids == activity.input_ids + assert restored.output_id == activity.output_id + assert restored.intermediate_ids == activity.intermediate_ids + + def test_all_node_ids(self): + """all_node_ids should return all nodes.""" + activity = Activity( + activity_id="test", + input_ids=["a", "b"], + output_id="c", + intermediate_ids=["d", "e"], + created_at=time.time(), + ) + + all_ids = activity.all_node_ids + assert set(all_ids) == {"a", "b", "c", "d", "e"} + + +class TestActivityStore: + """Tests for ActivityStore persistence.""" + + def test_add_and_get(self, activity_store): + """Should add and retrieve activities.""" + activity = Activity( + activity_id="test1", + input_ids=["input1"], + output_id="output1", + intermediate_ids=["inter1"], + created_at=time.time(), + ) + + activity_store.add(activity) + retrieved = activity_store.get("test1") + + assert retrieved is not None + assert retrieved.activity_id == "test1" + + def test_persistence(self, temp_dir): + """Activities should persist across store reloads.""" + store1 = ActivityStore(temp_dir / "activities") + activity = Activity( + activity_id="persist", + input_ids=["i1"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + store1.add(activity) + + # Reload + store2 = ActivityStore(temp_dir / "activities") + retrieved = store2.get("persist") + + assert retrieved is not None + assert retrieved.activity_id == "persist" + + def test_find_by_input_ids(self, activity_store): + """Should find activities with matching inputs.""" + activity1 = Activity( + activity_id="a1", + input_ids=["x", "y"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity2 = Activity( + activity_id="a2", + input_ids=["y", "x"], # Same inputs, different order + output_id="o2", + intermediate_ids=[], + created_at=time.time(), + ) + activity3 = Activity( + activity_id="a3", + input_ids=["z"], # Different inputs + output_id="o3", + intermediate_ids=[], + created_at=time.time(), + ) + + activity_store.add(activity1) + activity_store.add(activity2) + activity_store.add(activity3) + + found = activity_store.find_by_input_ids(["x", "y"]) + assert len(found) == 2 + assert {a.activity_id for a in found} == {"a1", "a2"} + + def test_find_using_node(self, activity_store): + """Should find activities referencing a node.""" + activity = Activity( + activity_id="a1", + input_ids=["input1"], + output_id="output1", + intermediate_ids=["inter1"], + created_at=time.time(), + ) + activity_store.add(activity) + + # Should find by input + found = activity_store.find_using_node("input1") + assert len(found) == 1 + + # Should find by intermediate + found = activity_store.find_using_node("inter1") + assert len(found) == 1 + + # Should find by output + found = activity_store.find_using_node("output1") + assert len(found) == 1 + + # Should not find unknown + found = activity_store.find_using_node("unknown") + assert len(found) == 0 + + def test_remove(self, activity_store): + """Should remove activities.""" + activity = Activity( + activity_id="to_remove", + input_ids=["i"], + output_id="o", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + assert activity_store.get("to_remove") is not None + + result = activity_store.remove("to_remove") + assert result is True + assert activity_store.get("to_remove") is None + + +class TestActivityManager: + """Tests for ActivityManager deletion rules.""" + + def test_can_delete_orphaned_entry(self, manager, cache, temp_dir): + """Orphaned entries (not in any activity) can be deleted.""" + test_file = create_test_file(temp_dir / "orphan.txt", "orphan") + cache.put("orphan_node", test_file, "test") + + assert manager.can_delete_cache_entry("orphan_node") is True + + def test_cannot_delete_shared_entry(self, manager, cache, temp_dir, ap_store): + """Shared entries (ActivityPub published) cannot be deleted.""" + test_file = create_test_file(temp_dir / "shared.txt", "shared content") + cache.put("shared_node", test_file, "test") + + # Mark as shared + entry = cache.get_entry("shared_node") + ap_store.mark_shared(entry.cid) + + assert manager.can_delete_cache_entry("shared_node") is False + + def test_cannot_delete_activity_input(self, manager, cache, activity_store, temp_dir): + """Activity inputs cannot be deleted.""" + test_file = create_test_file(temp_dir / "input.txt", "input") + cache.put("input_node", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input_node"], + output_id="output_node", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_delete_cache_entry("input_node") is False + + def test_cannot_delete_activity_output(self, manager, cache, activity_store, temp_dir): + """Activity outputs cannot be deleted.""" + test_file = create_test_file(temp_dir / "output.txt", "output") + cache.put("output_node", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input_node"], + output_id="output_node", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_delete_cache_entry("output_node") is False + + def test_can_delete_intermediate(self, manager, cache, activity_store, temp_dir): + """Intermediate entries can be deleted (they're reconstructible).""" + test_file = create_test_file(temp_dir / "inter.txt", "intermediate") + cache.put("inter_node", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input_node"], + output_id="output_node", + intermediate_ids=["inter_node"], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_delete_cache_entry("inter_node") is True + + def test_can_discard_activity_no_shared(self, manager, activity_store): + """Activity can be discarded if nothing is shared.""" + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=["m1"], + created_at=time.time(), + ) + activity_store.add(activity) + + assert manager.can_discard_activity("a1") is True + + def test_cannot_discard_activity_with_shared_output(self, manager, cache, activity_store, temp_dir, ap_store): + """Activity cannot be discarded if output is shared.""" + test_file = create_test_file(temp_dir / "output.txt", "output content") + cache.put("o1", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + # Mark output as shared + entry = cache.get_entry("o1") + ap_store.mark_shared(entry.cid) + + assert manager.can_discard_activity("a1") is False + + def test_cannot_discard_activity_with_shared_input(self, manager, cache, activity_store, temp_dir, ap_store): + """Activity cannot be discarded if input is shared.""" + test_file = create_test_file(temp_dir / "input.txt", "input content") + cache.put("i1", test_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + entry = cache.get_entry("i1") + ap_store.mark_shared(entry.cid) + + assert manager.can_discard_activity("a1") is False + + def test_discard_activity_deletes_intermediates(self, manager, cache, activity_store, temp_dir): + """Discarding activity should delete intermediate cache entries.""" + # Create cache entries + input_file = create_test_file(temp_dir / "input.txt", "input") + inter_file = create_test_file(temp_dir / "inter.txt", "intermediate") + output_file = create_test_file(temp_dir / "output.txt", "output") + + cache.put("i1", input_file, "test") + cache.put("m1", inter_file, "test") + cache.put("o1", output_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=["i1"], + output_id="o1", + intermediate_ids=["m1"], + created_at=time.time(), + ) + activity_store.add(activity) + + # Discard + result = manager.discard_activity("a1") + + assert result is True + assert cache.has("m1") is False # Intermediate deleted + assert activity_store.get("a1") is None # Activity removed + + def test_discard_activity_deletes_orphaned_output(self, manager, cache, activity_store, temp_dir): + """Discarding activity should delete output if orphaned.""" + output_file = create_test_file(temp_dir / "output.txt", "output") + cache.put("o1", output_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=[], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + manager.discard_activity("a1") + + assert cache.has("o1") is False # Orphaned output deleted + + def test_discard_activity_keeps_shared_output(self, manager, cache, activity_store, temp_dir, ap_store): + """Discarding should fail if output is shared.""" + output_file = create_test_file(temp_dir / "output.txt", "shared output") + cache.put("o1", output_file, "test") + + activity = Activity( + activity_id="a1", + input_ids=[], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity) + + entry = cache.get_entry("o1") + ap_store.mark_shared(entry.cid) + + result = manager.discard_activity("a1") + + assert result is False # Cannot discard + assert cache.has("o1") is True # Output preserved + assert activity_store.get("a1") is not None # Activity preserved + + def test_discard_keeps_input_used_elsewhere(self, manager, cache, activity_store, temp_dir): + """Input used by another activity should not be deleted.""" + input_file = create_test_file(temp_dir / "input.txt", "shared input") + cache.put("shared_input", input_file, "test") + + activity1 = Activity( + activity_id="a1", + input_ids=["shared_input"], + output_id="o1", + intermediate_ids=[], + created_at=time.time(), + ) + activity2 = Activity( + activity_id="a2", + input_ids=["shared_input"], + output_id="o2", + intermediate_ids=[], + created_at=time.time(), + ) + activity_store.add(activity1) + activity_store.add(activity2) + + manager.discard_activity("a1") + + # Input still used by a2 + assert cache.has("shared_input") is True + + def test_get_deletable_entries(self, manager, cache, activity_store, temp_dir): + """Should list all deletable entries.""" + # Orphan (deletable) + orphan = create_test_file(temp_dir / "orphan.txt", "orphan") + cache.put("orphan", orphan, "test") + + # Intermediate (deletable) + inter = create_test_file(temp_dir / "inter.txt", "inter") + cache.put("inter", inter, "test") + + # Input (not deletable) + inp = create_test_file(temp_dir / "input.txt", "input") + cache.put("input", inp, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input"], + output_id="output", + intermediate_ids=["inter"], + created_at=time.time(), + ) + activity_store.add(activity) + + deletable = manager.get_deletable_entries() + deletable_ids = {e.node_id for e in deletable} + + assert "orphan" in deletable_ids + assert "inter" in deletable_ids + assert "input" not in deletable_ids + + def test_cleanup_intermediates(self, manager, cache, activity_store, temp_dir): + """cleanup_intermediates() should delete all intermediate entries.""" + inter1 = create_test_file(temp_dir / "i1.txt", "inter1") + inter2 = create_test_file(temp_dir / "i2.txt", "inter2") + cache.put("inter1", inter1, "test") + cache.put("inter2", inter2, "test") + + activity = Activity( + activity_id="a1", + input_ids=["input"], + output_id="output", + intermediate_ids=["inter1", "inter2"], + created_at=time.time(), + ) + activity_store.add(activity) + + deleted = manager.cleanup_intermediates() + + assert deleted == 2 + assert cache.has("inter1") is False + assert cache.has("inter2") is False + + +class TestMakeIsSharedFn: + """Tests for make_is_shared_fn factory.""" + + def test_returns_true_for_shared(self, ap_store): + """Should return True for shared content.""" + is_shared = make_is_shared_fn(ap_store) + ap_store.mark_shared("hash123") + + assert is_shared("hash123") is True + + def test_returns_false_for_not_shared(self, ap_store): + """Should return False for non-shared content.""" + is_shared = make_is_shared_fn(ap_store) + + assert is_shared("unknown_hash") is False diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..2aac235 --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,163 @@ +# tests/test_primitive_new/test_cache.py +"""Tests for primitive cache module.""" + +import pytest +import tempfile +from pathlib import Path + +from artdag.cache import Cache, CacheStats + + +@pytest.fixture +def cache_dir(): + """Create temporary cache directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def cache(cache_dir): + """Create cache instance.""" + return Cache(cache_dir) + + +@pytest.fixture +def sample_file(cache_dir): + """Create a sample file to cache.""" + file_path = cache_dir / "sample.txt" + file_path.write_text("test content") + return file_path + + +class TestCache: + """Test Cache class.""" + + def test_cache_creation(self, cache_dir): + """Test cache directory is created.""" + cache = Cache(cache_dir / "new_cache") + assert cache.cache_dir.exists() + + def test_cache_put_and_get(self, cache, sample_file): + """Test putting and getting from cache.""" + node_id = "abc123" + cached_path = cache.put(node_id, sample_file, "TEST") + + assert cached_path.exists() + assert cache.has(node_id) + + retrieved = cache.get(node_id) + assert retrieved == cached_path + + def test_cache_miss(self, cache): + """Test cache miss returns None.""" + result = cache.get("nonexistent") + assert result is None + + def test_cache_stats_hit_miss(self, cache, sample_file): + """Test cache hit/miss stats.""" + cache.put("abc123", sample_file, "TEST") + + # Miss + cache.get("nonexistent") + assert cache.stats.misses == 1 + + # Hit + cache.get("abc123") + assert cache.stats.hits == 1 + + assert cache.stats.hit_rate == 0.5 + + def test_cache_remove(self, cache, sample_file): + """Test removing from cache.""" + node_id = "abc123" + cache.put(node_id, sample_file, "TEST") + assert cache.has(node_id) + + cache.remove(node_id) + assert not cache.has(node_id) + + def test_cache_clear(self, cache, sample_file): + """Test clearing cache.""" + cache.put("node1", sample_file, "TEST") + cache.put("node2", sample_file, "TEST") + + assert cache.stats.total_entries == 2 + + cache.clear() + + assert cache.stats.total_entries == 0 + assert not cache.has("node1") + assert not cache.has("node2") + + def test_cache_preserves_extension(self, cache, cache_dir): + """Test that cache preserves file extension.""" + mp4_file = cache_dir / "video.mp4" + mp4_file.write_text("fake video") + + cached = cache.put("video_node", mp4_file, "SOURCE") + assert cached.suffix == ".mp4" + + def test_cache_list_entries(self, cache, sample_file): + """Test listing cache entries.""" + cache.put("node1", sample_file, "TYPE1") + cache.put("node2", sample_file, "TYPE2") + + entries = cache.list_entries() + assert len(entries) == 2 + + node_ids = {e.node_id for e in entries} + assert "node1" in node_ids + assert "node2" in node_ids + + def test_cache_persistence(self, cache_dir, sample_file): + """Test cache persists across instances.""" + # First instance + cache1 = Cache(cache_dir) + cache1.put("abc123", sample_file, "TEST") + + # Second instance loads from disk + cache2 = Cache(cache_dir) + assert cache2.has("abc123") + + def test_cache_prune_by_age(self, cache, sample_file): + """Test pruning by age.""" + import time + + cache.put("old_node", sample_file, "TEST") + + # Manually set old creation time + entry = cache._entries["old_node"] + entry.created_at = time.time() - 3600 # 1 hour ago + + removed = cache.prune(max_age_seconds=1800) # 30 minutes + + assert removed == 1 + assert not cache.has("old_node") + + def test_cache_output_path(self, cache): + """Test getting output path for node.""" + path = cache.get_output_path("abc123", ".mp4") + assert path.suffix == ".mp4" + assert "abc123" in str(path) + assert path.parent.exists() + + +class TestCacheStats: + """Test CacheStats class.""" + + def test_hit_rate_calculation(self): + """Test hit rate calculation.""" + stats = CacheStats() + + stats.record_hit() + stats.record_hit() + stats.record_miss() + + assert stats.hits == 2 + assert stats.misses == 1 + assert abs(stats.hit_rate - 0.666) < 0.01 + + def test_initial_hit_rate(self): + """Test hit rate with no requests.""" + stats = CacheStats() + assert stats.hit_rate == 0.0 diff --git a/tests/test_dag.py b/tests/test_dag.py new file mode 100644 index 0000000..48250c6 --- /dev/null +++ b/tests/test_dag.py @@ -0,0 +1,271 @@ +# tests/test_primitive_new/test_dag.py +"""Tests for primitive DAG data structures.""" + +import pytest +from artdag.dag import Node, NodeType, DAG, DAGBuilder + + +class TestNode: + """Test Node class.""" + + def test_node_creation(self): + """Test basic node creation.""" + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + assert node.node_type == NodeType.SOURCE + assert node.config == {"path": "/test.mp4"} + assert node.node_id is not None + + def test_node_id_is_content_addressed(self): + """Same content produces same node_id.""" + node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + assert node1.node_id == node2.node_id + + def test_different_config_different_id(self): + """Different config produces different node_id.""" + node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test1.mp4"}) + node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test2.mp4"}) + assert node1.node_id != node2.node_id + + def test_node_with_inputs(self): + """Node with inputs includes them in ID.""" + node1 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["abc123"]) + node2 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["abc123"]) + node3 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["def456"]) + + assert node1.node_id == node2.node_id + assert node1.node_id != node3.node_id + + def test_node_serialization(self): + """Test node to_dict and from_dict.""" + original = Node( + node_type=NodeType.SEGMENT, + config={"duration": 5.0, "offset": 10.0}, + inputs=["abc123"], + name="my_segment", + ) + data = original.to_dict() + restored = Node.from_dict(data) + + assert restored.node_type == original.node_type + assert restored.config == original.config + assert restored.inputs == original.inputs + assert restored.name == original.name + assert restored.node_id == original.node_id + + def test_custom_node_type(self): + """Test node with custom string type.""" + node = Node(node_type="CUSTOM_TYPE", config={"custom": True}) + assert node.node_type == "CUSTOM_TYPE" + assert node.node_id is not None + + +class TestDAG: + """Test DAG class.""" + + def test_dag_creation(self): + """Test basic DAG creation.""" + dag = DAG() + assert len(dag.nodes) == 0 + assert dag.output_id is None + + def test_add_node(self): + """Test adding nodes to DAG.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + + assert node_id in dag.nodes + assert dag.nodes[node_id] == node + + def test_node_deduplication(self): + """Same node added twice returns same ID.""" + dag = DAG() + node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + + id1 = dag.add_node(node1) + id2 = dag.add_node(node2) + + assert id1 == id2 + assert len(dag.nodes) == 1 + + def test_set_output(self): + """Test setting output node.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + assert dag.output_id == node_id + + def test_set_output_invalid(self): + """Setting invalid output raises error.""" + dag = DAG() + with pytest.raises(ValueError): + dag.set_output("nonexistent") + + def test_topological_order(self): + """Test topological ordering.""" + dag = DAG() + + # Create simple chain: source -> segment -> output + source = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + source_id = dag.add_node(source) + + segment = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=[source_id]) + segment_id = dag.add_node(segment) + + dag.set_output(segment_id) + order = dag.topological_order() + + # Source must come before segment + assert order.index(source_id) < order.index(segment_id) + + def test_validate_valid_dag(self): + """Test validation of valid DAG.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + errors = dag.validate() + assert len(errors) == 0 + + def test_validate_no_output(self): + """DAG without output is invalid.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + dag.add_node(node) + + errors = dag.validate() + assert len(errors) > 0 + assert any("output" in e.lower() for e in errors) + + def test_validate_missing_input(self): + """DAG with missing input reference is invalid.""" + dag = DAG() + node = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["nonexistent"]) + node_id = dag.add_node(node) + dag.set_output(node_id) + + errors = dag.validate() + assert len(errors) > 0 + assert any("missing" in e.lower() for e in errors) + + def test_dag_serialization(self): + """Test DAG to_dict and from_dict.""" + dag = DAG(metadata={"name": "test_dag"}) + source = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + source_id = dag.add_node(source) + dag.set_output(source_id) + + data = dag.to_dict() + restored = DAG.from_dict(data) + + assert len(restored.nodes) == len(dag.nodes) + assert restored.output_id == dag.output_id + assert restored.metadata == dag.metadata + + def test_dag_json(self): + """Test DAG JSON serialization.""" + dag = DAG() + node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + json_str = dag.to_json() + restored = DAG.from_json(json_str) + + assert len(restored.nodes) == 1 + assert restored.output_id == node_id + + +class TestDAGBuilder: + """Test DAGBuilder class.""" + + def test_builder_source(self): + """Test building source node.""" + builder = DAGBuilder() + source_id = builder.source("/test.mp4") + + assert source_id in builder.dag.nodes + node = builder.dag.nodes[source_id] + assert node.node_type == NodeType.SOURCE + assert node.config["path"] == "/test.mp4" + + def test_builder_segment(self): + """Test building segment node.""" + builder = DAGBuilder() + source_id = builder.source("/test.mp4") + segment_id = builder.segment(source_id, duration=5.0, offset=10.0) + + node = builder.dag.nodes[segment_id] + assert node.node_type == NodeType.SEGMENT + assert node.config["duration"] == 5.0 + assert node.config["offset"] == 10.0 + assert source_id in node.inputs + + def test_builder_chain(self): + """Test building a chain of nodes.""" + builder = DAGBuilder() + source = builder.source("/test.mp4") + segment = builder.segment(source, duration=5.0) + resized = builder.resize(segment, width=1920, height=1080) + builder.set_output(resized) + + dag = builder.build() + + assert len(dag.nodes) == 3 + assert dag.output_id == resized + errors = dag.validate() + assert len(errors) == 0 + + def test_builder_sequence(self): + """Test building sequence node.""" + builder = DAGBuilder() + s1 = builder.source("/clip1.mp4") + s2 = builder.source("/clip2.mp4") + seq = builder.sequence([s1, s2], transition={"type": "crossfade", "duration": 0.5}) + builder.set_output(seq) + + dag = builder.build() + node = dag.nodes[seq] + assert node.node_type == NodeType.SEQUENCE + assert s1 in node.inputs + assert s2 in node.inputs + + def test_builder_mux(self): + """Test building mux node.""" + builder = DAGBuilder() + video = builder.source("/video.mp4") + audio = builder.source("/audio.mp3") + muxed = builder.mux(video, audio) + builder.set_output(muxed) + + dag = builder.build() + node = dag.nodes[muxed] + assert node.node_type == NodeType.MUX + assert video in node.inputs + assert audio in node.inputs + + def test_builder_transform(self): + """Test building transform node.""" + builder = DAGBuilder() + source = builder.source("/test.mp4") + transformed = builder.transform(source, effects={"saturation": 1.5, "contrast": 1.2}) + builder.set_output(transformed) + + dag = builder.build() + node = dag.nodes[transformed] + assert node.node_type == NodeType.TRANSFORM + assert node.config["effects"]["saturation"] == 1.5 + + def test_builder_validation_fails(self): + """Builder raises error for invalid DAG.""" + builder = DAGBuilder() + builder.source("/test.mp4") + # No output set + + with pytest.raises(ValueError): + builder.build() diff --git a/tests/test_engine.py b/tests/test_engine.py new file mode 100644 index 0000000..b6e5a95 --- /dev/null +++ b/tests/test_engine.py @@ -0,0 +1,464 @@ +# tests/test_primitive_new/test_engine.py +"""Tests for primitive engine execution.""" + +import pytest +import subprocess +import tempfile +from pathlib import Path + +from artdag.dag import DAG, DAGBuilder, Node, NodeType +from artdag.engine import Engine +from artdag import nodes # Register executors + + +@pytest.fixture +def cache_dir(): + """Create temporary cache directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def engine(cache_dir): + """Create engine instance.""" + return Engine(cache_dir) + + +@pytest.fixture +def test_video(cache_dir): + """Create a test video file.""" + video_path = cache_dir / "test_video.mp4" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "testsrc=duration=5:size=320x240:rate=30", + "-f", "lavfi", "-i", "sine=frequency=440:duration=5", + "-c:v", "libx264", "-preset", "ultrafast", + "-c:a", "aac", + str(video_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + return video_path + + +@pytest.fixture +def test_audio(cache_dir): + """Create a test audio file.""" + audio_path = cache_dir / "test_audio.mp3" + cmd = [ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=880:duration=5", + "-c:a", "libmp3lame", + str(audio_path) + ] + subprocess.run(cmd, capture_output=True, check=True) + return audio_path + + +class TestEngineBasic: + """Test basic engine functionality.""" + + def test_engine_creation(self, cache_dir): + """Test engine creation.""" + engine = Engine(cache_dir) + assert engine.cache is not None + + def test_invalid_dag(self, engine): + """Test executing invalid DAG.""" + dag = DAG() # No nodes, no output + result = engine.execute(dag) + + assert not result.success + assert "Invalid DAG" in result.error + + def test_missing_executor(self, engine): + """Test executing node with missing executor.""" + dag = DAG() + node = Node(node_type="UNKNOWN_TYPE", config={}) + node_id = dag.add_node(node) + dag.set_output(node_id) + + result = engine.execute(dag) + + assert not result.success + assert "No executor" in result.error + + +class TestSourceExecutor: + """Test SOURCE node executor.""" + + def test_source_creates_symlink(self, engine, test_video): + """Test source node creates symlink.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + assert result.output_path.is_symlink() + + def test_source_missing_file(self, engine): + """Test source with missing file.""" + builder = DAGBuilder() + source = builder.source("/nonexistent/file.mp4") + builder.set_output(source) + dag = builder.build() + + result = engine.execute(dag) + + assert not result.success + assert "not found" in result.error.lower() + + +class TestSegmentExecutor: + """Test SEGMENT node executor.""" + + def test_segment_duration(self, engine, test_video): + """Test segment extracts correct duration.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + segment = builder.segment(source, duration=2.0) + builder.set_output(segment) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Verify duration + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(result.output_path) + ], capture_output=True, text=True) + duration = float(probe.stdout.strip()) + assert abs(duration - 2.0) < 0.1 + + def test_segment_with_offset(self, engine, test_video): + """Test segment with offset.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + segment = builder.segment(source, offset=1.0, duration=2.0) + builder.set_output(segment) + dag = builder.build() + + result = engine.execute(dag) + assert result.success + + +class TestResizeExecutor: + """Test RESIZE node executor.""" + + def test_resize_dimensions(self, engine, test_video): + """Test resize to specific dimensions.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + resized = builder.resize(source, width=640, height=480, mode="fit") + builder.set_output(resized) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Verify dimensions + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "stream=width,height", + "-of", "csv=p=0:s=x", + str(result.output_path) + ], capture_output=True, text=True) + dimensions = probe.stdout.strip().split("\n")[0] + assert "640x480" in dimensions + + +class TestTransformExecutor: + """Test TRANSFORM node executor.""" + + def test_transform_saturation(self, engine, test_video): + """Test transform with saturation effect.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + transformed = builder.transform(source, effects={"saturation": 1.5}) + builder.set_output(transformed) + dag = builder.build() + + result = engine.execute(dag) + assert result.success + assert result.output_path.exists() + + def test_transform_multiple_effects(self, engine, test_video): + """Test transform with multiple effects.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + transformed = builder.transform(source, effects={ + "saturation": 1.2, + "contrast": 1.1, + "brightness": 0.05, + }) + builder.set_output(transformed) + dag = builder.build() + + result = engine.execute(dag) + assert result.success + + +class TestSequenceExecutor: + """Test SEQUENCE node executor.""" + + def test_sequence_cut(self, engine, test_video): + """Test sequence with cut transition.""" + builder = DAGBuilder() + s1 = builder.source(str(test_video)) + seg1 = builder.segment(s1, duration=2.0) + seg2 = builder.segment(s1, offset=2.0, duration=2.0) + seq = builder.sequence([seg1, seg2], transition={"type": "cut"}) + builder.set_output(seq) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Verify combined duration + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(result.output_path) + ], capture_output=True, text=True) + duration = float(probe.stdout.strip()) + assert abs(duration - 4.0) < 0.2 + + def test_sequence_crossfade(self, engine, test_video): + """Test sequence with crossfade transition.""" + builder = DAGBuilder() + s1 = builder.source(str(test_video)) + seg1 = builder.segment(s1, duration=3.0) + seg2 = builder.segment(s1, offset=1.0, duration=3.0) + seq = builder.sequence([seg1, seg2], transition={"type": "crossfade", "duration": 0.5}) + builder.set_output(seq) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + + # Duration should be sum minus crossfade + probe = subprocess.run([ + "ffprobe", "-v", "error", + "-show_entries", "format=duration", + "-of", "csv=p=0", + str(result.output_path) + ], capture_output=True, text=True) + duration = float(probe.stdout.strip()) + # 3 + 3 - 0.5 = 5.5 + assert abs(duration - 5.5) < 0.3 + + +class TestMuxExecutor: + """Test MUX node executor.""" + + def test_mux_video_audio(self, engine, test_video, test_audio): + """Test muxing video and audio.""" + builder = DAGBuilder() + video = builder.source(str(test_video)) + audio = builder.source(str(test_audio)) + muxed = builder.mux(video, audio) + builder.set_output(muxed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + +class TestAudioMixExecutor: + """Test AUDIO_MIX node executor.""" + + def test_audio_mix_simple(self, engine, cache_dir): + """Test simple audio mixing.""" + # Create two test audio files with different frequencies + audio1_path = cache_dir / "audio1.mp3" + audio2_path = cache_dir / "audio2.mp3" + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=440:duration=3", + "-c:a", "libmp3lame", + str(audio1_path) + ], capture_output=True, check=True) + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=880:duration=3", + "-c:a", "libmp3lame", + str(audio2_path) + ], capture_output=True, check=True) + + builder = DAGBuilder() + a1 = builder.source(str(audio1_path)) + a2 = builder.source(str(audio2_path)) + mixed = builder.audio_mix([a1, a2]) + builder.set_output(mixed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + def test_audio_mix_with_gains(self, engine, cache_dir): + """Test audio mixing with custom gains.""" + audio1_path = cache_dir / "audio1.mp3" + audio2_path = cache_dir / "audio2.mp3" + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=440:duration=3", + "-c:a", "libmp3lame", + str(audio1_path) + ], capture_output=True, check=True) + + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", "sine=frequency=880:duration=3", + "-c:a", "libmp3lame", + str(audio2_path) + ], capture_output=True, check=True) + + builder = DAGBuilder() + a1 = builder.source(str(audio1_path)) + a2 = builder.source(str(audio2_path)) + mixed = builder.audio_mix([a1, a2], gains=[1.0, 0.3]) + builder.set_output(mixed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + def test_audio_mix_three_inputs(self, engine, cache_dir): + """Test mixing three audio sources.""" + audio_paths = [] + for i, freq in enumerate([440, 660, 880]): + path = cache_dir / f"audio{i}.mp3" + subprocess.run([ + "ffmpeg", "-y", + "-f", "lavfi", "-i", f"sine=frequency={freq}:duration=2", + "-c:a", "libmp3lame", + str(path) + ], capture_output=True, check=True) + audio_paths.append(path) + + builder = DAGBuilder() + sources = [builder.source(str(p)) for p in audio_paths] + mixed = builder.audio_mix(sources, gains=[1.0, 0.5, 0.3]) + builder.set_output(mixed) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + + +class TestCaching: + """Test engine caching behavior.""" + + def test_cache_reuse(self, engine, test_video): + """Test that cached results are reused.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + # First execution + result1 = engine.execute(dag) + assert result1.success + assert result1.nodes_cached == 0 + assert result1.nodes_executed == 1 + + # Second execution should use cache + result2 = engine.execute(dag) + assert result2.success + assert result2.nodes_cached == 1 + assert result2.nodes_executed == 0 + + def test_clear_cache(self, engine, test_video): + """Test clearing cache.""" + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + engine.execute(dag) + assert engine.cache.stats.total_entries == 1 + + engine.clear_cache() + assert engine.cache.stats.total_entries == 0 + + +class TestProgressCallback: + """Test progress callback functionality.""" + + def test_progress_callback(self, engine, test_video): + """Test that progress callback is called.""" + progress_updates = [] + + def callback(progress): + progress_updates.append((progress.node_id, progress.status)) + + engine.set_progress_callback(callback) + + builder = DAGBuilder() + source = builder.source(str(test_video)) + builder.set_output(source) + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert len(progress_updates) > 0 + # Should have pending, running, completed + statuses = [p[1] for p in progress_updates] + assert "pending" in statuses + assert "completed" in statuses + + +class TestFullWorkflow: + """Test complete workflow.""" + + def test_full_pipeline(self, engine, test_video, test_audio): + """Test complete video processing pipeline.""" + builder = DAGBuilder() + + # Load sources + video = builder.source(str(test_video)) + audio = builder.source(str(test_audio)) + + # Extract segment + segment = builder.segment(video, duration=3.0) + + # Resize + resized = builder.resize(segment, width=640, height=480) + + # Apply effects + transformed = builder.transform(resized, effects={"saturation": 1.3}) + + # Mux with audio + final = builder.mux(transformed, audio) + builder.set_output(final) + + dag = builder.build() + + result = engine.execute(dag) + + assert result.success + assert result.output_path.exists() + assert result.nodes_executed == 6 # source, source, segment, resize, transform, mux diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..5149554 --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,110 @@ +# tests/test_primitive_new/test_executor.py +"""Tests for primitive executor module.""" + +import pytest +from pathlib import Path +from typing import Any, Dict, List + +from artdag.dag import NodeType +from artdag.executor import ( + Executor, + register_executor, + get_executor, + list_executors, + clear_executors, +) + + +class TestExecutorRegistry: + """Test executor registration.""" + + def setup_method(self): + """Clear registry before each test.""" + clear_executors() + + def teardown_method(self): + """Clear registry after each test.""" + clear_executors() + + def test_register_executor(self): + """Test registering an executor.""" + @register_executor(NodeType.SOURCE) + class TestSourceExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = get_executor(NodeType.SOURCE) + assert executor is not None + assert isinstance(executor, TestSourceExecutor) + + def test_register_custom_type(self): + """Test registering executor for custom type.""" + @register_executor("CUSTOM_NODE") + class CustomExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = get_executor("CUSTOM_NODE") + assert executor is not None + + def test_get_unregistered(self): + """Test getting unregistered executor.""" + executor = get_executor(NodeType.ANALYZE) + assert executor is None + + def test_list_executors(self): + """Test listing registered executors.""" + @register_executor(NodeType.SOURCE) + class SourceExec(Executor): + def execute(self, config, inputs, output_path): + return output_path + + @register_executor(NodeType.SEGMENT) + class SegmentExec(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executors = list_executors() + assert "SOURCE" in executors + assert "SEGMENT" in executors + + def test_overwrite_warning(self, caplog): + """Test warning when overwriting executor.""" + @register_executor(NodeType.SOURCE) + class FirstExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + # Register again - should warn + @register_executor(NodeType.SOURCE) + class SecondExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + # Second should be registered + executor = get_executor(NodeType.SOURCE) + assert isinstance(executor, SecondExecutor) + + +class TestExecutorBase: + """Test Executor base class.""" + + def test_validate_config_default(self): + """Test default validate_config returns empty list.""" + class TestExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = TestExecutor() + errors = executor.validate_config({"any": "config"}) + assert errors == [] + + def test_estimate_output_size(self): + """Test default output size estimation.""" + class TestExecutor(Executor): + def execute(self, config, inputs, output_path): + return output_path + + executor = TestExecutor() + size = executor.estimate_output_size({}, [100, 200, 300]) + assert size == 600 diff --git a/tests/test_ipfs_access.py b/tests/test_ipfs_access.py new file mode 100644 index 0000000..33795cb --- /dev/null +++ b/tests/test_ipfs_access.py @@ -0,0 +1,301 @@ +""" +Tests for IPFS access consistency. + +All IPFS access should use IPFS_API (multiaddr format) for consistency +with art-celery's ipfs_client.py. This ensures Docker deployments work +correctly since IPFS_API is set to /dns/ipfs/tcp/5001. +""" + +import os +import re +from pathlib import Path +from typing import Optional +from unittest.mock import patch, MagicMock + +import pytest + + +def multiaddr_to_url(multiaddr: str) -> str: + """ + Convert IPFS multiaddr to HTTP URL. + + This is the canonical conversion used by ipfs_client.py. + """ + # Handle /dns/hostname/tcp/port format + dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr) + if dns_match: + return f"http://{dns_match.group(1)}:{dns_match.group(2)}" + + # Handle /ip4/address/tcp/port format + ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr) + if ip4_match: + return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}" + + # Fallback: assume it's already a URL or use default + if multiaddr.startswith("http"): + return multiaddr + return "http://127.0.0.1:5001" + + +class TestMultiaddrConversion: + """Tests for multiaddr to URL conversion.""" + + def test_dns_format(self) -> None: + """Docker DNS format should convert correctly.""" + result = multiaddr_to_url("/dns/ipfs/tcp/5001") + assert result == "http://ipfs:5001" + + def test_dns4_format(self) -> None: + """dns4 format should work.""" + result = multiaddr_to_url("/dns4/ipfs.example.com/tcp/5001") + assert result == "http://ipfs.example.com:5001" + + def test_ip4_format(self) -> None: + """IPv4 format should convert correctly.""" + result = multiaddr_to_url("/ip4/127.0.0.1/tcp/5001") + assert result == "http://127.0.0.1:5001" + + def test_already_url(self) -> None: + """HTTP URLs should pass through.""" + result = multiaddr_to_url("http://localhost:5001") + assert result == "http://localhost:5001" + + def test_fallback(self) -> None: + """Unknown format should fallback to localhost.""" + result = multiaddr_to_url("garbage") + assert result == "http://127.0.0.1:5001" + + +class TestIPFSConfigConsistency: + """ + Tests to ensure IPFS configuration is consistent. + + The effect executor should use IPFS_API (like ipfs_client.py) + rather than a separate IPFS_GATEWAY variable. + """ + + def test_effect_module_should_not_use_gateway_var(self) -> None: + """ + Regression test: Effect module should use IPFS_API, not IPFS_GATEWAY. + + Bug found 2026-01-12: artdag/nodes/effect.py used IPFS_GATEWAY which + defaulted to http://127.0.0.1:8080. This doesn't work in Docker where + the IPFS node is a separate container. The ipfs_client.py uses IPFS_API + which is correctly set in docker-compose. + """ + from artdag.nodes import effect + + # Check if the module still has the old IPFS_GATEWAY variable + # After the fix, this should use IPFS_API instead + has_gateway_var = hasattr(effect, 'IPFS_GATEWAY') + has_api_var = hasattr(effect, 'IPFS_API') or hasattr(effect, '_get_ipfs_base_url') + + # This test documents the current buggy state + # After fix: has_gateway_var should be False, has_api_var should be True + if has_gateway_var and not has_api_var: + pytest.fail( + "Effect module uses IPFS_GATEWAY instead of IPFS_API. " + "This breaks Docker deployments where IPFS_API=/dns/ipfs/tcp/5001 " + "but IPFS_GATEWAY defaults to localhost." + ) + + def test_ipfs_api_default_is_localhost(self) -> None: + """IPFS_API should default to localhost for local development.""" + default_api = "/ip4/127.0.0.1/tcp/5001" + url = multiaddr_to_url(default_api) + assert "127.0.0.1" in url + assert "5001" in url + + def test_docker_ipfs_api_uses_service_name(self) -> None: + """In Docker, IPFS_API should use the service name.""" + docker_api = "/dns/ipfs/tcp/5001" + url = multiaddr_to_url(docker_api) + assert url == "http://ipfs:5001" + assert "127.0.0.1" not in url + + +class TestEffectFetchURL: + """Tests for the URL used to fetch effects from IPFS.""" + + def test_fetch_should_use_api_cat_endpoint(self) -> None: + """ + Effect fetch should use /api/v0/cat endpoint (like ipfs_client.py). + + The IPFS API's cat endpoint works reliably in Docker. + The gateway endpoint (port 8080) requires separate configuration. + """ + # The correct way to fetch via API + base_url = "http://ipfs:5001" + cid = "QmTestCid123" + correct_url = f"{base_url}/api/v0/cat?arg={cid}" + + assert "/api/v0/cat" in correct_url + assert "arg=" in correct_url + + def test_gateway_url_is_different_from_api(self) -> None: + """ + Document the difference between gateway and API URLs. + + Gateway: http://ipfs:8080/ipfs/{cid} (requires IPFS_GATEWAY config) + API: http://ipfs:5001/api/v0/cat?arg={cid} (uses IPFS_API config) + + Using the API is more reliable since IPFS_API is already configured + correctly in docker-compose.yml. + """ + cid = "QmTestCid123" + + # Gateway style (the old broken way) + gateway_url = f"http://ipfs:8080/ipfs/{cid}" + + # API style (the correct way) + api_url = f"http://ipfs:5001/api/v0/cat?arg={cid}" + + # These are different approaches + assert gateway_url != api_url + assert ":8080" in gateway_url + assert ":5001" in api_url + + +class TestEffectDependencies: + """Tests for effect dependency handling.""" + + def test_parse_pep723_dependencies(self) -> None: + """Should parse PEP 723 dependencies from effect source.""" + source = ''' +# /// script +# requires-python = ">=3.10" +# dependencies = ["numpy", "opencv-python"] +# /// +""" +@effect test_effect +""" + +def process_frame(frame, params, state): + return frame, state +''' + # Import the function after the fix is applied + from artdag.nodes.effect import _parse_pep723_dependencies + + deps = _parse_pep723_dependencies(source) + + assert deps == ["numpy", "opencv-python"] + + def test_parse_pep723_no_dependencies(self) -> None: + """Should return empty list if no dependencies block.""" + source = ''' +""" +@effect simple_effect +""" + +def process_frame(frame, params, state): + return frame, state +''' + from artdag.nodes.effect import _parse_pep723_dependencies + + deps = _parse_pep723_dependencies(source) + + assert deps == [] + + def test_ensure_dependencies_already_installed(self) -> None: + """Should return True if dependencies are already installed.""" + from artdag.nodes.effect import _ensure_dependencies + + # os is always available + result = _ensure_dependencies(["os"], "QmTest123") + + assert result is True + + def test_effect_with_missing_dependency_gives_clear_error(self, tmp_path: Path) -> None: + """ + Regression test: Missing dependencies should give clear error message. + + Bug found 2026-01-12: Effect with numpy dependency failed with + "No module named 'numpy'" but this was swallowed and reported as + "Unknown effect: invert" - very confusing. + """ + effects_dir = tmp_path / "_effects" + effect_cid = "QmTestEffectWithDeps" + + # Create effect that imports a non-existent module + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + (effect_dir / "effect.py").write_text(''' +# /// script +# requires-python = ">=3.10" +# dependencies = ["some_nonexistent_package_xyz"] +# /// +""" +@effect test_effect +""" +import some_nonexistent_package_xyz + +def process_frame(frame, params, state): + return frame, state +''') + + # The effect file exists + effect_path = effects_dir / effect_cid / "effect.py" + assert effect_path.exists() + + # When loading fails due to missing import, error should mention the dependency + with patch.dict(os.environ, {"CACHE_DIR": str(tmp_path)}): + from artdag.nodes.effect import _load_cached_effect + + # This should return None but log a clear error about the missing module + result = _load_cached_effect(effect_cid) + + # Currently returns None, which causes "Unknown effect" error + # The real issue is the dependency isn't installed + assert result is None + + +class TestEffectCacheAndFetch: + """Integration tests for effect caching and fetching.""" + + def test_effect_loads_from_cache_without_ipfs(self, tmp_path: Path) -> None: + """When effect is in cache, IPFS should not be contacted.""" + effects_dir = tmp_path / "_effects" + effect_cid = "QmTestEffect123" + + # Create cached effect + effect_dir = effects_dir / effect_cid + effect_dir.mkdir(parents=True) + (effect_dir / "effect.py").write_text(''' +def process_frame(frame, params, state): + return frame, state +''') + + # Patch environment and verify effect can be loaded + with patch.dict(os.environ, {"CACHE_DIR": str(tmp_path)}): + from artdag.nodes.effect import _load_cached_effect + + # Should load without hitting IPFS + effect_fn = _load_cached_effect(effect_cid) + assert effect_fn is not None + + def test_effect_fetch_uses_correct_endpoint(self, tmp_path: Path) -> None: + """When fetching from IPFS, should use API endpoint.""" + effects_dir = tmp_path / "_effects" + effects_dir.mkdir(parents=True) + effect_cid = "QmNonExistentEffect" + + with patch.dict(os.environ, { + "CACHE_DIR": str(tmp_path), + "IPFS_API": "/dns/ipfs/tcp/5001" + }): + with patch('requests.post') as mock_post: + # Set up mock to return effect source + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'def process_frame(f, p, s): return f, s' + mock_post.return_value = mock_response + + from artdag.nodes.effect import _load_cached_effect + + # Try to load - should attempt IPFS fetch + _load_cached_effect(effect_cid) + + # After fix, this should use the API endpoint + # Check if requests.post was called (API style) + # or requests.get was called (gateway style) + # The fix should make it use POST to /api/v0/cat