Import core (art-dag) as core/
This commit is contained in:
47
core/.gitignore
vendored
Normal file
47
core/.gitignore
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
|
||||
# Private keys (ActivityPub secrets)
|
||||
.cache/
|
||||
|
||||
# Test outputs
|
||||
test_cache/
|
||||
test_plan_output.json
|
||||
analysis.json
|
||||
plan.json
|
||||
plan_with_analysis.json
|
||||
110
core/README.md
Normal file
110
core/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
# artdag
|
||||
|
||||
Content-addressed DAG execution engine with ActivityPub ownership.
|
||||
|
||||
## Features
|
||||
|
||||
- **Content-addressed nodes**: `node_id = SHA3-256(type + config + inputs)` for automatic deduplication
|
||||
- **Quantum-resistant hashing**: SHA-3 throughout for future-proof integrity
|
||||
- **ActivityPub ownership**: Cryptographically signed ownership claims
|
||||
- **Federated identity**: `@user@artdag.rose-ash.com` style identities
|
||||
- **Pluggable executors**: Register custom node types
|
||||
- **Built-in video primitives**: SOURCE, SEGMENT, RESIZE, TRANSFORM, SEQUENCE, MUX, BLEND
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Optional: External Effect Tools
|
||||
|
||||
Some effects can use external tools for better performance:
|
||||
|
||||
**Pixelsort** (glitch art pixel sorting):
|
||||
```bash
|
||||
# Rust CLI (recommended - fast)
|
||||
cargo install --git https://github.com/Void-ux/pixelsort.git pixelsort
|
||||
|
||||
# Or Python CLI
|
||||
pip install git+https://github.com/Blotz/pixelsort-cli
|
||||
```
|
||||
|
||||
**Datamosh** (video glitch/corruption):
|
||||
```bash
|
||||
# FFglitch (recommended)
|
||||
./scripts/install-ffglitch.sh
|
||||
|
||||
# Or Python CLI
|
||||
pip install git+https://github.com/tiberiuiancu/datamoshing
|
||||
```
|
||||
|
||||
Check available tools:
|
||||
```bash
|
||||
python -m artdag.sexp.external_tools
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
from artdag import Engine, DAGBuilder, Registry
|
||||
from artdag.activitypub import OwnershipManager
|
||||
|
||||
# Create ownership manager
|
||||
manager = OwnershipManager("./my_registry")
|
||||
|
||||
# Create your identity
|
||||
actor = manager.create_actor("alice", "Alice")
|
||||
print(f"Created: {actor.handle}") # @alice@artdag.rose-ash.com
|
||||
|
||||
# Register an asset with ownership
|
||||
asset, activity = manager.register_asset(
|
||||
actor=actor,
|
||||
name="my_image",
|
||||
path="/path/to/image.jpg",
|
||||
tags=["photo", "art"],
|
||||
)
|
||||
print(f"Owned: {asset.name} (hash: {asset.content_hash})")
|
||||
|
||||
# Build and execute a DAG
|
||||
engine = Engine("./cache")
|
||||
builder = DAGBuilder()
|
||||
|
||||
source = builder.source(str(asset.path))
|
||||
resized = builder.resize(source, width=1920, height=1080)
|
||||
builder.set_output(resized)
|
||||
|
||||
result = engine.execute(builder.build())
|
||||
print(f"Output: {result.output_path}")
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
artdag/
|
||||
├── dag.py # Node, DAG, DAGBuilder
|
||||
├── cache.py # Content-addressed file cache
|
||||
├── executor.py # Base executor + registry
|
||||
├── engine.py # DAG execution engine
|
||||
├── activitypub/ # Identity + ownership
|
||||
│ ├── actor.py # Actor identity with RSA keys
|
||||
│ ├── activity.py # Create, Announce activities
|
||||
│ ├── signatures.py # RSA signing/verification
|
||||
│ └── ownership.py # Links actors to assets
|
||||
├── nodes/ # Built-in executors
|
||||
│ ├── source.py # SOURCE
|
||||
│ ├── transform.py # SEGMENT, RESIZE, TRANSFORM
|
||||
│ ├── compose.py # SEQUENCE, LAYER, MUX, BLEND
|
||||
│ └── effect.py # EFFECT (identity, etc.)
|
||||
└── effects/ # Effect implementations
|
||||
└── identity/ # The foundational identity effect
|
||||
```
|
||||
|
||||
## Related Repos
|
||||
|
||||
- **Registry**: https://git.rose-ash.com/art-dag/registry - Asset registry with ownership proofs
|
||||
- **Recipes**: https://git.rose-ash.com/art-dag/recipes - DAG recipes using effects
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
61
core/artdag/__init__.py
Normal file
61
core/artdag/__init__.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# artdag - Content-addressed DAG execution engine with ActivityPub ownership
|
||||
#
|
||||
# A standalone execution engine that processes directed acyclic graphs (DAGs)
|
||||
# where each node represents an operation. Nodes are content-addressed for
|
||||
# automatic caching and deduplication.
|
||||
#
|
||||
# Core concepts:
|
||||
# - Node: An operation with type, config, and inputs
|
||||
# - DAG: A graph of nodes with a designated output node
|
||||
# - Executor: Implements the actual operation for a node type
|
||||
# - Engine: Executes DAGs by resolving dependencies and running executors
|
||||
|
||||
from .dag import Node, DAG, DAGBuilder, NodeType
|
||||
from .cache import Cache, CacheEntry
|
||||
from .executor import Executor, register_executor, get_executor
|
||||
from .engine import Engine
|
||||
from .registry import Registry, Asset
|
||||
from .activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn
|
||||
|
||||
# Analysis and planning modules (optional, require extra dependencies)
|
||||
try:
|
||||
from .analysis import Analyzer, AnalysisResult
|
||||
except ImportError:
|
||||
Analyzer = None
|
||||
AnalysisResult = None
|
||||
|
||||
try:
|
||||
from .planning import RecipePlanner, ExecutionPlan, ExecutionStep
|
||||
except ImportError:
|
||||
RecipePlanner = None
|
||||
ExecutionPlan = None
|
||||
ExecutionStep = None
|
||||
|
||||
__all__ = [
|
||||
# Core
|
||||
"Node",
|
||||
"DAG",
|
||||
"DAGBuilder",
|
||||
"NodeType",
|
||||
"Cache",
|
||||
"CacheEntry",
|
||||
"Executor",
|
||||
"register_executor",
|
||||
"get_executor",
|
||||
"Engine",
|
||||
"Registry",
|
||||
"Asset",
|
||||
"Activity",
|
||||
"ActivityStore",
|
||||
"ActivityManager",
|
||||
"make_is_shared_fn",
|
||||
# Analysis (optional)
|
||||
"Analyzer",
|
||||
"AnalysisResult",
|
||||
# Planning (optional)
|
||||
"RecipePlanner",
|
||||
"ExecutionPlan",
|
||||
"ExecutionStep",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
371
core/artdag/activities.py
Normal file
371
core/artdag/activities.py
Normal file
@@ -0,0 +1,371 @@
|
||||
# artdag/activities.py
|
||||
"""
|
||||
Persistent activity (job) tracking for cache management.
|
||||
|
||||
Activities represent executions of DAGs. They track:
|
||||
- Input node IDs (sources)
|
||||
- Output node ID (terminal node)
|
||||
- Intermediate node IDs (everything in between)
|
||||
|
||||
This enables deletion rules:
|
||||
- Shared items (ActivityPub published) cannot be deleted
|
||||
- Inputs/outputs of activities cannot be deleted
|
||||
- Intermediates can be deleted (reconstructible)
|
||||
- Activities can only be discarded if no items are shared
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from .cache import Cache, CacheEntry
|
||||
from .dag import DAG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_is_shared_fn(activitypub_store: "ActivityStore") -> Callable[[str], bool]:
|
||||
"""
|
||||
Create an is_shared function from an ActivityPub ActivityStore.
|
||||
|
||||
Args:
|
||||
activitypub_store: The ActivityPub activity store
|
||||
(from artdag.activitypub.activity)
|
||||
|
||||
Returns:
|
||||
Function that checks if a cid has been published
|
||||
"""
|
||||
def is_shared(cid: str) -> bool:
|
||||
activities = activitypub_store.find_by_object_hash(cid)
|
||||
return any(a.activity_type == "Create" for a in activities)
|
||||
return is_shared
|
||||
|
||||
|
||||
@dataclass
|
||||
class Activity:
|
||||
"""
|
||||
A recorded execution of a DAG.
|
||||
|
||||
Tracks which cache entries are inputs, outputs, and intermediates
|
||||
to enforce deletion rules.
|
||||
"""
|
||||
activity_id: str
|
||||
input_ids: List[str] # Source node cache IDs
|
||||
output_id: str # Terminal node cache ID
|
||||
intermediate_ids: List[str] # Everything in between
|
||||
created_at: float
|
||||
status: str = "completed" # pending|running|completed|failed
|
||||
dag_snapshot: Optional[Dict[str, Any]] = None # Serialized DAG for reconstruction
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"activity_id": self.activity_id,
|
||||
"input_ids": self.input_ids,
|
||||
"output_id": self.output_id,
|
||||
"intermediate_ids": self.intermediate_ids,
|
||||
"created_at": self.created_at,
|
||||
"status": self.status,
|
||||
"dag_snapshot": self.dag_snapshot,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Activity":
|
||||
return cls(
|
||||
activity_id=data["activity_id"],
|
||||
input_ids=data["input_ids"],
|
||||
output_id=data["output_id"],
|
||||
intermediate_ids=data["intermediate_ids"],
|
||||
created_at=data["created_at"],
|
||||
status=data.get("status", "completed"),
|
||||
dag_snapshot=data.get("dag_snapshot"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dag(cls, dag: DAG, activity_id: str = None) -> "Activity":
|
||||
"""
|
||||
Create an Activity from a DAG.
|
||||
|
||||
Classifies nodes as inputs, output, or intermediates.
|
||||
"""
|
||||
if activity_id is None:
|
||||
activity_id = str(uuid.uuid4())
|
||||
|
||||
# Find input nodes (nodes with no inputs - sources)
|
||||
input_ids = []
|
||||
for node_id, node in dag.nodes.items():
|
||||
if not node.inputs:
|
||||
input_ids.append(node_id)
|
||||
|
||||
# Output is the terminal node
|
||||
output_id = dag.output_id
|
||||
|
||||
# Intermediates are everything else
|
||||
intermediate_ids = []
|
||||
for node_id in dag.nodes:
|
||||
if node_id not in input_ids and node_id != output_id:
|
||||
intermediate_ids.append(node_id)
|
||||
|
||||
return cls(
|
||||
activity_id=activity_id,
|
||||
input_ids=sorted(input_ids),
|
||||
output_id=output_id,
|
||||
intermediate_ids=sorted(intermediate_ids),
|
||||
created_at=time.time(),
|
||||
status="completed",
|
||||
dag_snapshot=dag.to_dict(),
|
||||
)
|
||||
|
||||
@property
|
||||
def all_node_ids(self) -> List[str]:
|
||||
"""All node IDs involved in this activity."""
|
||||
return self.input_ids + [self.output_id] + self.intermediate_ids
|
||||
|
||||
|
||||
class ActivityStore:
|
||||
"""
|
||||
Persistent storage for activities.
|
||||
|
||||
Provides methods to check deletion eligibility and perform deletions.
|
||||
"""
|
||||
|
||||
def __init__(self, store_dir: Path | str):
|
||||
self.store_dir = Path(store_dir)
|
||||
self.store_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._activities: Dict[str, Activity] = {}
|
||||
self._load()
|
||||
|
||||
def _index_path(self) -> Path:
|
||||
return self.store_dir / "activities.json"
|
||||
|
||||
def _load(self):
|
||||
"""Load activities from disk."""
|
||||
index_path = self._index_path()
|
||||
if index_path.exists():
|
||||
try:
|
||||
with open(index_path) as f:
|
||||
data = json.load(f)
|
||||
self._activities = {
|
||||
a["activity_id"]: Activity.from_dict(a)
|
||||
for a in data.get("activities", [])
|
||||
}
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Failed to load activities: {e}")
|
||||
self._activities = {}
|
||||
|
||||
def _save(self):
|
||||
"""Save activities to disk."""
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"activities": [a.to_dict() for a in self._activities.values()],
|
||||
}
|
||||
with open(self._index_path(), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def add(self, activity: Activity) -> None:
|
||||
"""Add an activity."""
|
||||
self._activities[activity.activity_id] = activity
|
||||
self._save()
|
||||
|
||||
def get(self, activity_id: str) -> Optional[Activity]:
|
||||
"""Get an activity by ID."""
|
||||
return self._activities.get(activity_id)
|
||||
|
||||
def remove(self, activity_id: str) -> bool:
|
||||
"""Remove an activity record (does not delete cache entries)."""
|
||||
if activity_id not in self._activities:
|
||||
return False
|
||||
del self._activities[activity_id]
|
||||
self._save()
|
||||
return True
|
||||
|
||||
def list(self) -> List[Activity]:
|
||||
"""List all activities."""
|
||||
return list(self._activities.values())
|
||||
|
||||
def find_by_input_ids(self, input_ids: List[str]) -> List[Activity]:
|
||||
"""Find activities with the same inputs (for UI grouping)."""
|
||||
sorted_inputs = sorted(input_ids)
|
||||
return [
|
||||
a for a in self._activities.values()
|
||||
if sorted(a.input_ids) == sorted_inputs
|
||||
]
|
||||
|
||||
def find_using_node(self, node_id: str) -> List[Activity]:
|
||||
"""Find all activities that reference a node ID."""
|
||||
return [
|
||||
a for a in self._activities.values()
|
||||
if node_id in a.all_node_ids
|
||||
]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._activities)
|
||||
|
||||
|
||||
class ActivityManager:
|
||||
"""
|
||||
Manages activities and cache deletion with sharing rules.
|
||||
|
||||
Deletion rules:
|
||||
1. Shared items (ActivityPub published) cannot be deleted
|
||||
2. Inputs/outputs of activities cannot be deleted
|
||||
3. Intermediates can be deleted (reconstructible)
|
||||
4. Activities can only be discarded if no items are shared
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache: Cache,
|
||||
activity_store: ActivityStore,
|
||||
is_shared_fn: Callable[[str], bool],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cache: The L1 cache
|
||||
activity_store: Activity persistence
|
||||
is_shared_fn: Function that checks if a cid is shared
|
||||
(published via ActivityPub)
|
||||
"""
|
||||
self.cache = cache
|
||||
self.activities = activity_store
|
||||
self._is_shared = is_shared_fn
|
||||
|
||||
def record_activity(self, dag: DAG) -> Activity:
|
||||
"""Record a completed DAG execution as an activity."""
|
||||
activity = Activity.from_dag(dag)
|
||||
self.activities.add(activity)
|
||||
return activity
|
||||
|
||||
def is_shared(self, node_id: str) -> bool:
|
||||
"""Check if a cache entry is shared (published via ActivityPub)."""
|
||||
entry = self.cache.get_entry(node_id)
|
||||
if not entry or not entry.cid:
|
||||
return False
|
||||
return self._is_shared(entry.cid)
|
||||
|
||||
def can_delete_cache_entry(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if a cache entry can be deleted.
|
||||
|
||||
Returns False if:
|
||||
- Entry is shared (ActivityPub published)
|
||||
- Entry is an input or output of any activity
|
||||
"""
|
||||
# Check if shared
|
||||
if self.is_shared(node_id):
|
||||
return False
|
||||
|
||||
# Check if it's an input or output of any activity
|
||||
for activity in self.activities.list():
|
||||
if node_id in activity.input_ids:
|
||||
return False
|
||||
if node_id == activity.output_id:
|
||||
return False
|
||||
|
||||
# It's either an intermediate or orphaned - can delete
|
||||
return True
|
||||
|
||||
def can_discard_activity(self, activity_id: str) -> bool:
|
||||
"""
|
||||
Check if an activity can be discarded.
|
||||
|
||||
Returns False if any cache entry (input, output, or intermediate)
|
||||
is shared via ActivityPub.
|
||||
"""
|
||||
activity = self.activities.get(activity_id)
|
||||
if not activity:
|
||||
return False
|
||||
|
||||
# Check if any item is shared
|
||||
for node_id in activity.all_node_ids:
|
||||
if self.is_shared(node_id):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def discard_activity(self, activity_id: str) -> bool:
|
||||
"""
|
||||
Discard an activity and delete its intermediate cache entries.
|
||||
|
||||
Returns False if the activity cannot be discarded (has shared items).
|
||||
|
||||
When discarded:
|
||||
- Intermediate cache entries are deleted
|
||||
- The activity record is removed
|
||||
- Inputs remain (may be used by other activities)
|
||||
- Output is deleted if orphaned (not shared, not used elsewhere)
|
||||
"""
|
||||
if not self.can_discard_activity(activity_id):
|
||||
return False
|
||||
|
||||
activity = self.activities.get(activity_id)
|
||||
if not activity:
|
||||
return False
|
||||
|
||||
output_id = activity.output_id
|
||||
intermediate_ids = list(activity.intermediate_ids)
|
||||
|
||||
# Remove the activity record first
|
||||
self.activities.remove(activity_id)
|
||||
|
||||
# Delete intermediates
|
||||
for node_id in intermediate_ids:
|
||||
self.cache.remove(node_id)
|
||||
logger.debug(f"Deleted intermediate: {node_id}")
|
||||
|
||||
# Check if output is now orphaned
|
||||
if self._is_orphaned(output_id) and not self.is_shared(output_id):
|
||||
self.cache.remove(output_id)
|
||||
logger.debug(f"Deleted orphaned output: {output_id}")
|
||||
|
||||
# Inputs remain - they may be used by other activities
|
||||
# But check if any are orphaned now
|
||||
for input_id in activity.input_ids:
|
||||
if self._is_orphaned(input_id) and not self.is_shared(input_id):
|
||||
self.cache.remove(input_id)
|
||||
logger.debug(f"Deleted orphaned input: {input_id}")
|
||||
|
||||
return True
|
||||
|
||||
def _is_orphaned(self, node_id: str) -> bool:
|
||||
"""Check if a node is not referenced by any activity."""
|
||||
for activity in self.activities.list():
|
||||
if node_id in activity.all_node_ids:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_deletable_entries(self) -> List[CacheEntry]:
|
||||
"""Get all cache entries that can be deleted."""
|
||||
deletable = []
|
||||
for entry in self.cache.list_entries():
|
||||
if self.can_delete_cache_entry(entry.node_id):
|
||||
deletable.append(entry)
|
||||
return deletable
|
||||
|
||||
def get_discardable_activities(self) -> List[Activity]:
|
||||
"""Get all activities that can be discarded."""
|
||||
return [
|
||||
a for a in self.activities.list()
|
||||
if self.can_discard_activity(a.activity_id)
|
||||
]
|
||||
|
||||
def cleanup_intermediates(self) -> int:
|
||||
"""
|
||||
Delete all intermediate cache entries.
|
||||
|
||||
Intermediates are safe to delete as they can be reconstructed
|
||||
from inputs using the DAG.
|
||||
|
||||
Returns:
|
||||
Number of entries deleted
|
||||
"""
|
||||
deleted = 0
|
||||
for activity in self.activities.list():
|
||||
for node_id in activity.intermediate_ids:
|
||||
if self.cache.has(node_id):
|
||||
self.cache.remove(node_id)
|
||||
deleted += 1
|
||||
return deleted
|
||||
33
core/artdag/activitypub/__init__.py
Normal file
33
core/artdag/activitypub/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# primitive/activitypub/__init__.py
|
||||
"""
|
||||
ActivityPub implementation for Art DAG.
|
||||
|
||||
Provides decentralized identity and ownership for assets.
|
||||
Domain: artdag.rose-ash.com
|
||||
|
||||
Core concepts:
|
||||
- Actor: A user identity with cryptographic keys
|
||||
- Object: An asset (image, video, etc.)
|
||||
- Activity: An action (Create, Announce, Like, etc.)
|
||||
- Signature: Cryptographic proof of authorship
|
||||
"""
|
||||
|
||||
from .actor import Actor, ActorStore
|
||||
from .activity import Activity, CreateActivity, ActivityStore
|
||||
from .signatures import sign_activity, verify_signature, verify_activity_ownership
|
||||
from .ownership import OwnershipManager, OwnershipRecord
|
||||
|
||||
__all__ = [
|
||||
"Actor",
|
||||
"ActorStore",
|
||||
"Activity",
|
||||
"CreateActivity",
|
||||
"ActivityStore",
|
||||
"sign_activity",
|
||||
"verify_signature",
|
||||
"verify_activity_ownership",
|
||||
"OwnershipManager",
|
||||
"OwnershipRecord",
|
||||
]
|
||||
|
||||
DOMAIN = "artdag.rose-ash.com"
|
||||
203
core/artdag/activitypub/activity.py
Normal file
203
core/artdag/activitypub/activity.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# primitive/activitypub/activity.py
|
||||
"""
|
||||
ActivityPub Activity types.
|
||||
|
||||
Activities represent actions taken by actors on objects.
|
||||
Key activity types for Art DAG:
|
||||
- Create: Actor creates/claims ownership of an object
|
||||
- Announce: Actor shares/boosts an object
|
||||
- Like: Actor endorses an object
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .actor import Actor, DOMAIN
|
||||
|
||||
|
||||
def _generate_id() -> str:
|
||||
"""Generate unique activity ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@dataclass
|
||||
class Activity:
|
||||
"""
|
||||
Base ActivityPub Activity.
|
||||
|
||||
Attributes:
|
||||
activity_id: Unique identifier
|
||||
activity_type: Type (Create, Announce, Like, etc.)
|
||||
actor_id: ID of the actor performing the activity
|
||||
object_data: The object of the activity
|
||||
published: ISO timestamp
|
||||
signature: Cryptographic signature (added after signing)
|
||||
"""
|
||||
activity_id: str
|
||||
activity_type: str
|
||||
actor_id: str
|
||||
object_data: Dict[str, Any]
|
||||
published: str = field(default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))
|
||||
signature: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_activitypub(self) -> Dict[str, Any]:
|
||||
"""Return ActivityPub JSON-LD representation."""
|
||||
activity = {
|
||||
"@context": "https://www.w3.org/ns/activitystreams",
|
||||
"type": self.activity_type,
|
||||
"id": f"https://{DOMAIN}/activities/{self.activity_id}",
|
||||
"actor": self.actor_id,
|
||||
"object": self.object_data,
|
||||
"published": self.published,
|
||||
}
|
||||
if self.signature:
|
||||
activity["signature"] = self.signature
|
||||
return activity
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize for storage."""
|
||||
return {
|
||||
"activity_id": self.activity_id,
|
||||
"activity_type": self.activity_type,
|
||||
"actor_id": self.actor_id,
|
||||
"object_data": self.object_data,
|
||||
"published": self.published,
|
||||
"signature": self.signature,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Activity":
|
||||
"""Deserialize from storage."""
|
||||
return cls(
|
||||
activity_id=data["activity_id"],
|
||||
activity_type=data["activity_type"],
|
||||
actor_id=data["actor_id"],
|
||||
object_data=data["object_data"],
|
||||
published=data.get("published", ""),
|
||||
signature=data.get("signature"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateActivity(Activity):
|
||||
"""
|
||||
Create activity - establishes ownership of an object.
|
||||
|
||||
Used when an actor creates or claims an asset.
|
||||
"""
|
||||
activity_type: str = field(default="Create", init=False)
|
||||
|
||||
@classmethod
|
||||
def for_asset(
|
||||
cls,
|
||||
actor: Actor,
|
||||
asset_name: str,
|
||||
cid: str,
|
||||
asset_type: str = "Image",
|
||||
metadata: Dict[str, Any] = None,
|
||||
) -> "CreateActivity":
|
||||
"""
|
||||
Create a Create activity for an asset.
|
||||
|
||||
Args:
|
||||
actor: The actor claiming ownership
|
||||
asset_name: Name of the asset
|
||||
cid: SHA-3 hash of the asset content
|
||||
asset_type: ActivityPub object type (Image, Video, Audio, etc.)
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
CreateActivity establishing ownership
|
||||
"""
|
||||
object_data = {
|
||||
"type": asset_type,
|
||||
"name": asset_name,
|
||||
"id": f"https://{DOMAIN}/objects/{cid}",
|
||||
"contentHash": {
|
||||
"algorithm": "sha3-256",
|
||||
"value": cid,
|
||||
},
|
||||
"attributedTo": actor.id,
|
||||
}
|
||||
if metadata:
|
||||
object_data["metadata"] = metadata
|
||||
|
||||
return cls(
|
||||
activity_id=_generate_id(),
|
||||
actor_id=actor.id,
|
||||
object_data=object_data,
|
||||
)
|
||||
|
||||
|
||||
class ActivityStore:
|
||||
"""
|
||||
Persistent storage for activities.
|
||||
|
||||
Activities are stored as an append-only log for auditability.
|
||||
"""
|
||||
|
||||
def __init__(self, store_dir: Path | str):
|
||||
self.store_dir = Path(store_dir)
|
||||
self.store_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._activities: List[Activity] = []
|
||||
self._load()
|
||||
|
||||
def _log_path(self) -> Path:
|
||||
return self.store_dir / "activities.json"
|
||||
|
||||
def _load(self):
|
||||
"""Load activities from disk."""
|
||||
log_path = self._log_path()
|
||||
if log_path.exists():
|
||||
with open(log_path) as f:
|
||||
data = json.load(f)
|
||||
self._activities = [
|
||||
Activity.from_dict(a) for a in data.get("activities", [])
|
||||
]
|
||||
|
||||
def _save(self):
|
||||
"""Save activities to disk."""
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"activities": [a.to_dict() for a in self._activities],
|
||||
}
|
||||
with open(self._log_path(), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def add(self, activity: Activity) -> None:
|
||||
"""Add an activity to the log."""
|
||||
self._activities.append(activity)
|
||||
self._save()
|
||||
|
||||
def get(self, activity_id: str) -> Optional[Activity]:
|
||||
"""Get an activity by ID."""
|
||||
for a in self._activities:
|
||||
if a.activity_id == activity_id:
|
||||
return a
|
||||
return None
|
||||
|
||||
def list(self) -> List[Activity]:
|
||||
"""List all activities."""
|
||||
return list(self._activities)
|
||||
|
||||
def find_by_actor(self, actor_id: str) -> List[Activity]:
|
||||
"""Find activities by actor."""
|
||||
return [a for a in self._activities if a.actor_id == actor_id]
|
||||
|
||||
def find_by_object_hash(self, cid: str) -> List[Activity]:
|
||||
"""Find activities referencing an object by hash."""
|
||||
results = []
|
||||
for a in self._activities:
|
||||
obj_hash = a.object_data.get("contentHash", {})
|
||||
if isinstance(obj_hash, dict) and obj_hash.get("value") == cid:
|
||||
results.append(a)
|
||||
elif a.object_data.get("contentHash") == cid:
|
||||
results.append(a)
|
||||
return results
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._activities)
|
||||
206
core/artdag/activitypub/actor.py
Normal file
206
core/artdag/activitypub/actor.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# primitive/activitypub/actor.py
|
||||
"""
|
||||
ActivityPub Actor management.
|
||||
|
||||
An Actor is an identity with:
|
||||
- Username and display name
|
||||
- RSA key pair for signing
|
||||
- ActivityPub-compliant JSON-LD representation
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
|
||||
DOMAIN = "artdag.rose-ash.com"
|
||||
|
||||
|
||||
def _generate_keypair() -> tuple[bytes, bytes]:
|
||||
"""Generate RSA key pair for signing."""
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
private_pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
public_pem = private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return private_pem, public_pem
|
||||
|
||||
|
||||
@dataclass
|
||||
class Actor:
|
||||
"""
|
||||
An ActivityPub Actor (identity).
|
||||
|
||||
Attributes:
|
||||
username: Unique username (e.g., "giles")
|
||||
display_name: Human-readable name
|
||||
public_key: PEM-encoded public key
|
||||
private_key: PEM-encoded private key (kept secret)
|
||||
created_at: Timestamp of creation
|
||||
"""
|
||||
username: str
|
||||
display_name: str
|
||||
public_key: bytes
|
||||
private_key: bytes
|
||||
created_at: float = field(default_factory=time.time)
|
||||
domain: str = DOMAIN
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""ActivityPub actor ID (URL)."""
|
||||
return f"https://{self.domain}/users/{self.username}"
|
||||
|
||||
@property
|
||||
def handle(self) -> str:
|
||||
"""Fediverse handle."""
|
||||
return f"@{self.username}@{self.domain}"
|
||||
|
||||
@property
|
||||
def inbox(self) -> str:
|
||||
"""ActivityPub inbox URL."""
|
||||
return f"{self.id}/inbox"
|
||||
|
||||
@property
|
||||
def outbox(self) -> str:
|
||||
"""ActivityPub outbox URL."""
|
||||
return f"{self.id}/outbox"
|
||||
|
||||
@property
|
||||
def key_id(self) -> str:
|
||||
"""Key ID for HTTP Signatures."""
|
||||
return f"{self.id}#main-key"
|
||||
|
||||
def to_activitypub(self) -> Dict[str, Any]:
|
||||
"""Return ActivityPub JSON-LD representation."""
|
||||
return {
|
||||
"@context": [
|
||||
"https://www.w3.org/ns/activitystreams",
|
||||
"https://w3id.org/security/v1",
|
||||
],
|
||||
"type": "Person",
|
||||
"id": self.id,
|
||||
"preferredUsername": self.username,
|
||||
"name": self.display_name,
|
||||
"inbox": self.inbox,
|
||||
"outbox": self.outbox,
|
||||
"publicKey": {
|
||||
"id": self.key_id,
|
||||
"owner": self.id,
|
||||
"publicKeyPem": self.public_key.decode("utf-8"),
|
||||
},
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize for storage."""
|
||||
return {
|
||||
"username": self.username,
|
||||
"display_name": self.display_name,
|
||||
"public_key": self.public_key.decode("utf-8"),
|
||||
"private_key": self.private_key.decode("utf-8"),
|
||||
"created_at": self.created_at,
|
||||
"domain": self.domain,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Actor":
|
||||
"""Deserialize from storage."""
|
||||
return cls(
|
||||
username=data["username"],
|
||||
display_name=data["display_name"],
|
||||
public_key=data["public_key"].encode("utf-8"),
|
||||
private_key=data["private_key"].encode("utf-8"),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
domain=data.get("domain", DOMAIN),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(cls, username: str, display_name: str = None) -> "Actor":
|
||||
"""Create a new actor with generated keys."""
|
||||
private_pem, public_pem = _generate_keypair()
|
||||
return cls(
|
||||
username=username,
|
||||
display_name=display_name or username,
|
||||
public_key=public_pem,
|
||||
private_key=private_pem,
|
||||
)
|
||||
|
||||
|
||||
class ActorStore:
|
||||
"""
|
||||
Persistent storage for actors.
|
||||
|
||||
Structure:
|
||||
store_dir/
|
||||
actors.json # Index of all actors
|
||||
keys/
|
||||
<username>.private.pem
|
||||
<username>.public.pem
|
||||
"""
|
||||
|
||||
def __init__(self, store_dir: Path | str):
|
||||
self.store_dir = Path(store_dir)
|
||||
self.store_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._actors: Dict[str, Actor] = {}
|
||||
self._load()
|
||||
|
||||
def _index_path(self) -> Path:
|
||||
return self.store_dir / "actors.json"
|
||||
|
||||
def _load(self):
|
||||
"""Load actors from disk."""
|
||||
index_path = self._index_path()
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
data = json.load(f)
|
||||
self._actors = {
|
||||
username: Actor.from_dict(actor_data)
|
||||
for username, actor_data in data.get("actors", {}).items()
|
||||
}
|
||||
|
||||
def _save(self):
|
||||
"""Save actors to disk."""
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"domain": DOMAIN,
|
||||
"actors": {
|
||||
username: actor.to_dict()
|
||||
for username, actor in self._actors.items()
|
||||
},
|
||||
}
|
||||
with open(self._index_path(), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def create(self, username: str, display_name: str = None) -> Actor:
|
||||
"""Create and store a new actor."""
|
||||
if username in self._actors:
|
||||
raise ValueError(f"Actor {username} already exists")
|
||||
|
||||
actor = Actor.create(username, display_name)
|
||||
self._actors[username] = actor
|
||||
self._save()
|
||||
return actor
|
||||
|
||||
def get(self, username: str) -> Optional[Actor]:
|
||||
"""Get an actor by username."""
|
||||
return self._actors.get(username)
|
||||
|
||||
def list(self) -> list[Actor]:
|
||||
"""List all actors."""
|
||||
return list(self._actors.values())
|
||||
|
||||
def __contains__(self, username: str) -> bool:
|
||||
return username in self._actors
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._actors)
|
||||
226
core/artdag/activitypub/ownership.py
Normal file
226
core/artdag/activitypub/ownership.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# primitive/activitypub/ownership.py
|
||||
"""
|
||||
Ownership integration between ActivityPub and Registry.
|
||||
|
||||
Connects actors, activities, and assets to establish provable ownership.
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .actor import Actor, ActorStore
|
||||
from .activity import Activity, CreateActivity, ActivityStore
|
||||
from .signatures import sign_activity, verify_activity_ownership
|
||||
from ..registry import Registry, Asset
|
||||
|
||||
|
||||
@dataclass
|
||||
class OwnershipRecord:
|
||||
"""
|
||||
A verified ownership record linking actor to asset.
|
||||
|
||||
Attributes:
|
||||
actor_handle: The actor's fediverse handle
|
||||
asset_name: Name of the owned asset
|
||||
cid: SHA-3 hash of the asset
|
||||
activity_id: ID of the Create activity establishing ownership
|
||||
verified: Whether the signature has been verified
|
||||
"""
|
||||
actor_handle: str
|
||||
asset_name: str
|
||||
cid: str
|
||||
activity_id: str
|
||||
verified: bool = False
|
||||
|
||||
|
||||
class OwnershipManager:
|
||||
"""
|
||||
Manages ownership relationships between actors and assets.
|
||||
|
||||
Integrates:
|
||||
- ActorStore: Identity management
|
||||
- Registry: Asset storage
|
||||
- ActivityStore: Ownership activities
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: Path | str):
|
||||
self.base_dir = Path(base_dir)
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize stores
|
||||
self.actors = ActorStore(self.base_dir / "actors")
|
||||
self.activities = ActivityStore(self.base_dir / "activities")
|
||||
self.registry = Registry(self.base_dir / "registry")
|
||||
|
||||
def create_actor(self, username: str, display_name: str = None) -> Actor:
|
||||
"""Create a new actor identity."""
|
||||
return self.actors.create(username, display_name)
|
||||
|
||||
def get_actor(self, username: str) -> Optional[Actor]:
|
||||
"""Get an actor by username."""
|
||||
return self.actors.get(username)
|
||||
|
||||
def register_asset(
|
||||
self,
|
||||
actor: Actor,
|
||||
name: str,
|
||||
cid: str,
|
||||
url: str = None,
|
||||
local_path: Path | str = None,
|
||||
tags: List[str] = None,
|
||||
metadata: Dict[str, Any] = None,
|
||||
) -> tuple[Asset, Activity]:
|
||||
"""
|
||||
Register an asset and establish ownership.
|
||||
|
||||
Creates the asset in the registry and a signed Create activity
|
||||
proving the actor's ownership.
|
||||
|
||||
Args:
|
||||
actor: The actor claiming ownership
|
||||
name: Name for the asset
|
||||
cid: SHA-3-256 hash of the content
|
||||
url: Public URL (canonical location)
|
||||
local_path: Optional local path
|
||||
tags: Optional tags
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (Asset, signed CreateActivity)
|
||||
"""
|
||||
# Add to registry
|
||||
asset = self.registry.add(
|
||||
name=name,
|
||||
cid=cid,
|
||||
url=url,
|
||||
local_path=local_path,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Create ownership activity
|
||||
activity = CreateActivity.for_asset(
|
||||
actor=actor,
|
||||
asset_name=name,
|
||||
cid=asset.cid,
|
||||
asset_type=self._asset_type_to_ap(asset.asset_type),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Sign the activity
|
||||
signed_activity = sign_activity(activity, actor)
|
||||
|
||||
# Store the activity
|
||||
self.activities.add(signed_activity)
|
||||
|
||||
return asset, signed_activity
|
||||
|
||||
def _asset_type_to_ap(self, asset_type: str) -> str:
|
||||
"""Convert registry asset type to ActivityPub type."""
|
||||
type_map = {
|
||||
"image": "Image",
|
||||
"video": "Video",
|
||||
"audio": "Audio",
|
||||
"unknown": "Document",
|
||||
}
|
||||
return type_map.get(asset_type, "Document")
|
||||
|
||||
def get_owner(self, asset_name: str) -> Optional[Actor]:
|
||||
"""
|
||||
Get the owner of an asset.
|
||||
|
||||
Finds the earliest Create activity for the asset and returns
|
||||
the actor if the signature is valid.
|
||||
"""
|
||||
asset = self.registry.get(asset_name)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
# Find Create activities for this asset
|
||||
activities = self.activities.find_by_object_hash(asset.cid)
|
||||
create_activities = [a for a in activities if a.activity_type == "Create"]
|
||||
|
||||
if not create_activities:
|
||||
return None
|
||||
|
||||
# Get the earliest (first owner)
|
||||
earliest = min(create_activities, key=lambda a: a.published)
|
||||
|
||||
# Extract username from actor_id
|
||||
# Format: https://artdag.rose-ash.com/users/{username}
|
||||
actor_id = earliest.actor_id
|
||||
if "/users/" in actor_id:
|
||||
username = actor_id.split("/users/")[-1]
|
||||
actor = self.actors.get(username)
|
||||
if actor and verify_activity_ownership(earliest, actor):
|
||||
return actor
|
||||
|
||||
return None
|
||||
|
||||
def verify_ownership(self, asset_name: str, actor: Actor) -> bool:
|
||||
"""
|
||||
Verify that an actor owns an asset.
|
||||
|
||||
Checks for a valid signed Create activity linking the actor
|
||||
to the asset.
|
||||
"""
|
||||
asset = self.registry.get(asset_name)
|
||||
if not asset:
|
||||
return False
|
||||
|
||||
activities = self.activities.find_by_object_hash(asset.cid)
|
||||
for activity in activities:
|
||||
if activity.activity_type == "Create" and activity.actor_id == actor.id:
|
||||
if verify_activity_ownership(activity, actor):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_owned_assets(self, actor: Actor) -> List[Asset]:
|
||||
"""List all assets owned by an actor."""
|
||||
activities = self.activities.find_by_actor(actor.id)
|
||||
owned = []
|
||||
|
||||
for activity in activities:
|
||||
if activity.activity_type == "Create":
|
||||
# Find asset by hash
|
||||
obj_hash = activity.object_data.get("contentHash", {})
|
||||
if isinstance(obj_hash, dict):
|
||||
hash_value = obj_hash.get("value")
|
||||
else:
|
||||
hash_value = obj_hash
|
||||
|
||||
if hash_value:
|
||||
asset = self.registry.find_by_hash(hash_value)
|
||||
if asset:
|
||||
owned.append(asset)
|
||||
|
||||
return owned
|
||||
|
||||
def get_ownership_records(self) -> List[OwnershipRecord]:
|
||||
"""Get all ownership records."""
|
||||
records = []
|
||||
|
||||
for activity in self.activities.list():
|
||||
if activity.activity_type != "Create":
|
||||
continue
|
||||
|
||||
# Extract info
|
||||
actor_id = activity.actor_id
|
||||
username = actor_id.split("/users/")[-1] if "/users/" in actor_id else "unknown"
|
||||
actor = self.actors.get(username)
|
||||
|
||||
obj_hash = activity.object_data.get("contentHash", {})
|
||||
hash_value = obj_hash.get("value") if isinstance(obj_hash, dict) else obj_hash
|
||||
|
||||
records.append(OwnershipRecord(
|
||||
actor_handle=actor.handle if actor else f"@{username}@unknown",
|
||||
asset_name=activity.object_data.get("name", "unknown"),
|
||||
cid=hash_value or "unknown",
|
||||
activity_id=activity.activity_id,
|
||||
verified=verify_activity_ownership(activity, actor) if actor else False,
|
||||
))
|
||||
|
||||
return records
|
||||
163
core/artdag/activitypub/signatures.py
Normal file
163
core/artdag/activitypub/signatures.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# primitive/activitypub/signatures.py
|
||||
"""
|
||||
Cryptographic signatures for ActivityPub.
|
||||
|
||||
Uses RSA-SHA256 signatures compatible with HTTP Signatures spec
|
||||
and Linked Data Signatures for ActivityPub.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding, rsa
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
|
||||
from .actor import Actor
|
||||
from .activity import Activity
|
||||
|
||||
|
||||
def _canonicalize(data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Canonicalize JSON for signing.
|
||||
|
||||
Uses JCS (JSON Canonicalization Scheme) - sorted keys, no whitespace.
|
||||
"""
|
||||
return json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _hash_sha256(data: str) -> bytes:
|
||||
"""Hash string with SHA-256."""
|
||||
return hashlib.sha256(data.encode()).digest()
|
||||
|
||||
|
||||
def sign_activity(activity: Activity, actor: Actor) -> Activity:
|
||||
"""
|
||||
Sign an activity with the actor's private key.
|
||||
|
||||
Uses Linked Data Signatures with RsaSignature2017.
|
||||
|
||||
Args:
|
||||
activity: The activity to sign
|
||||
actor: The actor whose key signs the activity
|
||||
|
||||
Returns:
|
||||
Activity with signature attached
|
||||
"""
|
||||
# Load private key
|
||||
private_key = serialization.load_pem_private_key(
|
||||
actor.private_key,
|
||||
password=None,
|
||||
)
|
||||
|
||||
# Create signature options
|
||||
created = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
|
||||
# Canonicalize the activity (without signature)
|
||||
activity_data = activity.to_activitypub()
|
||||
activity_data.pop("signature", None)
|
||||
canonical = _canonicalize(activity_data)
|
||||
|
||||
# Create the data to sign: hash of options + hash of document
|
||||
options = {
|
||||
"@context": "https://w3id.org/security/v1",
|
||||
"type": "RsaSignature2017",
|
||||
"creator": actor.key_id,
|
||||
"created": created,
|
||||
}
|
||||
options_hash = _hash_sha256(_canonicalize(options))
|
||||
document_hash = _hash_sha256(canonical)
|
||||
to_sign = options_hash + document_hash
|
||||
|
||||
# Sign with RSA-SHA256
|
||||
signature_bytes = private_key.sign(
|
||||
to_sign,
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
signature_value = base64.b64encode(signature_bytes).decode("utf-8")
|
||||
|
||||
# Attach signature to activity
|
||||
activity.signature = {
|
||||
"type": "RsaSignature2017",
|
||||
"creator": actor.key_id,
|
||||
"created": created,
|
||||
"signatureValue": signature_value,
|
||||
}
|
||||
|
||||
return activity
|
||||
|
||||
|
||||
def verify_signature(activity: Activity, public_key_pem: bytes) -> bool:
|
||||
"""
|
||||
Verify an activity's signature.
|
||||
|
||||
Args:
|
||||
activity: The activity with signature
|
||||
public_key_pem: PEM-encoded public key
|
||||
|
||||
Returns:
|
||||
True if signature is valid
|
||||
"""
|
||||
if not activity.signature:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load public key
|
||||
public_key = serialization.load_pem_public_key(public_key_pem)
|
||||
|
||||
# Reconstruct signature options
|
||||
options = {
|
||||
"@context": "https://w3id.org/security/v1",
|
||||
"type": activity.signature["type"],
|
||||
"creator": activity.signature["creator"],
|
||||
"created": activity.signature["created"],
|
||||
}
|
||||
|
||||
# Canonicalize activity without signature
|
||||
activity_data = activity.to_activitypub()
|
||||
activity_data.pop("signature", None)
|
||||
canonical = _canonicalize(activity_data)
|
||||
|
||||
# Recreate signed data
|
||||
options_hash = _hash_sha256(_canonicalize(options))
|
||||
document_hash = _hash_sha256(canonical)
|
||||
signed_data = options_hash + document_hash
|
||||
|
||||
# Decode and verify signature
|
||||
signature_bytes = base64.b64decode(activity.signature["signatureValue"])
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
signed_data,
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
return True
|
||||
|
||||
except (InvalidSignature, KeyError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def verify_activity_ownership(activity: Activity, actor: Actor) -> bool:
|
||||
"""
|
||||
Verify that an activity was signed by the claimed actor.
|
||||
|
||||
Args:
|
||||
activity: The activity to verify
|
||||
actor: The claimed actor
|
||||
|
||||
Returns:
|
||||
True if the activity was signed by this actor
|
||||
"""
|
||||
if not activity.signature:
|
||||
return False
|
||||
|
||||
# Check creator matches actor
|
||||
if activity.signature.get("creator") != actor.key_id:
|
||||
return False
|
||||
|
||||
# Verify signature
|
||||
return verify_signature(activity, actor.public_key)
|
||||
26
core/artdag/analysis/__init__.py
Normal file
26
core/artdag/analysis/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# artdag/analysis - Audio and video feature extraction
|
||||
#
|
||||
# Provides the Analysis phase of the 3-phase execution model:
|
||||
# 1. ANALYZE - Extract features from inputs
|
||||
# 2. PLAN - Generate execution plan with cache IDs
|
||||
# 3. EXECUTE - Run steps with caching
|
||||
|
||||
from .schema import (
|
||||
AnalysisResult,
|
||||
AudioFeatures,
|
||||
VideoFeatures,
|
||||
BeatInfo,
|
||||
EnergyEnvelope,
|
||||
SpectrumBands,
|
||||
)
|
||||
from .analyzer import Analyzer
|
||||
|
||||
__all__ = [
|
||||
"Analyzer",
|
||||
"AnalysisResult",
|
||||
"AudioFeatures",
|
||||
"VideoFeatures",
|
||||
"BeatInfo",
|
||||
"EnergyEnvelope",
|
||||
"SpectrumBands",
|
||||
]
|
||||
282
core/artdag/analysis/analyzer.py
Normal file
282
core/artdag/analysis/analyzer.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# artdag/analysis/analyzer.py
|
||||
"""
|
||||
Main Analyzer class for the Analysis phase.
|
||||
|
||||
Coordinates audio and video feature extraction with caching.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .schema import AnalysisResult, AudioFeatures, VideoFeatures
|
||||
from .audio import analyze_audio, FEATURE_ALL as AUDIO_ALL
|
||||
from .video import analyze_video, FEATURE_ALL as VIDEO_ALL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalysisCache:
|
||||
"""
|
||||
Simple file-based cache for analysis results.
|
||||
|
||||
Stores results as JSON files keyed by analysis cache_id.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Path):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _path_for(self, cache_id: str) -> Path:
|
||||
"""Get cache file path for a cache_id."""
|
||||
return self.cache_dir / f"{cache_id}.json"
|
||||
|
||||
def get(self, cache_id: str) -> Optional[AnalysisResult]:
|
||||
"""Retrieve cached analysis result."""
|
||||
path = self._path_for(cache_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
return AnalysisResult.from_dict(data)
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"Failed to load analysis cache {cache_id}: {e}")
|
||||
return None
|
||||
|
||||
def put(self, result: AnalysisResult) -> None:
|
||||
"""Store analysis result in cache."""
|
||||
path = self._path_for(result.cache_id)
|
||||
with open(path, "w") as f:
|
||||
json.dump(result.to_dict(), f, indent=2)
|
||||
|
||||
def has(self, cache_id: str) -> bool:
|
||||
"""Check if analysis result is cached."""
|
||||
return self._path_for(cache_id).exists()
|
||||
|
||||
def remove(self, cache_id: str) -> bool:
|
||||
"""Remove cached analysis result."""
|
||||
path = self._path_for(cache_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class Analyzer:
|
||||
"""
|
||||
Analyzes media inputs to extract features.
|
||||
|
||||
The Analyzer is the first phase of the 3-phase execution model.
|
||||
It extracts features from inputs that inform downstream processing.
|
||||
|
||||
Example:
|
||||
analyzer = Analyzer(cache_dir=Path("./analysis_cache"))
|
||||
|
||||
# Analyze a music file for beats
|
||||
result = analyzer.analyze(
|
||||
input_path=Path("/path/to/music.mp3"),
|
||||
input_hash="abc123...",
|
||||
features=["beats", "energy"]
|
||||
)
|
||||
|
||||
print(f"Tempo: {result.tempo} BPM")
|
||||
print(f"Beats: {result.beat_times}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Optional[Path] = None,
|
||||
content_cache: Optional["Cache"] = None, # artdag.Cache for input lookup
|
||||
):
|
||||
"""
|
||||
Initialize the Analyzer.
|
||||
|
||||
Args:
|
||||
cache_dir: Directory for analysis cache. If None, no caching.
|
||||
content_cache: artdag Cache for looking up inputs by hash
|
||||
"""
|
||||
self.cache = AnalysisCache(cache_dir) if cache_dir else None
|
||||
self.content_cache = content_cache
|
||||
|
||||
def get_input_path(self, input_hash: str, input_path: Optional[Path] = None) -> Path:
|
||||
"""
|
||||
Resolve input to a file path.
|
||||
|
||||
Args:
|
||||
input_hash: Content hash of the input
|
||||
input_path: Optional direct path to file
|
||||
|
||||
Returns:
|
||||
Path to the input file
|
||||
|
||||
Raises:
|
||||
ValueError: If input cannot be resolved
|
||||
"""
|
||||
if input_path and input_path.exists():
|
||||
return input_path
|
||||
|
||||
if self.content_cache:
|
||||
entry = self.content_cache.get(input_hash)
|
||||
if entry:
|
||||
return Path(entry.output_path)
|
||||
|
||||
raise ValueError(f"Cannot resolve input {input_hash}: no path provided and not in cache")
|
||||
|
||||
def analyze(
|
||||
self,
|
||||
input_hash: str,
|
||||
features: List[str],
|
||||
input_path: Optional[Path] = None,
|
||||
media_type: Optional[str] = None,
|
||||
) -> AnalysisResult:
|
||||
"""
|
||||
Analyze an input file and extract features.
|
||||
|
||||
Args:
|
||||
input_hash: Content hash of the input (for cache key)
|
||||
features: List of features to extract:
|
||||
Audio: "beats", "tempo", "energy", "spectrum", "onsets"
|
||||
Video: "metadata", "motion_tempo", "scene_changes"
|
||||
Meta: "all" (extracts all relevant features)
|
||||
input_path: Optional direct path to file
|
||||
media_type: Optional hint ("audio", "video", or None for auto-detect)
|
||||
|
||||
Returns:
|
||||
AnalysisResult with extracted features
|
||||
"""
|
||||
# Compute cache ID
|
||||
temp_result = AnalysisResult(
|
||||
input_hash=input_hash,
|
||||
features_requested=sorted(features),
|
||||
)
|
||||
cache_id = temp_result.cache_id
|
||||
|
||||
# Check cache
|
||||
if self.cache and self.cache.has(cache_id):
|
||||
cached = self.cache.get(cache_id)
|
||||
if cached:
|
||||
logger.info(f"Analysis cache hit: {cache_id[:16]}...")
|
||||
return cached
|
||||
|
||||
# Resolve input path
|
||||
path = self.get_input_path(input_hash, input_path)
|
||||
logger.info(f"Analyzing {path} for features: {features}")
|
||||
|
||||
# Detect media type if not specified
|
||||
if media_type is None:
|
||||
media_type = self._detect_media_type(path)
|
||||
|
||||
# Extract features
|
||||
audio_features = None
|
||||
video_features = None
|
||||
|
||||
# Normalize features
|
||||
if "all" in features:
|
||||
audio_features_list = [AUDIO_ALL]
|
||||
video_features_list = [VIDEO_ALL]
|
||||
else:
|
||||
audio_features_list = [f for f in features if f in ("beats", "tempo", "energy", "spectrum", "onsets")]
|
||||
video_features_list = [f for f in features if f in ("metadata", "motion_tempo", "scene_changes")]
|
||||
|
||||
if media_type in ("audio", "video") and audio_features_list:
|
||||
try:
|
||||
audio_features = analyze_audio(path, features=audio_features_list)
|
||||
except Exception as e:
|
||||
logger.warning(f"Audio analysis failed: {e}")
|
||||
|
||||
if media_type == "video" and video_features_list:
|
||||
try:
|
||||
video_features = analyze_video(path, features=video_features_list)
|
||||
except Exception as e:
|
||||
logger.warning(f"Video analysis failed: {e}")
|
||||
|
||||
result = AnalysisResult(
|
||||
input_hash=input_hash,
|
||||
features_requested=sorted(features),
|
||||
audio=audio_features,
|
||||
video=video_features,
|
||||
analyzed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
# Cache result
|
||||
if self.cache:
|
||||
self.cache.put(result)
|
||||
|
||||
return result
|
||||
|
||||
def analyze_multiple(
|
||||
self,
|
||||
inputs: Dict[str, Path],
|
||||
features: List[str],
|
||||
) -> Dict[str, AnalysisResult]:
|
||||
"""
|
||||
Analyze multiple inputs.
|
||||
|
||||
Args:
|
||||
inputs: Dict mapping input_hash to file path
|
||||
features: Features to extract from all inputs
|
||||
|
||||
Returns:
|
||||
Dict mapping input_hash to AnalysisResult
|
||||
"""
|
||||
results = {}
|
||||
for input_hash, input_path in inputs.items():
|
||||
try:
|
||||
results[input_hash] = self.analyze(
|
||||
input_hash=input_hash,
|
||||
features=features,
|
||||
input_path=input_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Analysis failed for {input_hash}: {e}")
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
def _detect_media_type(self, path: Path) -> str:
|
||||
"""
|
||||
Detect if file is audio or video.
|
||||
|
||||
Args:
|
||||
path: Path to media file
|
||||
|
||||
Returns:
|
||||
"audio" or "video"
|
||||
"""
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet",
|
||||
"-print_format", "json",
|
||||
"-show_streams",
|
||||
str(path)
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
data = json.loads(result.stdout)
|
||||
streams = data.get("streams", [])
|
||||
|
||||
has_video = any(s.get("codec_type") == "video" for s in streams)
|
||||
has_audio = any(s.get("codec_type") == "audio" for s in streams)
|
||||
|
||||
if has_video:
|
||||
return "video"
|
||||
elif has_audio:
|
||||
return "audio"
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
except (subprocess.CalledProcessError, json.JSONDecodeError):
|
||||
# Fall back to extension-based detection
|
||||
ext = path.suffix.lower()
|
||||
if ext in (".mp4", ".mov", ".avi", ".mkv", ".webm"):
|
||||
return "video"
|
||||
elif ext in (".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac"):
|
||||
return "audio"
|
||||
return "unknown"
|
||||
336
core/artdag/analysis/audio.py
Normal file
336
core/artdag/analysis/audio.py
Normal file
@@ -0,0 +1,336 @@
|
||||
# artdag/analysis/audio.py
|
||||
"""
|
||||
Audio feature extraction.
|
||||
|
||||
Uses librosa for beat detection, energy analysis, and spectral features.
|
||||
Falls back to basic ffprobe if librosa is not available.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from .schema import AudioFeatures, BeatInfo, EnergyEnvelope, SpectrumBands
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature names for requesting specific analysis
|
||||
FEATURE_BEATS = "beats"
|
||||
FEATURE_TEMPO = "tempo"
|
||||
FEATURE_ENERGY = "energy"
|
||||
FEATURE_SPECTRUM = "spectrum"
|
||||
FEATURE_ONSETS = "onsets"
|
||||
FEATURE_ALL = "all"
|
||||
|
||||
|
||||
def _get_audio_info_ffprobe(path: Path) -> Tuple[float, int, int]:
|
||||
"""Get basic audio info using ffprobe."""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet",
|
||||
"-print_format", "json",
|
||||
"-show_streams",
|
||||
"-select_streams", "a:0",
|
||||
str(path)
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
data = json.loads(result.stdout)
|
||||
if not data.get("streams"):
|
||||
raise ValueError("No audio stream found")
|
||||
|
||||
stream = data["streams"][0]
|
||||
duration = float(stream.get("duration", 0))
|
||||
sample_rate = int(stream.get("sample_rate", 44100))
|
||||
channels = int(stream.get("channels", 2))
|
||||
return duration, sample_rate, channels
|
||||
except (subprocess.CalledProcessError, json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"ffprobe failed: {e}")
|
||||
raise ValueError(f"Could not read audio info: {e}")
|
||||
|
||||
|
||||
def _extract_audio_to_wav(path: Path, duration: Optional[float] = None) -> Path:
|
||||
"""Extract audio to temporary WAV file for librosa processing."""
|
||||
import tempfile
|
||||
wav_path = Path(tempfile.mktemp(suffix=".wav"))
|
||||
|
||||
cmd = ["ffmpeg", "-y", "-i", str(path)]
|
||||
if duration:
|
||||
cmd.extend(["-t", str(duration)])
|
||||
cmd.extend([
|
||||
"-vn", # No video
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", "22050", # Resample to 22050 Hz for librosa
|
||||
"-ac", "1", # Mono
|
||||
str(wav_path)
|
||||
])
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
return wav_path
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error(f"Audio extraction failed: {e.stderr}")
|
||||
raise ValueError(f"Could not extract audio: {e}")
|
||||
|
||||
|
||||
def analyze_beats(path: Path, sample_rate: int = 22050) -> BeatInfo:
|
||||
"""
|
||||
Detect beats and tempo using librosa.
|
||||
|
||||
Args:
|
||||
path: Path to audio file (or pre-extracted WAV)
|
||||
sample_rate: Sample rate for analysis
|
||||
|
||||
Returns:
|
||||
BeatInfo with beat times, tempo, and confidence
|
||||
"""
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
raise ImportError("librosa required for beat detection. Install with: pip install librosa")
|
||||
|
||||
# Load audio
|
||||
y, sr = librosa.load(str(path), sr=sample_rate, mono=True)
|
||||
|
||||
# Detect tempo and beats
|
||||
tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr)
|
||||
|
||||
# Convert frames to times
|
||||
beat_times = librosa.frames_to_time(beat_frames, sr=sr).tolist()
|
||||
|
||||
# Estimate confidence from onset strength consistency
|
||||
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
|
||||
beat_strength = onset_env[beat_frames] if len(beat_frames) > 0 else []
|
||||
confidence = float(beat_strength.mean() / onset_env.max()) if len(beat_strength) > 0 and onset_env.max() > 0 else 0.5
|
||||
|
||||
# Detect downbeats (first beat of each bar)
|
||||
# Use beat phase to estimate bar positions
|
||||
downbeat_times = None
|
||||
if len(beat_times) >= 4:
|
||||
# Assume 4/4 time signature, downbeats every 4 beats
|
||||
downbeat_times = [beat_times[i] for i in range(0, len(beat_times), 4)]
|
||||
|
||||
return BeatInfo(
|
||||
beat_times=beat_times,
|
||||
tempo=float(tempo) if hasattr(tempo, '__float__') else float(tempo[0]) if len(tempo) > 0 else 120.0,
|
||||
confidence=min(1.0, max(0.0, confidence)),
|
||||
downbeat_times=downbeat_times,
|
||||
time_signature=4,
|
||||
)
|
||||
|
||||
|
||||
def analyze_energy(path: Path, window_ms: float = 50.0, sample_rate: int = 22050) -> EnergyEnvelope:
|
||||
"""
|
||||
Extract energy (loudness) envelope.
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
window_ms: Analysis window size in milliseconds
|
||||
sample_rate: Sample rate for analysis
|
||||
|
||||
Returns:
|
||||
EnergyEnvelope with times and normalized values
|
||||
"""
|
||||
try:
|
||||
import librosa
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("librosa and numpy required. Install with: pip install librosa numpy")
|
||||
|
||||
y, sr = librosa.load(str(path), sr=sample_rate, mono=True)
|
||||
|
||||
# Calculate frame size from window_ms
|
||||
hop_length = int(sr * window_ms / 1000)
|
||||
|
||||
# RMS energy
|
||||
rms = librosa.feature.rms(y=y, hop_length=hop_length)[0]
|
||||
|
||||
# Normalize to 0-1
|
||||
rms_max = rms.max()
|
||||
if rms_max > 0:
|
||||
rms_normalized = rms / rms_max
|
||||
else:
|
||||
rms_normalized = rms
|
||||
|
||||
# Generate time points
|
||||
times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=hop_length)
|
||||
|
||||
return EnergyEnvelope(
|
||||
times=times.tolist(),
|
||||
values=rms_normalized.tolist(),
|
||||
window_ms=window_ms,
|
||||
)
|
||||
|
||||
|
||||
def analyze_spectrum(
|
||||
path: Path,
|
||||
band_ranges: Optional[dict] = None,
|
||||
window_ms: float = 50.0,
|
||||
sample_rate: int = 22050
|
||||
) -> SpectrumBands:
|
||||
"""
|
||||
Extract frequency band envelopes.
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
band_ranges: Dict mapping band name to (low_hz, high_hz)
|
||||
window_ms: Analysis window size
|
||||
sample_rate: Sample rate
|
||||
|
||||
Returns:
|
||||
SpectrumBands with bass, mid, high envelopes
|
||||
"""
|
||||
try:
|
||||
import librosa
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
raise ImportError("librosa and numpy required")
|
||||
|
||||
if band_ranges is None:
|
||||
band_ranges = {
|
||||
"bass": (20, 200),
|
||||
"mid": (200, 2000),
|
||||
"high": (2000, 20000),
|
||||
}
|
||||
|
||||
y, sr = librosa.load(str(path), sr=sample_rate, mono=True)
|
||||
hop_length = int(sr * window_ms / 1000)
|
||||
|
||||
# Compute STFT
|
||||
n_fft = 2048
|
||||
stft = np.abs(librosa.stft(y, n_fft=n_fft, hop_length=hop_length))
|
||||
|
||||
# Frequency bins
|
||||
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
|
||||
|
||||
def band_energy(low_hz: float, high_hz: float) -> List[float]:
|
||||
"""Sum energy in frequency band."""
|
||||
mask = (freqs >= low_hz) & (freqs <= high_hz)
|
||||
if not mask.any():
|
||||
return [0.0] * stft.shape[1]
|
||||
band = stft[mask, :].sum(axis=0)
|
||||
# Normalize
|
||||
band_max = band.max()
|
||||
if band_max > 0:
|
||||
band = band / band_max
|
||||
return band.tolist()
|
||||
|
||||
times = librosa.frames_to_time(np.arange(stft.shape[1]), sr=sr, hop_length=hop_length)
|
||||
|
||||
return SpectrumBands(
|
||||
bass=band_energy(*band_ranges["bass"]),
|
||||
mid=band_energy(*band_ranges["mid"]),
|
||||
high=band_energy(*band_ranges["high"]),
|
||||
times=times.tolist(),
|
||||
band_ranges=band_ranges,
|
||||
)
|
||||
|
||||
|
||||
def analyze_onsets(path: Path, sample_rate: int = 22050) -> List[float]:
|
||||
"""
|
||||
Detect onset times (note/sound starts).
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
sample_rate: Sample rate
|
||||
|
||||
Returns:
|
||||
List of onset times in seconds
|
||||
"""
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
raise ImportError("librosa required")
|
||||
|
||||
y, sr = librosa.load(str(path), sr=sample_rate, mono=True)
|
||||
|
||||
# Detect onsets
|
||||
onset_frames = librosa.onset.onset_detect(y=y, sr=sr)
|
||||
onset_times = librosa.frames_to_time(onset_frames, sr=sr)
|
||||
|
||||
return onset_times.tolist()
|
||||
|
||||
|
||||
def analyze_audio(
|
||||
path: Path,
|
||||
features: Optional[List[str]] = None,
|
||||
) -> AudioFeatures:
|
||||
"""
|
||||
Extract audio features from file.
|
||||
|
||||
Args:
|
||||
path: Path to audio/video file
|
||||
features: List of features to extract. Options:
|
||||
- "beats": Beat detection (tempo, beat times)
|
||||
- "energy": Loudness envelope
|
||||
- "spectrum": Frequency band envelopes
|
||||
- "onsets": Note onset times
|
||||
- "all": All features
|
||||
|
||||
Returns:
|
||||
AudioFeatures with requested analysis
|
||||
"""
|
||||
if features is None:
|
||||
features = [FEATURE_ALL]
|
||||
|
||||
# Normalize features
|
||||
if FEATURE_ALL in features:
|
||||
features = [FEATURE_BEATS, FEATURE_ENERGY, FEATURE_SPECTRUM, FEATURE_ONSETS]
|
||||
|
||||
# Get basic info via ffprobe
|
||||
duration, sample_rate, channels = _get_audio_info_ffprobe(path)
|
||||
|
||||
result = AudioFeatures(
|
||||
duration=duration,
|
||||
sample_rate=sample_rate,
|
||||
channels=channels,
|
||||
)
|
||||
|
||||
# Check if librosa is available for advanced features
|
||||
try:
|
||||
import librosa # noqa: F401
|
||||
has_librosa = True
|
||||
except ImportError:
|
||||
has_librosa = False
|
||||
if any(f in features for f in [FEATURE_BEATS, FEATURE_ENERGY, FEATURE_SPECTRUM, FEATURE_ONSETS]):
|
||||
logger.warning("librosa not available, skipping advanced audio features")
|
||||
|
||||
if not has_librosa:
|
||||
return result
|
||||
|
||||
# Extract audio to WAV for librosa
|
||||
wav_path = None
|
||||
try:
|
||||
wav_path = _extract_audio_to_wav(path)
|
||||
|
||||
if FEATURE_BEATS in features or FEATURE_TEMPO in features:
|
||||
try:
|
||||
result.beats = analyze_beats(wav_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Beat detection failed: {e}")
|
||||
|
||||
if FEATURE_ENERGY in features:
|
||||
try:
|
||||
result.energy = analyze_energy(wav_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Energy analysis failed: {e}")
|
||||
|
||||
if FEATURE_SPECTRUM in features:
|
||||
try:
|
||||
result.spectrum = analyze_spectrum(wav_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Spectrum analysis failed: {e}")
|
||||
|
||||
if FEATURE_ONSETS in features:
|
||||
try:
|
||||
result.onsets = analyze_onsets(wav_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Onset detection failed: {e}")
|
||||
|
||||
finally:
|
||||
# Clean up temporary WAV file
|
||||
if wav_path and wav_path.exists():
|
||||
wav_path.unlink()
|
||||
|
||||
return result
|
||||
352
core/artdag/analysis/schema.py
Normal file
352
core/artdag/analysis/schema.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# artdag/analysis/schema.py
|
||||
"""
|
||||
Data structures for analysis results.
|
||||
|
||||
Analysis extracts features from input media that inform downstream processing.
|
||||
Results are cached by: analysis_cache_id = SHA3-256(input_hash + sorted(features))
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""Create stable hash from arbitrary data."""
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class BeatInfo:
|
||||
"""
|
||||
Beat detection results.
|
||||
|
||||
Attributes:
|
||||
beat_times: List of beat positions in seconds
|
||||
tempo: Estimated tempo in BPM
|
||||
confidence: Tempo detection confidence (0-1)
|
||||
downbeat_times: First beat of each bar (if detected)
|
||||
time_signature: Detected or assumed time signature (e.g., 4)
|
||||
"""
|
||||
beat_times: List[float]
|
||||
tempo: float
|
||||
confidence: float = 1.0
|
||||
downbeat_times: Optional[List[float]] = None
|
||||
time_signature: int = 4
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"beat_times": self.beat_times,
|
||||
"tempo": self.tempo,
|
||||
"confidence": self.confidence,
|
||||
"downbeat_times": self.downbeat_times,
|
||||
"time_signature": self.time_signature,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BeatInfo":
|
||||
return cls(
|
||||
beat_times=data["beat_times"],
|
||||
tempo=data["tempo"],
|
||||
confidence=data.get("confidence", 1.0),
|
||||
downbeat_times=data.get("downbeat_times"),
|
||||
time_signature=data.get("time_signature", 4),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnergyEnvelope:
|
||||
"""
|
||||
Energy (loudness) over time.
|
||||
|
||||
Attributes:
|
||||
times: Time points in seconds
|
||||
values: Energy values (0-1, normalized)
|
||||
window_ms: Analysis window size in milliseconds
|
||||
"""
|
||||
times: List[float]
|
||||
values: List[float]
|
||||
window_ms: float = 50.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"times": self.times,
|
||||
"values": self.values,
|
||||
"window_ms": self.window_ms,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EnergyEnvelope":
|
||||
return cls(
|
||||
times=data["times"],
|
||||
values=data["values"],
|
||||
window_ms=data.get("window_ms", 50.0),
|
||||
)
|
||||
|
||||
def at_time(self, t: float) -> float:
|
||||
"""Interpolate energy value at given time."""
|
||||
if not self.times:
|
||||
return 0.0
|
||||
if t <= self.times[0]:
|
||||
return self.values[0]
|
||||
if t >= self.times[-1]:
|
||||
return self.values[-1]
|
||||
|
||||
# Binary search for bracketing indices
|
||||
lo, hi = 0, len(self.times) - 1
|
||||
while hi - lo > 1:
|
||||
mid = (lo + hi) // 2
|
||||
if self.times[mid] <= t:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
# Linear interpolation
|
||||
t0, t1 = self.times[lo], self.times[hi]
|
||||
v0, v1 = self.values[lo], self.values[hi]
|
||||
alpha = (t - t0) / (t1 - t0) if t1 != t0 else 0
|
||||
return v0 + alpha * (v1 - v0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpectrumBands:
|
||||
"""
|
||||
Frequency band envelopes over time.
|
||||
|
||||
Attributes:
|
||||
bass: Low frequency envelope (20-200 Hz typical)
|
||||
mid: Mid frequency envelope (200-2000 Hz typical)
|
||||
high: High frequency envelope (2000-20000 Hz typical)
|
||||
times: Time points in seconds
|
||||
band_ranges: Frequency ranges for each band in Hz
|
||||
"""
|
||||
bass: List[float]
|
||||
mid: List[float]
|
||||
high: List[float]
|
||||
times: List[float]
|
||||
band_ranges: Dict[str, Tuple[float, float]] = field(default_factory=lambda: {
|
||||
"bass": (20, 200),
|
||||
"mid": (200, 2000),
|
||||
"high": (2000, 20000),
|
||||
})
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"bass": self.bass,
|
||||
"mid": self.mid,
|
||||
"high": self.high,
|
||||
"times": self.times,
|
||||
"band_ranges": self.band_ranges,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SpectrumBands":
|
||||
return cls(
|
||||
bass=data["bass"],
|
||||
mid=data["mid"],
|
||||
high=data["high"],
|
||||
times=data["times"],
|
||||
band_ranges=data.get("band_ranges", {
|
||||
"bass": (20, 200),
|
||||
"mid": (200, 2000),
|
||||
"high": (2000, 20000),
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioFeatures:
|
||||
"""
|
||||
All extracted audio features.
|
||||
|
||||
Attributes:
|
||||
duration: Audio duration in seconds
|
||||
sample_rate: Sample rate in Hz
|
||||
channels: Number of audio channels
|
||||
beats: Beat detection results
|
||||
energy: Energy envelope
|
||||
spectrum: Frequency band envelopes
|
||||
onsets: Note/sound onset times
|
||||
"""
|
||||
duration: float
|
||||
sample_rate: int
|
||||
channels: int
|
||||
beats: Optional[BeatInfo] = None
|
||||
energy: Optional[EnergyEnvelope] = None
|
||||
spectrum: Optional[SpectrumBands] = None
|
||||
onsets: Optional[List[float]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"duration": self.duration,
|
||||
"sample_rate": self.sample_rate,
|
||||
"channels": self.channels,
|
||||
"beats": self.beats.to_dict() if self.beats else None,
|
||||
"energy": self.energy.to_dict() if self.energy else None,
|
||||
"spectrum": self.spectrum.to_dict() if self.spectrum else None,
|
||||
"onsets": self.onsets,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AudioFeatures":
|
||||
return cls(
|
||||
duration=data["duration"],
|
||||
sample_rate=data["sample_rate"],
|
||||
channels=data["channels"],
|
||||
beats=BeatInfo.from_dict(data["beats"]) if data.get("beats") else None,
|
||||
energy=EnergyEnvelope.from_dict(data["energy"]) if data.get("energy") else None,
|
||||
spectrum=SpectrumBands.from_dict(data["spectrum"]) if data.get("spectrum") else None,
|
||||
onsets=data.get("onsets"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFeatures:
|
||||
"""
|
||||
Extracted video features.
|
||||
|
||||
Attributes:
|
||||
duration: Video duration in seconds
|
||||
frame_rate: Frames per second
|
||||
width: Frame width in pixels
|
||||
height: Frame height in pixels
|
||||
codec: Video codec name
|
||||
motion_tempo: Estimated tempo from motion analysis (optional)
|
||||
scene_changes: Times of detected scene changes
|
||||
"""
|
||||
duration: float
|
||||
frame_rate: float
|
||||
width: int
|
||||
height: int
|
||||
codec: str = ""
|
||||
motion_tempo: Optional[float] = None
|
||||
scene_changes: Optional[List[float]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"duration": self.duration,
|
||||
"frame_rate": self.frame_rate,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"codec": self.codec,
|
||||
"motion_tempo": self.motion_tempo,
|
||||
"scene_changes": self.scene_changes,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "VideoFeatures":
|
||||
return cls(
|
||||
duration=data["duration"],
|
||||
frame_rate=data["frame_rate"],
|
||||
width=data["width"],
|
||||
height=data["height"],
|
||||
codec=data.get("codec", ""),
|
||||
motion_tempo=data.get("motion_tempo"),
|
||||
scene_changes=data.get("scene_changes"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
"""
|
||||
Complete analysis result for an input.
|
||||
|
||||
Combines audio and video features with metadata for caching.
|
||||
|
||||
Attributes:
|
||||
input_hash: Content hash of the analyzed input
|
||||
features_requested: List of features that were requested
|
||||
audio: Audio features (if input has audio)
|
||||
video: Video features (if input has video)
|
||||
cache_id: Computed cache ID for this analysis
|
||||
analyzed_at: Timestamp of analysis
|
||||
"""
|
||||
input_hash: str
|
||||
features_requested: List[str]
|
||||
audio: Optional[AudioFeatures] = None
|
||||
video: Optional[VideoFeatures] = None
|
||||
cache_id: Optional[str] = None
|
||||
analyzed_at: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Compute cache_id if not provided."""
|
||||
if self.cache_id is None:
|
||||
self.cache_id = self._compute_cache_id()
|
||||
|
||||
def _compute_cache_id(self) -> str:
|
||||
"""
|
||||
Compute cache ID from input hash and requested features.
|
||||
|
||||
cache_id = SHA3-256(input_hash + sorted(features_requested))
|
||||
"""
|
||||
content = {
|
||||
"input_hash": self.input_hash,
|
||||
"features": sorted(self.features_requested),
|
||||
}
|
||||
return _stable_hash(content)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"input_hash": self.input_hash,
|
||||
"features_requested": self.features_requested,
|
||||
"audio": self.audio.to_dict() if self.audio else None,
|
||||
"video": self.video.to_dict() if self.video else None,
|
||||
"cache_id": self.cache_id,
|
||||
"analyzed_at": self.analyzed_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AnalysisResult":
|
||||
return cls(
|
||||
input_hash=data["input_hash"],
|
||||
features_requested=data["features_requested"],
|
||||
audio=AudioFeatures.from_dict(data["audio"]) if data.get("audio") else None,
|
||||
video=VideoFeatures.from_dict(data["video"]) if data.get("video") else None,
|
||||
cache_id=data.get("cache_id"),
|
||||
analyzed_at=data.get("analyzed_at"),
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Serialize to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "AnalysisResult":
|
||||
"""Deserialize from JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
# Convenience accessors
|
||||
@property
|
||||
def tempo(self) -> Optional[float]:
|
||||
"""Get tempo if beats were analyzed."""
|
||||
return self.audio.beats.tempo if self.audio and self.audio.beats else None
|
||||
|
||||
@property
|
||||
def beat_times(self) -> Optional[List[float]]:
|
||||
"""Get beat times if beats were analyzed."""
|
||||
return self.audio.beats.beat_times if self.audio and self.audio.beats else None
|
||||
|
||||
@property
|
||||
def downbeat_times(self) -> Optional[List[float]]:
|
||||
"""Get downbeat times if analyzed."""
|
||||
return self.audio.beats.downbeat_times if self.audio and self.audio.beats else None
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
"""Get duration from video or audio."""
|
||||
if self.video:
|
||||
return self.video.duration
|
||||
if self.audio:
|
||||
return self.audio.duration
|
||||
return 0.0
|
||||
|
||||
@property
|
||||
def dimensions(self) -> Optional[Tuple[int, int]]:
|
||||
"""Get video dimensions if available."""
|
||||
if self.video:
|
||||
return (self.video.width, self.video.height)
|
||||
return None
|
||||
266
core/artdag/analysis/video.py
Normal file
266
core/artdag/analysis/video.py
Normal file
@@ -0,0 +1,266 @@
|
||||
# artdag/analysis/video.py
|
||||
"""
|
||||
Video feature extraction.
|
||||
|
||||
Uses ffprobe for basic metadata and optional OpenCV for motion analysis.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from .schema import VideoFeatures
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Feature names
|
||||
FEATURE_METADATA = "metadata"
|
||||
FEATURE_MOTION_TEMPO = "motion_tempo"
|
||||
FEATURE_SCENE_CHANGES = "scene_changes"
|
||||
FEATURE_ALL = "all"
|
||||
|
||||
|
||||
def _parse_frame_rate(rate_str: str) -> float:
|
||||
"""Parse frame rate string like '30000/1001' or '30'."""
|
||||
try:
|
||||
if "/" in rate_str:
|
||||
frac = Fraction(rate_str)
|
||||
return float(frac)
|
||||
return float(rate_str)
|
||||
except (ValueError, ZeroDivisionError):
|
||||
return 30.0 # Default
|
||||
|
||||
|
||||
def analyze_metadata(path: Path) -> VideoFeatures:
|
||||
"""
|
||||
Extract video metadata using ffprobe.
|
||||
|
||||
Args:
|
||||
path: Path to video file
|
||||
|
||||
Returns:
|
||||
VideoFeatures with basic metadata
|
||||
"""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet",
|
||||
"-print_format", "json",
|
||||
"-show_streams",
|
||||
"-show_format",
|
||||
"-select_streams", "v:0",
|
||||
str(path)
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
data = json.loads(result.stdout)
|
||||
except (subprocess.CalledProcessError, json.JSONDecodeError) as e:
|
||||
raise ValueError(f"Could not read video info: {e}")
|
||||
|
||||
if not data.get("streams"):
|
||||
raise ValueError("No video stream found")
|
||||
|
||||
stream = data["streams"][0]
|
||||
fmt = data.get("format", {})
|
||||
|
||||
# Get duration from format or stream
|
||||
duration = float(fmt.get("duration", stream.get("duration", 0)))
|
||||
|
||||
# Parse frame rate
|
||||
frame_rate = _parse_frame_rate(stream.get("avg_frame_rate", "30"))
|
||||
|
||||
return VideoFeatures(
|
||||
duration=duration,
|
||||
frame_rate=frame_rate,
|
||||
width=int(stream.get("width", 0)),
|
||||
height=int(stream.get("height", 0)),
|
||||
codec=stream.get("codec_name", ""),
|
||||
)
|
||||
|
||||
|
||||
def analyze_scene_changes(path: Path, threshold: float = 0.3) -> List[float]:
|
||||
"""
|
||||
Detect scene changes using ffmpeg scene detection.
|
||||
|
||||
Args:
|
||||
path: Path to video file
|
||||
threshold: Scene change threshold (0-1, lower = more sensitive)
|
||||
|
||||
Returns:
|
||||
List of scene change times in seconds
|
||||
"""
|
||||
cmd = [
|
||||
"ffmpeg", "-i", str(path),
|
||||
"-vf", f"select='gt(scene,{threshold})',showinfo",
|
||||
"-f", "null", "-"
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
stderr = result.stderr
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Scene detection failed: {e}")
|
||||
return []
|
||||
|
||||
# Parse scene change times from ffmpeg output
|
||||
scene_times = []
|
||||
for line in stderr.split("\n"):
|
||||
if "pts_time:" in line:
|
||||
try:
|
||||
# Extract pts_time value
|
||||
for part in line.split():
|
||||
if part.startswith("pts_time:"):
|
||||
time_str = part.split(":")[1]
|
||||
scene_times.append(float(time_str))
|
||||
break
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return scene_times
|
||||
|
||||
|
||||
def analyze_motion_tempo(path: Path, sample_duration: float = 30.0) -> Optional[float]:
|
||||
"""
|
||||
Estimate tempo from video motion periodicity.
|
||||
|
||||
Analyzes optical flow or frame differences to detect rhythmic motion.
|
||||
This is useful for matching video speed to audio tempo.
|
||||
|
||||
Args:
|
||||
path: Path to video file
|
||||
sample_duration: Duration to analyze (seconds)
|
||||
|
||||
Returns:
|
||||
Estimated motion tempo in BPM, or None if not detectable
|
||||
"""
|
||||
try:
|
||||
import cv2
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logger.warning("OpenCV not available, skipping motion tempo analysis")
|
||||
return None
|
||||
|
||||
cap = cv2.VideoCapture(str(path))
|
||||
if not cap.isOpened():
|
||||
logger.warning(f"Could not open video: {path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
if fps <= 0:
|
||||
fps = 30.0
|
||||
|
||||
max_frames = int(sample_duration * fps)
|
||||
frame_diffs = []
|
||||
prev_gray = None
|
||||
|
||||
frame_count = 0
|
||||
while frame_count < max_frames:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Convert to grayscale and resize for speed
|
||||
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
||||
gray = cv2.resize(gray, (160, 90))
|
||||
|
||||
if prev_gray is not None:
|
||||
# Calculate frame difference
|
||||
diff = cv2.absdiff(gray, prev_gray)
|
||||
frame_diffs.append(np.mean(diff))
|
||||
|
||||
prev_gray = gray
|
||||
frame_count += 1
|
||||
|
||||
if len(frame_diffs) < 60: # Need at least 2 seconds at 30fps
|
||||
return None
|
||||
|
||||
# Convert to numpy array
|
||||
motion = np.array(frame_diffs)
|
||||
|
||||
# Normalize
|
||||
motion = motion - motion.mean()
|
||||
if motion.std() > 0:
|
||||
motion = motion / motion.std()
|
||||
|
||||
# Autocorrelation to find periodicity
|
||||
n = len(motion)
|
||||
acf = np.correlate(motion, motion, mode="full")[n-1:]
|
||||
acf = acf / acf[0] # Normalize
|
||||
|
||||
# Find peaks in autocorrelation (potential beat periods)
|
||||
# Look for periods between 0.3s (200 BPM) and 2s (30 BPM)
|
||||
min_lag = int(0.3 * fps)
|
||||
max_lag = min(int(2.0 * fps), len(acf) - 1)
|
||||
|
||||
if max_lag <= min_lag:
|
||||
return None
|
||||
|
||||
# Find the highest peak in the valid range
|
||||
search_range = acf[min_lag:max_lag]
|
||||
if len(search_range) == 0:
|
||||
return None
|
||||
|
||||
peak_idx = np.argmax(search_range) + min_lag
|
||||
peak_value = acf[peak_idx]
|
||||
|
||||
# Only report if peak is significant
|
||||
if peak_value < 0.1:
|
||||
return None
|
||||
|
||||
# Convert lag to BPM
|
||||
period_seconds = peak_idx / fps
|
||||
bpm = 60.0 / period_seconds
|
||||
|
||||
# Sanity check
|
||||
if 30 <= bpm <= 200:
|
||||
return round(bpm, 1)
|
||||
|
||||
return None
|
||||
|
||||
finally:
|
||||
cap.release()
|
||||
|
||||
|
||||
def analyze_video(
|
||||
path: Path,
|
||||
features: Optional[List[str]] = None,
|
||||
) -> VideoFeatures:
|
||||
"""
|
||||
Extract video features from file.
|
||||
|
||||
Args:
|
||||
path: Path to video file
|
||||
features: List of features to extract. Options:
|
||||
- "metadata": Basic video info (always included)
|
||||
- "motion_tempo": Estimated tempo from motion
|
||||
- "scene_changes": Scene change detection
|
||||
- "all": All features
|
||||
|
||||
Returns:
|
||||
VideoFeatures with requested analysis
|
||||
"""
|
||||
if features is None:
|
||||
features = [FEATURE_METADATA]
|
||||
|
||||
if FEATURE_ALL in features:
|
||||
features = [FEATURE_METADATA, FEATURE_MOTION_TEMPO, FEATURE_SCENE_CHANGES]
|
||||
|
||||
# Basic metadata is always extracted
|
||||
result = analyze_metadata(path)
|
||||
|
||||
if FEATURE_MOTION_TEMPO in features:
|
||||
try:
|
||||
result.motion_tempo = analyze_motion_tempo(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Motion tempo analysis failed: {e}")
|
||||
|
||||
if FEATURE_SCENE_CHANGES in features:
|
||||
try:
|
||||
result.scene_changes = analyze_scene_changes(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Scene change detection failed: {e}")
|
||||
|
||||
return result
|
||||
464
core/artdag/cache.py
Normal file
464
core/artdag/cache.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# primitive/cache.py
|
||||
"""
|
||||
Content-addressed file cache for node outputs.
|
||||
|
||||
Each node's output is stored at: cache_dir / node_id / output_file
|
||||
This enables automatic reuse when the same operation is requested.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _file_hash(path: Path, algorithm: str = "sha3_256") -> str:
|
||||
"""
|
||||
Compute content hash of a file.
|
||||
|
||||
Uses SHA-3 (Keccak) by default for quantum resistance.
|
||||
"""
|
||||
import hashlib
|
||||
hasher = hashlib.new(algorithm)
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""Metadata about a cached output."""
|
||||
node_id: str
|
||||
output_path: Path
|
||||
created_at: float
|
||||
size_bytes: int
|
||||
node_type: str
|
||||
cid: str = "" # Content identifier (IPFS CID or local hash)
|
||||
execution_time: float = 0.0
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"output_path": str(self.output_path),
|
||||
"created_at": self.created_at,
|
||||
"size_bytes": self.size_bytes,
|
||||
"node_type": self.node_type,
|
||||
"cid": self.cid,
|
||||
"execution_time": self.execution_time,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "CacheEntry":
|
||||
# Support both "cid" and legacy "content_hash"
|
||||
cid = data.get("cid") or data.get("content_hash", "")
|
||||
return cls(
|
||||
node_id=data["node_id"],
|
||||
output_path=Path(data["output_path"]),
|
||||
created_at=data["created_at"],
|
||||
size_bytes=data["size_bytes"],
|
||||
node_type=data["node_type"],
|
||||
cid=cid,
|
||||
execution_time=data.get("execution_time", 0.0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Statistics about cache usage."""
|
||||
total_entries: int = 0
|
||||
total_size_bytes: int = 0
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
hit_rate: float = 0.0
|
||||
|
||||
def record_hit(self):
|
||||
self.hits += 1
|
||||
self._update_rate()
|
||||
|
||||
def record_miss(self):
|
||||
self.misses += 1
|
||||
self._update_rate()
|
||||
|
||||
def _update_rate(self):
|
||||
total = self.hits + self.misses
|
||||
self.hit_rate = self.hits / total if total > 0 else 0.0
|
||||
|
||||
|
||||
class Cache:
|
||||
"""
|
||||
Code-addressed file cache.
|
||||
|
||||
The filesystem IS the index - no JSON index files needed.
|
||||
Each node's hash is its directory name.
|
||||
|
||||
Structure:
|
||||
cache_dir/
|
||||
<hash>/
|
||||
output.ext # Actual output file
|
||||
metadata.json # Per-node metadata (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Path | str):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.stats = CacheStats()
|
||||
|
||||
def _node_dir(self, node_id: str) -> Path:
|
||||
"""Get the cache directory for a node."""
|
||||
return self.cache_dir / node_id
|
||||
|
||||
def _find_output_file(self, node_dir: Path) -> Optional[Path]:
|
||||
"""Find the output file in a node directory."""
|
||||
if not node_dir.exists() or not node_dir.is_dir():
|
||||
return None
|
||||
for f in node_dir.iterdir():
|
||||
if f.is_file() and f.name.startswith("output."):
|
||||
return f
|
||||
return None
|
||||
|
||||
def get(self, node_id: str) -> Optional[Path]:
|
||||
"""
|
||||
Get cached output path for a node.
|
||||
|
||||
Checks filesystem directly - no in-memory index.
|
||||
Returns the output path if cached, None otherwise.
|
||||
"""
|
||||
node_dir = self._node_dir(node_id)
|
||||
output_file = self._find_output_file(node_dir)
|
||||
|
||||
if output_file:
|
||||
self.stats.record_hit()
|
||||
logger.debug(f"Cache hit: {node_id[:16]}...")
|
||||
return output_file
|
||||
|
||||
self.stats.record_miss()
|
||||
return None
|
||||
|
||||
def put(self, node_id: str, source_path: Path, node_type: str,
|
||||
execution_time: float = 0.0, move: bool = False) -> Path:
|
||||
"""
|
||||
Store a file in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The code-addressed node ID (hash)
|
||||
source_path: Path to the file to cache
|
||||
node_type: Type of the node (for metadata)
|
||||
execution_time: How long the node took to execute
|
||||
move: If True, move the file instead of copying
|
||||
|
||||
Returns:
|
||||
Path to the cached file
|
||||
"""
|
||||
node_dir = self._node_dir(node_id)
|
||||
node_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Preserve extension
|
||||
ext = source_path.suffix or ".out"
|
||||
output_path = node_dir / f"output{ext}"
|
||||
|
||||
# Copy or move file (skip if already in place)
|
||||
source_resolved = Path(source_path).resolve()
|
||||
output_resolved = output_path.resolve()
|
||||
if source_resolved != output_resolved:
|
||||
if move:
|
||||
shutil.move(source_path, output_path)
|
||||
else:
|
||||
shutil.copy2(source_path, output_path)
|
||||
|
||||
# Compute content hash (IPFS CID of the result)
|
||||
cid = _file_hash(output_path)
|
||||
|
||||
# Store per-node metadata (optional, for stats/debugging)
|
||||
metadata = {
|
||||
"node_id": node_id,
|
||||
"output_path": str(output_path),
|
||||
"created_at": time.time(),
|
||||
"size_bytes": output_path.stat().st_size,
|
||||
"node_type": node_type,
|
||||
"cid": cid,
|
||||
"execution_time": execution_time,
|
||||
}
|
||||
metadata_path = node_dir / "metadata.json"
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.debug(f"Cached: {node_id[:16]}... ({metadata['size_bytes']} bytes)")
|
||||
return output_path
|
||||
|
||||
def has(self, node_id: str) -> bool:
|
||||
"""Check if a node is cached (without affecting stats)."""
|
||||
return self._find_output_file(self._node_dir(node_id)) is not None
|
||||
|
||||
def remove(self, node_id: str) -> bool:
|
||||
"""Remove a node from the cache."""
|
||||
node_dir = self._node_dir(node_id)
|
||||
if node_dir.exists():
|
||||
shutil.rmtree(node_dir)
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached entries."""
|
||||
for node_dir in self.cache_dir.iterdir():
|
||||
if node_dir.is_dir() and not node_dir.name.startswith("_"):
|
||||
shutil.rmtree(node_dir)
|
||||
self.stats = CacheStats()
|
||||
|
||||
def get_stats(self) -> CacheStats:
|
||||
"""Get cache statistics (scans filesystem)."""
|
||||
stats = CacheStats()
|
||||
for node_dir in self.cache_dir.iterdir():
|
||||
if node_dir.is_dir() and not node_dir.name.startswith("_"):
|
||||
output_file = self._find_output_file(node_dir)
|
||||
if output_file:
|
||||
stats.total_entries += 1
|
||||
stats.total_size_bytes += output_file.stat().st_size
|
||||
stats.hits = self.stats.hits
|
||||
stats.misses = self.stats.misses
|
||||
stats.hit_rate = self.stats.hit_rate
|
||||
return stats
|
||||
|
||||
def list_entries(self) -> List[CacheEntry]:
|
||||
"""List all cache entries (scans filesystem)."""
|
||||
entries = []
|
||||
for node_dir in self.cache_dir.iterdir():
|
||||
if node_dir.is_dir() and not node_dir.name.startswith("_"):
|
||||
entry = self._load_entry_from_disk(node_dir.name)
|
||||
if entry:
|
||||
entries.append(entry)
|
||||
return entries
|
||||
|
||||
def _load_entry_from_disk(self, node_id: str) -> Optional[CacheEntry]:
|
||||
"""Load entry metadata from disk."""
|
||||
node_dir = self._node_dir(node_id)
|
||||
metadata_path = node_dir / "metadata.json"
|
||||
output_file = self._find_output_file(node_dir)
|
||||
|
||||
if not output_file:
|
||||
return None
|
||||
|
||||
if metadata_path.exists():
|
||||
try:
|
||||
with open(metadata_path) as f:
|
||||
data = json.load(f)
|
||||
return CacheEntry.from_dict(data)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
pass
|
||||
|
||||
# Fallback: create entry from filesystem
|
||||
return CacheEntry(
|
||||
node_id=node_id,
|
||||
output_path=output_file,
|
||||
created_at=output_file.stat().st_mtime,
|
||||
size_bytes=output_file.stat().st_size,
|
||||
node_type="unknown",
|
||||
cid=_file_hash(output_file),
|
||||
)
|
||||
|
||||
def get_entry(self, node_id: str) -> Optional[CacheEntry]:
|
||||
"""Get cache entry metadata (without affecting stats)."""
|
||||
return self._load_entry_from_disk(node_id)
|
||||
|
||||
def find_by_cid(self, cid: str) -> Optional[CacheEntry]:
|
||||
"""Find a cache entry by its content hash (scans filesystem)."""
|
||||
for entry in self.list_entries():
|
||||
if entry.cid == cid:
|
||||
return entry
|
||||
return None
|
||||
|
||||
def prune(self, max_size_bytes: int = None, max_age_seconds: float = None) -> int:
|
||||
"""
|
||||
Prune cache based on size or age.
|
||||
|
||||
Args:
|
||||
max_size_bytes: Remove oldest entries until under this size
|
||||
max_age_seconds: Remove entries older than this
|
||||
|
||||
Returns:
|
||||
Number of entries removed
|
||||
"""
|
||||
removed = 0
|
||||
now = time.time()
|
||||
entries = self.list_entries()
|
||||
|
||||
# Remove by age first
|
||||
if max_age_seconds is not None:
|
||||
for entry in entries:
|
||||
if now - entry.created_at > max_age_seconds:
|
||||
self.remove(entry.node_id)
|
||||
removed += 1
|
||||
|
||||
# Then by size (remove oldest first)
|
||||
if max_size_bytes is not None:
|
||||
stats = self.get_stats()
|
||||
if stats.total_size_bytes > max_size_bytes:
|
||||
sorted_entries = sorted(entries, key=lambda e: e.created_at)
|
||||
total_size = stats.total_size_bytes
|
||||
for entry in sorted_entries:
|
||||
if total_size <= max_size_bytes:
|
||||
break
|
||||
self.remove(entry.node_id)
|
||||
total_size -= entry.size_bytes
|
||||
removed += 1
|
||||
|
||||
return removed
|
||||
|
||||
def get_output_path(self, node_id: str, extension: str = ".mkv") -> Path:
|
||||
"""Get the output path for a node (creates directory if needed)."""
|
||||
node_dir = self._node_dir(node_id)
|
||||
node_dir.mkdir(parents=True, exist_ok=True)
|
||||
return node_dir / f"output{extension}"
|
||||
|
||||
# Effect storage methods
|
||||
|
||||
def _effects_dir(self) -> Path:
|
||||
"""Get the effects subdirectory."""
|
||||
effects_dir = self.cache_dir / "_effects"
|
||||
effects_dir.mkdir(parents=True, exist_ok=True)
|
||||
return effects_dir
|
||||
|
||||
def store_effect(self, source: str) -> str:
|
||||
"""
|
||||
Store an effect in the cache.
|
||||
|
||||
Args:
|
||||
source: Effect source code
|
||||
|
||||
Returns:
|
||||
Content hash (cache ID) of the effect
|
||||
"""
|
||||
import hashlib as _hashlib
|
||||
|
||||
# Compute content hash
|
||||
cid = _hashlib.sha3_256(source.encode("utf-8")).hexdigest()
|
||||
|
||||
# Try to load full metadata if effects module available
|
||||
try:
|
||||
from .effects.loader import load_effect
|
||||
loaded = load_effect(source)
|
||||
meta_dict = loaded.meta.to_dict()
|
||||
dependencies = loaded.dependencies
|
||||
requires_python = loaded.requires_python
|
||||
except ImportError:
|
||||
# Fallback: store without parsed metadata
|
||||
meta_dict = {}
|
||||
dependencies = []
|
||||
requires_python = ">=3.10"
|
||||
|
||||
effect_dir = self._effects_dir() / cid
|
||||
effect_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Store source
|
||||
source_path = effect_dir / "effect.py"
|
||||
source_path.write_text(source, encoding="utf-8")
|
||||
|
||||
# Store metadata
|
||||
metadata = {
|
||||
"cid": cid,
|
||||
"meta": meta_dict,
|
||||
"dependencies": dependencies,
|
||||
"requires_python": requires_python,
|
||||
"stored_at": time.time(),
|
||||
}
|
||||
metadata_path = effect_dir / "metadata.json"
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.info(f"Stored effect '{loaded.meta.name}' with hash {cid[:16]}...")
|
||||
return cid
|
||||
|
||||
def get_effect(self, cid: str) -> Optional[str]:
|
||||
"""
|
||||
Get effect source by content hash.
|
||||
|
||||
Args:
|
||||
cid: SHA3-256 hash of effect source
|
||||
|
||||
Returns:
|
||||
Effect source code if found, None otherwise
|
||||
"""
|
||||
effect_dir = self._effects_dir() / cid
|
||||
source_path = effect_dir / "effect.py"
|
||||
|
||||
if not source_path.exists():
|
||||
return None
|
||||
|
||||
return source_path.read_text(encoding="utf-8")
|
||||
|
||||
def get_effect_path(self, cid: str) -> Optional[Path]:
|
||||
"""
|
||||
Get path to effect source file.
|
||||
|
||||
Args:
|
||||
cid: SHA3-256 hash of effect source
|
||||
|
||||
Returns:
|
||||
Path to effect.py if found, None otherwise
|
||||
"""
|
||||
effect_dir = self._effects_dir() / cid
|
||||
source_path = effect_dir / "effect.py"
|
||||
|
||||
if not source_path.exists():
|
||||
return None
|
||||
|
||||
return source_path
|
||||
|
||||
def get_effect_metadata(self, cid: str) -> Optional[dict]:
|
||||
"""
|
||||
Get effect metadata by content hash.
|
||||
|
||||
Args:
|
||||
cid: SHA3-256 hash of effect source
|
||||
|
||||
Returns:
|
||||
Metadata dict if found, None otherwise
|
||||
"""
|
||||
effect_dir = self._effects_dir() / cid
|
||||
metadata_path = effect_dir / "metadata.json"
|
||||
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(metadata_path) as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
return None
|
||||
|
||||
def has_effect(self, cid: str) -> bool:
|
||||
"""Check if an effect is cached."""
|
||||
effect_dir = self._effects_dir() / cid
|
||||
return (effect_dir / "effect.py").exists()
|
||||
|
||||
def list_effects(self) -> List[dict]:
|
||||
"""List all cached effects with their metadata."""
|
||||
effects = []
|
||||
effects_dir = self._effects_dir()
|
||||
|
||||
if not effects_dir.exists():
|
||||
return effects
|
||||
|
||||
for effect_dir in effects_dir.iterdir():
|
||||
if effect_dir.is_dir():
|
||||
metadata = self.get_effect_metadata(effect_dir.name)
|
||||
if metadata:
|
||||
effects.append(metadata)
|
||||
|
||||
return effects
|
||||
|
||||
def remove_effect(self, cid: str) -> bool:
|
||||
"""Remove an effect from the cache."""
|
||||
effect_dir = self._effects_dir() / cid
|
||||
|
||||
if not effect_dir.exists():
|
||||
return False
|
||||
|
||||
shutil.rmtree(effect_dir)
|
||||
logger.info(f"Removed effect {cid[:16]}...")
|
||||
return True
|
||||
724
core/artdag/cli.py
Normal file
724
core/artdag/cli.py
Normal file
@@ -0,0 +1,724 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Art DAG CLI
|
||||
|
||||
Command-line interface for the 3-phase execution model:
|
||||
artdag analyze - Extract features from inputs
|
||||
artdag plan - Generate execution plan
|
||||
artdag execute - Run the plan
|
||||
artdag run-recipe - Full pipeline
|
||||
|
||||
Usage:
|
||||
artdag analyze <recipe> -i <name>:<hash>[@<path>] [--features <list>]
|
||||
artdag plan <recipe> -i <name>:<hash> [--analysis <file>]
|
||||
artdag execute <plan.json> [--dry-run]
|
||||
artdag run-recipe <recipe> -i <name>:<hash>[@<path>]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def parse_input(input_str: str) -> Tuple[str, str, Optional[str]]:
|
||||
"""
|
||||
Parse input specification: name:hash[@path]
|
||||
|
||||
Returns (name, hash, path or None)
|
||||
"""
|
||||
if "@" in input_str:
|
||||
name_hash, path = input_str.rsplit("@", 1)
|
||||
else:
|
||||
name_hash = input_str
|
||||
path = None
|
||||
|
||||
if ":" not in name_hash:
|
||||
raise ValueError(f"Invalid input format: {input_str}. Expected name:hash[@path]")
|
||||
|
||||
name, hash_value = name_hash.split(":", 1)
|
||||
return name, hash_value, path
|
||||
|
||||
|
||||
def parse_inputs(input_list: List[str]) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||||
"""
|
||||
Parse list of input specifications.
|
||||
|
||||
Returns (input_hashes, input_paths)
|
||||
"""
|
||||
input_hashes = {}
|
||||
input_paths = {}
|
||||
|
||||
for input_str in input_list:
|
||||
name, hash_value, path = parse_input(input_str)
|
||||
input_hashes[name] = hash_value
|
||||
if path:
|
||||
input_paths[name] = path
|
||||
|
||||
return input_hashes, input_paths
|
||||
|
||||
|
||||
def cmd_analyze(args):
|
||||
"""Run analysis phase."""
|
||||
from .analysis import Analyzer
|
||||
|
||||
# Parse inputs
|
||||
input_hashes, input_paths = parse_inputs(args.input)
|
||||
|
||||
# Parse features
|
||||
features = args.features.split(",") if args.features else ["all"]
|
||||
|
||||
# Create analyzer
|
||||
cache_dir = Path(args.cache_dir) if args.cache_dir else Path("./analysis_cache")
|
||||
analyzer = Analyzer(cache_dir=cache_dir)
|
||||
|
||||
# Analyze each input
|
||||
results = {}
|
||||
for name, hash_value in input_hashes.items():
|
||||
path = input_paths.get(name)
|
||||
if path:
|
||||
path = Path(path)
|
||||
|
||||
print(f"Analyzing {name} ({hash_value[:16]}...)...")
|
||||
|
||||
result = analyzer.analyze(
|
||||
input_hash=hash_value,
|
||||
features=features,
|
||||
input_path=path,
|
||||
)
|
||||
|
||||
results[hash_value] = result.to_dict()
|
||||
|
||||
# Print summary
|
||||
if result.audio and result.audio.beats:
|
||||
print(f" Tempo: {result.audio.beats.tempo:.1f} BPM")
|
||||
print(f" Beats: {len(result.audio.beats.beat_times)}")
|
||||
if result.video:
|
||||
print(f" Duration: {result.video.duration:.1f}s")
|
||||
print(f" Dimensions: {result.video.width}x{result.video.height}")
|
||||
|
||||
# Write output
|
||||
output_path = Path(args.output) if args.output else Path("analysis.json")
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
print(f"\nAnalysis saved to: {output_path}")
|
||||
|
||||
|
||||
def cmd_plan(args):
|
||||
"""Run planning phase."""
|
||||
from .analysis import AnalysisResult
|
||||
from .planning import RecipePlanner, Recipe
|
||||
|
||||
# Load recipe
|
||||
recipe = Recipe.from_file(Path(args.recipe))
|
||||
print(f"Recipe: {recipe.name} v{recipe.version}")
|
||||
|
||||
# Parse inputs
|
||||
input_hashes, _ = parse_inputs(args.input)
|
||||
|
||||
# Load analysis if provided
|
||||
analysis = {}
|
||||
if args.analysis:
|
||||
with open(args.analysis, "r") as f:
|
||||
analysis_data = json.load(f)
|
||||
for hash_value, data in analysis_data.items():
|
||||
analysis[hash_value] = AnalysisResult.from_dict(data)
|
||||
|
||||
# Create planner
|
||||
planner = RecipePlanner(use_tree_reduction=not args.no_tree_reduction)
|
||||
|
||||
# Generate plan
|
||||
print("Generating execution plan...")
|
||||
plan = planner.plan(
|
||||
recipe=recipe,
|
||||
input_hashes=input_hashes,
|
||||
analysis=analysis,
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print(f"\nPlan ID: {plan.plan_id[:16]}...")
|
||||
print(f"Steps: {len(plan.steps)}")
|
||||
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
max_level = max(steps_by_level.keys()) if steps_by_level else 0
|
||||
print(f"Levels: {max_level + 1}")
|
||||
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
print(f" Level {level}: {len(steps)} steps (parallel)")
|
||||
|
||||
# Write output
|
||||
output_path = Path(args.output) if args.output else Path("plan.json")
|
||||
with open(output_path, "w") as f:
|
||||
f.write(plan.to_json())
|
||||
|
||||
print(f"\nPlan saved to: {output_path}")
|
||||
|
||||
|
||||
def cmd_execute(args):
|
||||
"""Run execution phase."""
|
||||
from .planning import ExecutionPlan
|
||||
from .cache import Cache
|
||||
from .executor import get_executor
|
||||
from .dag import NodeType
|
||||
from . import nodes # Register built-in executors
|
||||
|
||||
# Load plan
|
||||
with open(args.plan, "r") as f:
|
||||
plan = ExecutionPlan.from_json(f.read())
|
||||
|
||||
print(f"Executing plan: {plan.plan_id[:16]}...")
|
||||
print(f"Steps: {len(plan.steps)}")
|
||||
|
||||
if args.dry_run:
|
||||
print("\n=== DRY RUN ===")
|
||||
|
||||
# Check cache status
|
||||
cache = Cache(Path(args.cache_dir) if args.cache_dir else Path("./cache"))
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
|
||||
cached_count = 0
|
||||
pending_count = 0
|
||||
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
print(f"\nLevel {level}:")
|
||||
for step in steps:
|
||||
if cache.has(step.cache_id):
|
||||
print(f" [CACHED] {step.step_id}: {step.node_type}")
|
||||
cached_count += 1
|
||||
else:
|
||||
print(f" [PENDING] {step.step_id}: {step.node_type}")
|
||||
pending_count += 1
|
||||
|
||||
print(f"\nSummary: {cached_count} cached, {pending_count} pending")
|
||||
return
|
||||
|
||||
# Execute locally (for testing - production uses Celery)
|
||||
cache = Cache(Path(args.cache_dir) if args.cache_dir else Path("./cache"))
|
||||
|
||||
cache_paths = {}
|
||||
for name, hash_value in plan.input_hashes.items():
|
||||
if cache.has(hash_value):
|
||||
entry = cache.get(hash_value)
|
||||
cache_paths[hash_value] = str(entry.output_path)
|
||||
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
executed = 0
|
||||
cached = 0
|
||||
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
print(f"\nLevel {level}: {len(steps)} steps")
|
||||
|
||||
for step in steps:
|
||||
if cache.has(step.cache_id):
|
||||
cached_path = cache.get(step.cache_id)
|
||||
cache_paths[step.cache_id] = str(cached_path)
|
||||
cache_paths[step.step_id] = str(cached_path)
|
||||
print(f" [CACHED] {step.step_id}")
|
||||
cached += 1
|
||||
continue
|
||||
|
||||
print(f" [RUNNING] {step.step_id}: {step.node_type}...")
|
||||
|
||||
# Get executor
|
||||
try:
|
||||
node_type = NodeType[step.node_type]
|
||||
except KeyError:
|
||||
node_type = step.node_type
|
||||
|
||||
executor = get_executor(node_type)
|
||||
if executor is None:
|
||||
print(f" ERROR: No executor for {step.node_type}")
|
||||
continue
|
||||
|
||||
# Resolve inputs
|
||||
input_paths = []
|
||||
for input_id in step.input_steps:
|
||||
if input_id in cache_paths:
|
||||
input_paths.append(Path(cache_paths[input_id]))
|
||||
else:
|
||||
input_step = plan.get_step(input_id)
|
||||
if input_step and input_step.cache_id in cache_paths:
|
||||
input_paths.append(Path(cache_paths[input_step.cache_id]))
|
||||
|
||||
if len(input_paths) != len(step.input_steps):
|
||||
print(f" ERROR: Missing inputs")
|
||||
continue
|
||||
|
||||
# Execute
|
||||
output_path = cache.get_output_path(step.cache_id)
|
||||
try:
|
||||
result_path = executor.execute(step.config, input_paths, output_path)
|
||||
cache.put(step.cache_id, result_path, node_type=step.node_type)
|
||||
cache_paths[step.cache_id] = str(result_path)
|
||||
cache_paths[step.step_id] = str(result_path)
|
||||
print(f" [DONE] -> {result_path}")
|
||||
executed += 1
|
||||
except Exception as e:
|
||||
print(f" [FAILED] {e}")
|
||||
|
||||
# Final output
|
||||
output_step = plan.get_step(plan.output_step)
|
||||
output_path = cache_paths.get(output_step.cache_id) if output_step else None
|
||||
|
||||
print(f"\n=== Complete ===")
|
||||
print(f"Cached: {cached}")
|
||||
print(f"Executed: {executed}")
|
||||
if output_path:
|
||||
print(f"Output: {output_path}")
|
||||
|
||||
|
||||
def cmd_run_recipe(args):
|
||||
"""Run complete pipeline: analyze → plan → execute."""
|
||||
from .analysis import Analyzer, AnalysisResult
|
||||
from .planning import RecipePlanner, Recipe
|
||||
from .cache import Cache
|
||||
from .executor import get_executor
|
||||
from .dag import NodeType
|
||||
from . import nodes # Register built-in executors
|
||||
|
||||
# Load recipe
|
||||
recipe = Recipe.from_file(Path(args.recipe))
|
||||
print(f"Recipe: {recipe.name} v{recipe.version}")
|
||||
|
||||
# Parse inputs
|
||||
input_hashes, input_paths = parse_inputs(args.input)
|
||||
|
||||
# Parse features
|
||||
features = args.features.split(",") if args.features else ["beats", "energy"]
|
||||
|
||||
cache_dir = Path(args.cache_dir) if args.cache_dir else Path("./cache")
|
||||
|
||||
# Phase 1: Analyze
|
||||
print("\n=== Phase 1: Analysis ===")
|
||||
analyzer = Analyzer(cache_dir=cache_dir / "analysis")
|
||||
|
||||
analysis = {}
|
||||
for name, hash_value in input_hashes.items():
|
||||
path = input_paths.get(name)
|
||||
if path:
|
||||
path = Path(path)
|
||||
print(f"Analyzing {name}...")
|
||||
|
||||
result = analyzer.analyze(
|
||||
input_hash=hash_value,
|
||||
features=features,
|
||||
input_path=path,
|
||||
)
|
||||
analysis[hash_value] = result
|
||||
|
||||
if result.audio and result.audio.beats:
|
||||
print(f" Tempo: {result.audio.beats.tempo:.1f} BPM, {len(result.audio.beats.beat_times)} beats")
|
||||
|
||||
# Phase 2: Plan
|
||||
print("\n=== Phase 2: Planning ===")
|
||||
|
||||
# Check for cached plan
|
||||
plans_dir = cache_dir / "plans"
|
||||
plans_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate plan to get plan_id (deterministic hash)
|
||||
planner = RecipePlanner(use_tree_reduction=True)
|
||||
plan = planner.plan(
|
||||
recipe=recipe,
|
||||
input_hashes=input_hashes,
|
||||
analysis=analysis,
|
||||
)
|
||||
|
||||
plan_cache_path = plans_dir / f"{plan.plan_id}.json"
|
||||
|
||||
if plan_cache_path.exists():
|
||||
print(f"Plan cached: {plan.plan_id[:16]}...")
|
||||
from .planning import ExecutionPlan
|
||||
with open(plan_cache_path, "r") as f:
|
||||
plan = ExecutionPlan.from_json(f.read())
|
||||
else:
|
||||
# Save plan to cache
|
||||
with open(plan_cache_path, "w") as f:
|
||||
f.write(plan.to_json())
|
||||
print(f"Plan saved: {plan.plan_id[:16]}...")
|
||||
|
||||
print(f"Plan: {len(plan.steps)} steps")
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
print(f"Levels: {len(steps_by_level)}")
|
||||
|
||||
# Phase 3: Execute
|
||||
print("\n=== Phase 3: Execution ===")
|
||||
|
||||
cache = Cache(cache_dir)
|
||||
|
||||
# Build initial cache paths
|
||||
cache_paths = {}
|
||||
for name, hash_value in input_hashes.items():
|
||||
path = input_paths.get(name)
|
||||
if path:
|
||||
cache_paths[hash_value] = path
|
||||
cache_paths[name] = path
|
||||
|
||||
executed = 0
|
||||
cached = 0
|
||||
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
print(f"\nLevel {level}: {len(steps)} steps")
|
||||
|
||||
for step in steps:
|
||||
if cache.has(step.cache_id):
|
||||
cached_path = cache.get(step.cache_id)
|
||||
cache_paths[step.cache_id] = str(cached_path)
|
||||
cache_paths[step.step_id] = str(cached_path)
|
||||
print(f" [CACHED] {step.step_id}")
|
||||
cached += 1
|
||||
continue
|
||||
|
||||
# Handle SOURCE specially
|
||||
if step.node_type == "SOURCE":
|
||||
cid = step.config.get("cid")
|
||||
if cid in cache_paths:
|
||||
cache_paths[step.cache_id] = cache_paths[cid]
|
||||
cache_paths[step.step_id] = cache_paths[cid]
|
||||
print(f" [SOURCE] {step.step_id}")
|
||||
continue
|
||||
|
||||
print(f" [RUNNING] {step.step_id}: {step.node_type}...")
|
||||
|
||||
try:
|
||||
node_type = NodeType[step.node_type]
|
||||
except KeyError:
|
||||
node_type = step.node_type
|
||||
|
||||
executor = get_executor(node_type)
|
||||
if executor is None:
|
||||
print(f" SKIP: No executor for {step.node_type}")
|
||||
continue
|
||||
|
||||
# Resolve inputs
|
||||
input_paths_list = []
|
||||
for input_id in step.input_steps:
|
||||
if input_id in cache_paths:
|
||||
input_paths_list.append(Path(cache_paths[input_id]))
|
||||
else:
|
||||
input_step = plan.get_step(input_id)
|
||||
if input_step and input_step.cache_id in cache_paths:
|
||||
input_paths_list.append(Path(cache_paths[input_step.cache_id]))
|
||||
|
||||
if len(input_paths_list) != len(step.input_steps):
|
||||
print(f" ERROR: Missing inputs for {step.step_id}")
|
||||
continue
|
||||
|
||||
output_path = cache.get_output_path(step.cache_id)
|
||||
try:
|
||||
result_path = executor.execute(step.config, input_paths_list, output_path)
|
||||
cache.put(step.cache_id, result_path, node_type=step.node_type)
|
||||
cache_paths[step.cache_id] = str(result_path)
|
||||
cache_paths[step.step_id] = str(result_path)
|
||||
print(f" [DONE]")
|
||||
executed += 1
|
||||
except Exception as e:
|
||||
print(f" [FAILED] {e}")
|
||||
|
||||
# Final output
|
||||
output_step = plan.get_step(plan.output_step)
|
||||
output_path = cache_paths.get(output_step.cache_id) if output_step else None
|
||||
|
||||
print(f"\n=== Complete ===")
|
||||
print(f"Cached: {cached}")
|
||||
print(f"Executed: {executed}")
|
||||
if output_path:
|
||||
print(f"Output: {output_path}")
|
||||
|
||||
|
||||
def cmd_run_recipe_ipfs(args):
|
||||
"""Run complete pipeline with IPFS-primary mode.
|
||||
|
||||
Everything stored on IPFS:
|
||||
- Inputs (media files)
|
||||
- Analysis results (JSON)
|
||||
- Execution plans (JSON)
|
||||
- Step outputs (media files)
|
||||
"""
|
||||
import hashlib
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
from .analysis import Analyzer, AnalysisResult
|
||||
from .planning import RecipePlanner, Recipe, ExecutionPlan
|
||||
from .executor import get_executor
|
||||
from .dag import NodeType
|
||||
from . import nodes # Register built-in executors
|
||||
|
||||
# Check for ipfs_client
|
||||
try:
|
||||
from art_celery import ipfs_client
|
||||
except ImportError:
|
||||
# Try relative import for when running from art-celery
|
||||
try:
|
||||
import ipfs_client
|
||||
except ImportError:
|
||||
print("Error: ipfs_client not available. Install art-celery or run from art-celery directory.")
|
||||
sys.exit(1)
|
||||
|
||||
# Check IPFS availability
|
||||
if not ipfs_client.is_available():
|
||||
print("Error: IPFS daemon not available. Start IPFS with 'ipfs daemon'")
|
||||
sys.exit(1)
|
||||
|
||||
print("=== IPFS-Primary Mode ===")
|
||||
print(f"IPFS Node: {ipfs_client.get_node_id()[:16]}...")
|
||||
|
||||
# Load recipe
|
||||
recipe_path = Path(args.recipe)
|
||||
recipe = Recipe.from_file(recipe_path)
|
||||
print(f"\nRecipe: {recipe.name} v{recipe.version}")
|
||||
|
||||
# Parse inputs
|
||||
input_hashes, input_paths = parse_inputs(args.input)
|
||||
|
||||
# Parse features
|
||||
features = args.features.split(",") if args.features else ["beats", "energy"]
|
||||
|
||||
# Phase 0: Register on IPFS
|
||||
print("\n=== Phase 0: Register on IPFS ===")
|
||||
|
||||
# Register recipe
|
||||
recipe_bytes = recipe_path.read_bytes()
|
||||
recipe_cid = ipfs_client.add_bytes(recipe_bytes)
|
||||
print(f"Recipe CID: {recipe_cid}")
|
||||
|
||||
# Register inputs
|
||||
input_cids = {}
|
||||
for name, hash_value in input_hashes.items():
|
||||
path = input_paths.get(name)
|
||||
if path:
|
||||
cid = ipfs_client.add_file(Path(path))
|
||||
if cid:
|
||||
input_cids[name] = cid
|
||||
print(f"Input '{name}': {cid}")
|
||||
else:
|
||||
print(f"Error: Failed to add input '{name}' to IPFS")
|
||||
sys.exit(1)
|
||||
|
||||
# Phase 1: Analyze
|
||||
print("\n=== Phase 1: Analysis ===")
|
||||
|
||||
# Create temp dir for analysis
|
||||
work_dir = Path(tempfile.mkdtemp(prefix="artdag_ipfs_"))
|
||||
analysis_cids = {}
|
||||
analysis = {}
|
||||
|
||||
try:
|
||||
for name, hash_value in input_hashes.items():
|
||||
input_cid = input_cids.get(name)
|
||||
if not input_cid:
|
||||
continue
|
||||
|
||||
print(f"Analyzing {name}...")
|
||||
|
||||
# Fetch from IPFS to temp
|
||||
temp_input = work_dir / f"input_{name}.mkv"
|
||||
if not ipfs_client.get_file(input_cid, temp_input):
|
||||
print(f" Error: Failed to fetch from IPFS")
|
||||
continue
|
||||
|
||||
# Run analysis
|
||||
analyzer = Analyzer(cache_dir=None)
|
||||
result = analyzer.analyze(
|
||||
input_hash=hash_value,
|
||||
features=features,
|
||||
input_path=temp_input,
|
||||
)
|
||||
|
||||
if result.audio and result.audio.beats:
|
||||
print(f" Tempo: {result.audio.beats.tempo:.1f} BPM, {len(result.audio.beats.beat_times)} beats")
|
||||
|
||||
# Store analysis on IPFS
|
||||
analysis_cid = ipfs_client.add_json(result.to_dict())
|
||||
if analysis_cid:
|
||||
analysis_cids[hash_value] = analysis_cid
|
||||
analysis[hash_value] = result
|
||||
print(f" Analysis CID: {analysis_cid}")
|
||||
|
||||
# Phase 2: Plan
|
||||
print("\n=== Phase 2: Planning ===")
|
||||
|
||||
planner = RecipePlanner(use_tree_reduction=True)
|
||||
plan = planner.plan(
|
||||
recipe=recipe,
|
||||
input_hashes=input_hashes,
|
||||
analysis=analysis if analysis else None,
|
||||
)
|
||||
|
||||
# Store plan on IPFS
|
||||
import json
|
||||
plan_dict = json.loads(plan.to_json())
|
||||
plan_cid = ipfs_client.add_json(plan_dict)
|
||||
print(f"Plan ID: {plan.plan_id[:16]}...")
|
||||
print(f"Plan CID: {plan_cid}")
|
||||
print(f"Steps: {len(plan.steps)}")
|
||||
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
print(f"Levels: {len(steps_by_level)}")
|
||||
|
||||
# Phase 3: Execute
|
||||
print("\n=== Phase 3: Execution ===")
|
||||
|
||||
# CID results
|
||||
cid_results = dict(input_cids)
|
||||
step_cids = {}
|
||||
|
||||
executed = 0
|
||||
cached = 0
|
||||
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
print(f"\nLevel {level}: {len(steps)} steps")
|
||||
|
||||
for step in steps:
|
||||
# Handle SOURCE
|
||||
if step.node_type == "SOURCE":
|
||||
source_name = step.config.get("name") or step.step_id
|
||||
cid = cid_results.get(source_name)
|
||||
if cid:
|
||||
step_cids[step.step_id] = cid
|
||||
print(f" [SOURCE] {step.step_id}")
|
||||
continue
|
||||
|
||||
print(f" [RUNNING] {step.step_id}: {step.node_type}...")
|
||||
|
||||
try:
|
||||
node_type = NodeType[step.node_type]
|
||||
except KeyError:
|
||||
node_type = step.node_type
|
||||
|
||||
executor = get_executor(node_type)
|
||||
if executor is None:
|
||||
print(f" SKIP: No executor for {step.node_type}")
|
||||
continue
|
||||
|
||||
# Fetch inputs from IPFS
|
||||
input_paths_list = []
|
||||
for i, input_step_id in enumerate(step.input_steps):
|
||||
input_cid = step_cids.get(input_step_id) or cid_results.get(input_step_id)
|
||||
if not input_cid:
|
||||
print(f" ERROR: Missing input CID for {input_step_id}")
|
||||
continue
|
||||
|
||||
temp_path = work_dir / f"step_{step.step_id}_input_{i}.mkv"
|
||||
if not ipfs_client.get_file(input_cid, temp_path):
|
||||
print(f" ERROR: Failed to fetch {input_cid}")
|
||||
continue
|
||||
input_paths_list.append(temp_path)
|
||||
|
||||
if len(input_paths_list) != len(step.input_steps):
|
||||
print(f" ERROR: Missing inputs")
|
||||
continue
|
||||
|
||||
# Execute
|
||||
output_path = work_dir / f"step_{step.step_id}_output.mkv"
|
||||
try:
|
||||
result_path = executor.execute(step.config, input_paths_list, output_path)
|
||||
|
||||
# Add to IPFS
|
||||
output_cid = ipfs_client.add_file(result_path)
|
||||
if output_cid:
|
||||
step_cids[step.step_id] = output_cid
|
||||
print(f" [DONE] CID: {output_cid}")
|
||||
executed += 1
|
||||
else:
|
||||
print(f" [FAILED] Could not add to IPFS")
|
||||
except Exception as e:
|
||||
print(f" [FAILED] {e}")
|
||||
|
||||
# Final output
|
||||
output_step = plan.get_step(plan.output_step)
|
||||
output_cid = step_cids.get(output_step.step_id) if output_step else None
|
||||
|
||||
print(f"\n=== Complete ===")
|
||||
print(f"Executed: {executed}")
|
||||
if output_cid:
|
||||
print(f"Output CID: {output_cid}")
|
||||
print(f"Fetch with: ipfs get {output_cid}")
|
||||
|
||||
# Summary of all CIDs
|
||||
print(f"\n=== All CIDs ===")
|
||||
print(f"Recipe: {recipe_cid}")
|
||||
print(f"Plan: {plan_cid}")
|
||||
for name, cid in input_cids.items():
|
||||
print(f"Input '{name}': {cid}")
|
||||
for hash_val, cid in analysis_cids.items():
|
||||
print(f"Analysis '{hash_val[:16]}...': {cid}")
|
||||
if output_cid:
|
||||
print(f"Output: {output_cid}")
|
||||
|
||||
finally:
|
||||
# Cleanup temp
|
||||
shutil.rmtree(work_dir, ignore_errors=True)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="artdag",
|
||||
description="Art DAG - Declarative media composition",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commands")
|
||||
|
||||
# analyze command
|
||||
analyze_parser = subparsers.add_parser("analyze", help="Extract features from inputs")
|
||||
analyze_parser.add_argument("recipe", help="Recipe YAML file")
|
||||
analyze_parser.add_argument("-i", "--input", action="append", required=True,
|
||||
help="Input: name:hash[@path]")
|
||||
analyze_parser.add_argument("--features", help="Features to extract (comma-separated)")
|
||||
analyze_parser.add_argument("-o", "--output", help="Output file (default: analysis.json)")
|
||||
analyze_parser.add_argument("--cache-dir", help="Analysis cache directory")
|
||||
|
||||
# plan command
|
||||
plan_parser = subparsers.add_parser("plan", help="Generate execution plan")
|
||||
plan_parser.add_argument("recipe", help="Recipe YAML file")
|
||||
plan_parser.add_argument("-i", "--input", action="append", required=True,
|
||||
help="Input: name:hash")
|
||||
plan_parser.add_argument("--analysis", help="Analysis JSON file")
|
||||
plan_parser.add_argument("-o", "--output", help="Output file (default: plan.json)")
|
||||
plan_parser.add_argument("--no-tree-reduction", action="store_true",
|
||||
help="Disable tree reduction optimization")
|
||||
|
||||
# execute command
|
||||
execute_parser = subparsers.add_parser("execute", help="Execute a plan")
|
||||
execute_parser.add_argument("plan", help="Plan JSON file")
|
||||
execute_parser.add_argument("--dry-run", action="store_true",
|
||||
help="Show what would execute")
|
||||
execute_parser.add_argument("--cache-dir", help="Cache directory")
|
||||
|
||||
# run-recipe command
|
||||
run_parser = subparsers.add_parser("run-recipe", help="Full pipeline: analyze → plan → execute")
|
||||
run_parser.add_argument("recipe", help="Recipe YAML file")
|
||||
run_parser.add_argument("-i", "--input", action="append", required=True,
|
||||
help="Input: name:hash[@path]")
|
||||
run_parser.add_argument("--features", help="Features to extract (comma-separated)")
|
||||
run_parser.add_argument("--cache-dir", help="Cache directory")
|
||||
run_parser.add_argument("--ipfs-primary", action="store_true",
|
||||
help="Use IPFS-primary mode (everything on IPFS, no local cache)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "analyze":
|
||||
cmd_analyze(args)
|
||||
elif args.command == "plan":
|
||||
cmd_plan(args)
|
||||
elif args.command == "execute":
|
||||
cmd_execute(args)
|
||||
elif args.command == "run-recipe":
|
||||
if getattr(args, 'ipfs_primary', False):
|
||||
cmd_run_recipe_ipfs(args)
|
||||
else:
|
||||
cmd_run_recipe(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
201
core/artdag/client.py
Normal file
201
core/artdag/client.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# primitive/client.py
|
||||
"""
|
||||
Client SDK for the primitive execution server.
|
||||
|
||||
Provides a simple API for submitting DAGs and retrieving results.
|
||||
|
||||
Usage:
|
||||
client = PrimitiveClient("http://localhost:8080")
|
||||
|
||||
# Build a DAG
|
||||
builder = DAGBuilder()
|
||||
source = builder.source("/path/to/video.mp4")
|
||||
segment = builder.segment(source, duration=5.0)
|
||||
builder.set_output(segment)
|
||||
dag = builder.build()
|
||||
|
||||
# Execute and wait for result
|
||||
result = client.execute(dag)
|
||||
print(f"Output: {result.output_path}")
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.error import HTTPError, URLError
|
||||
|
||||
from .dag import DAG, DAGBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result from server execution."""
|
||||
success: bool
|
||||
output_path: Optional[Path] = None
|
||||
error: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
nodes_executed: int = 0
|
||||
nodes_cached: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""Cache statistics from server."""
|
||||
total_entries: int = 0
|
||||
total_size_bytes: int = 0
|
||||
hits: int = 0
|
||||
misses: int = 0
|
||||
hit_rate: float = 0.0
|
||||
|
||||
|
||||
class PrimitiveClient:
|
||||
"""
|
||||
Client for the primitive execution server.
|
||||
|
||||
Args:
|
||||
base_url: Server URL (e.g., "http://localhost:8080")
|
||||
timeout: Request timeout in seconds
|
||||
"""
|
||||
|
||||
def __init__(self, base_url: str = "http://localhost:8080", timeout: float = 300):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
def _request(self, method: str, path: str, data: dict = None) -> dict:
|
||||
"""Make HTTP request to server."""
|
||||
url = f"{self.base_url}{path}"
|
||||
|
||||
if data is not None:
|
||||
body = json.dumps(data).encode()
|
||||
headers = {"Content-Type": "application/json"}
|
||||
else:
|
||||
body = None
|
||||
headers = {}
|
||||
|
||||
req = Request(url, data=body, headers=headers, method=method)
|
||||
|
||||
try:
|
||||
with urlopen(req, timeout=self.timeout) as response:
|
||||
return json.loads(response.read().decode())
|
||||
except HTTPError as e:
|
||||
error_body = e.read().decode()
|
||||
try:
|
||||
error_data = json.loads(error_body)
|
||||
raise RuntimeError(error_data.get("error", str(e)))
|
||||
except json.JSONDecodeError:
|
||||
raise RuntimeError(f"HTTP {e.code}: {error_body}")
|
||||
except URLError as e:
|
||||
raise ConnectionError(f"Failed to connect to server: {e}")
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if server is healthy."""
|
||||
try:
|
||||
result = self._request("GET", "/health")
|
||||
return result.get("status") == "ok"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def submit(self, dag: DAG) -> str:
|
||||
"""
|
||||
Submit a DAG for execution.
|
||||
|
||||
Args:
|
||||
dag: The DAG to execute
|
||||
|
||||
Returns:
|
||||
Job ID for tracking
|
||||
"""
|
||||
result = self._request("POST", "/execute", dag.to_dict())
|
||||
return result["job_id"]
|
||||
|
||||
def status(self, job_id: str) -> str:
|
||||
"""
|
||||
Get job status.
|
||||
|
||||
Args:
|
||||
job_id: Job ID from submit()
|
||||
|
||||
Returns:
|
||||
Status: "pending", "running", "completed", or "failed"
|
||||
"""
|
||||
result = self._request("GET", f"/status/{job_id}")
|
||||
return result["status"]
|
||||
|
||||
def result(self, job_id: str) -> Optional[ExecutionResult]:
|
||||
"""
|
||||
Get job result (non-blocking).
|
||||
|
||||
Args:
|
||||
job_id: Job ID from submit()
|
||||
|
||||
Returns:
|
||||
ExecutionResult if complete, None if still running
|
||||
"""
|
||||
data = self._request("GET", f"/result/{job_id}")
|
||||
|
||||
if not data.get("ready", False):
|
||||
return None
|
||||
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output_path=Path(data["output_path"]) if data.get("output_path") else None,
|
||||
error=data.get("error"),
|
||||
execution_time=data.get("execution_time", 0),
|
||||
nodes_executed=data.get("nodes_executed", 0),
|
||||
nodes_cached=data.get("nodes_cached", 0),
|
||||
)
|
||||
|
||||
def wait(self, job_id: str, poll_interval: float = 0.5) -> ExecutionResult:
|
||||
"""
|
||||
Wait for job completion and return result.
|
||||
|
||||
Args:
|
||||
job_id: Job ID from submit()
|
||||
poll_interval: Seconds between status checks
|
||||
|
||||
Returns:
|
||||
ExecutionResult
|
||||
"""
|
||||
while True:
|
||||
result = self.result(job_id)
|
||||
if result is not None:
|
||||
return result
|
||||
time.sleep(poll_interval)
|
||||
|
||||
def execute(self, dag: DAG, poll_interval: float = 0.5) -> ExecutionResult:
|
||||
"""
|
||||
Submit DAG and wait for result.
|
||||
|
||||
Convenience method combining submit() and wait().
|
||||
|
||||
Args:
|
||||
dag: The DAG to execute
|
||||
poll_interval: Seconds between status checks
|
||||
|
||||
Returns:
|
||||
ExecutionResult
|
||||
"""
|
||||
job_id = self.submit(dag)
|
||||
return self.wait(job_id, poll_interval)
|
||||
|
||||
def cache_stats(self) -> CacheStats:
|
||||
"""Get cache statistics."""
|
||||
data = self._request("GET", "/cache/stats")
|
||||
return CacheStats(
|
||||
total_entries=data.get("total_entries", 0),
|
||||
total_size_bytes=data.get("total_size_bytes", 0),
|
||||
hits=data.get("hits", 0),
|
||||
misses=data.get("misses", 0),
|
||||
hit_rate=data.get("hit_rate", 0.0),
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the server cache."""
|
||||
self._request("DELETE", "/cache")
|
||||
|
||||
|
||||
# Re-export DAGBuilder for convenience
|
||||
__all__ = ["PrimitiveClient", "ExecutionResult", "CacheStats", "DAGBuilder"]
|
||||
344
core/artdag/dag.py
Normal file
344
core/artdag/dag.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# primitive/dag.py
|
||||
"""
|
||||
Core DAG data structures.
|
||||
|
||||
Nodes are content-addressed: node_id = hash(type + config + input_ids)
|
||||
This enables automatic caching and deduplication.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
"""Built-in node types."""
|
||||
# Source operations
|
||||
SOURCE = auto() # Load file from path
|
||||
|
||||
# Transform operations
|
||||
SEGMENT = auto() # Extract time range
|
||||
RESIZE = auto() # Scale/crop/pad
|
||||
TRANSFORM = auto() # Visual effects (color, blur, etc.)
|
||||
|
||||
# Compose operations
|
||||
SEQUENCE = auto() # Concatenate in time
|
||||
LAYER = auto() # Stack spatially (overlay)
|
||||
MUX = auto() # Combine video + audio streams
|
||||
BLEND = auto() # Blend two inputs
|
||||
AUDIO_MIX = auto() # Mix multiple audio streams
|
||||
SWITCH = auto() # Time-based input switching
|
||||
|
||||
# Analysis operations
|
||||
ANALYZE = auto() # Extract features (audio, motion, etc.)
|
||||
|
||||
# Generation operations
|
||||
GENERATE = auto() # Create content (text, graphics, etc.)
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""
|
||||
Create stable hash from arbitrary data.
|
||||
|
||||
Uses SHA-3 (Keccak) for quantum resistance.
|
||||
Returns full hash - no truncation.
|
||||
|
||||
Args:
|
||||
data: Data to hash (will be JSON serialized)
|
||||
algorithm: Hash algorithm (default: sha3_256)
|
||||
|
||||
Returns:
|
||||
Full hex digest
|
||||
"""
|
||||
# Convert to JSON with sorted keys for stability
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
"""
|
||||
A node in the execution DAG.
|
||||
|
||||
Attributes:
|
||||
node_type: The operation type (NodeType enum or string for custom types)
|
||||
config: Operation-specific configuration
|
||||
inputs: List of input node IDs (resolved during execution)
|
||||
node_id: Content-addressed ID (computed from type + config + inputs)
|
||||
name: Optional human-readable name for debugging
|
||||
"""
|
||||
node_type: NodeType | str
|
||||
config: Dict[str, Any] = field(default_factory=dict)
|
||||
inputs: List[str] = field(default_factory=list)
|
||||
node_id: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Compute node_id if not provided."""
|
||||
if self.node_id is None:
|
||||
self.node_id = self._compute_id()
|
||||
|
||||
def _compute_id(self) -> str:
|
||||
"""Compute content-addressed ID from node contents."""
|
||||
type_str = self.node_type.name if isinstance(self.node_type, NodeType) else str(self.node_type)
|
||||
content = {
|
||||
"type": type_str,
|
||||
"config": self.config,
|
||||
"inputs": sorted(self.inputs), # Sort for stability
|
||||
}
|
||||
return _stable_hash(content)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize node to dictionary."""
|
||||
type_str = self.node_type.name if isinstance(self.node_type, NodeType) else str(self.node_type)
|
||||
return {
|
||||
"node_id": self.node_id,
|
||||
"node_type": type_str,
|
||||
"config": self.config,
|
||||
"inputs": self.inputs,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Node":
|
||||
"""Deserialize node from dictionary."""
|
||||
type_str = data["node_type"]
|
||||
try:
|
||||
node_type = NodeType[type_str]
|
||||
except KeyError:
|
||||
node_type = type_str # Custom type as string
|
||||
|
||||
return cls(
|
||||
node_type=node_type,
|
||||
config=data.get("config", {}),
|
||||
inputs=data.get("inputs", []),
|
||||
node_id=data.get("node_id"),
|
||||
name=data.get("name"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DAG:
|
||||
"""
|
||||
A directed acyclic graph of nodes.
|
||||
|
||||
Attributes:
|
||||
nodes: Dictionary mapping node_id -> Node
|
||||
output_id: The ID of the final output node
|
||||
metadata: Optional metadata about the DAG (source, version, etc.)
|
||||
"""
|
||||
nodes: Dict[str, Node] = field(default_factory=dict)
|
||||
output_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def add_node(self, node: Node) -> str:
|
||||
"""Add a node to the DAG, returning its ID."""
|
||||
if node.node_id in self.nodes:
|
||||
# Node already exists (deduplication via content addressing)
|
||||
return node.node_id
|
||||
self.nodes[node.node_id] = node
|
||||
return node.node_id
|
||||
|
||||
def set_output(self, node_id: str) -> None:
|
||||
"""Set the output node."""
|
||||
if node_id not in self.nodes:
|
||||
raise ValueError(f"Node {node_id} not in DAG")
|
||||
self.output_id = node_id
|
||||
|
||||
def get_node(self, node_id: str) -> Node:
|
||||
"""Get a node by ID."""
|
||||
if node_id not in self.nodes:
|
||||
raise KeyError(f"Node {node_id} not found")
|
||||
return self.nodes[node_id]
|
||||
|
||||
def topological_order(self) -> List[str]:
|
||||
"""Return nodes in topological order (dependencies first)."""
|
||||
visited = set()
|
||||
order = []
|
||||
|
||||
def visit(node_id: str):
|
||||
if node_id in visited:
|
||||
return
|
||||
visited.add(node_id)
|
||||
node = self.nodes[node_id]
|
||||
for input_id in node.inputs:
|
||||
visit(input_id)
|
||||
order.append(node_id)
|
||||
|
||||
# Visit all nodes (not just output, in case of disconnected components)
|
||||
for node_id in self.nodes:
|
||||
visit(node_id)
|
||||
|
||||
return order
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate DAG structure. Returns list of errors (empty if valid)."""
|
||||
errors = []
|
||||
|
||||
if self.output_id is None:
|
||||
errors.append("No output node set")
|
||||
elif self.output_id not in self.nodes:
|
||||
errors.append(f"Output node {self.output_id} not in DAG")
|
||||
|
||||
# Check all input references are valid
|
||||
for node_id, node in self.nodes.items():
|
||||
for input_id in node.inputs:
|
||||
if input_id not in self.nodes:
|
||||
errors.append(f"Node {node_id} references missing input {input_id}")
|
||||
|
||||
# Check for cycles (skip if we already found missing inputs)
|
||||
if not any("missing" in e for e in errors):
|
||||
try:
|
||||
self.topological_order()
|
||||
except (RecursionError, KeyError):
|
||||
errors.append("DAG contains cycles or invalid references")
|
||||
|
||||
return errors
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize DAG to dictionary."""
|
||||
return {
|
||||
"nodes": {nid: node.to_dict() for nid, node in self.nodes.items()},
|
||||
"output_id": self.output_id,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "DAG":
|
||||
"""Deserialize DAG from dictionary."""
|
||||
dag = cls(metadata=data.get("metadata", {}))
|
||||
for node_data in data.get("nodes", {}).values():
|
||||
dag.add_node(Node.from_dict(node_data))
|
||||
dag.output_id = data.get("output_id")
|
||||
return dag
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Serialize DAG to JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "DAG":
|
||||
"""Deserialize DAG from JSON string."""
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
|
||||
class DAGBuilder:
|
||||
"""
|
||||
Fluent builder for constructing DAGs.
|
||||
|
||||
Example:
|
||||
builder = DAGBuilder()
|
||||
source = builder.source("/path/to/video.mp4")
|
||||
segment = builder.segment(source, duration=5.0)
|
||||
builder.set_output(segment)
|
||||
dag = builder.build()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.dag = DAG()
|
||||
|
||||
def _add(self, node_type: NodeType | str, config: Dict[str, Any],
|
||||
inputs: List[str] = None, name: str = None) -> str:
|
||||
"""Add a node and return its ID."""
|
||||
node = Node(
|
||||
node_type=node_type,
|
||||
config=config,
|
||||
inputs=inputs or [],
|
||||
name=name,
|
||||
)
|
||||
return self.dag.add_node(node)
|
||||
|
||||
# Source operations
|
||||
|
||||
def source(self, path: str, name: str = None) -> str:
|
||||
"""Add a SOURCE node."""
|
||||
return self._add(NodeType.SOURCE, {"path": path}, name=name)
|
||||
|
||||
# Transform operations
|
||||
|
||||
def segment(self, input_id: str, duration: float = None,
|
||||
offset: float = 0, precise: bool = True, name: str = None) -> str:
|
||||
"""Add a SEGMENT node."""
|
||||
config = {"offset": offset, "precise": precise}
|
||||
if duration is not None:
|
||||
config["duration"] = duration
|
||||
return self._add(NodeType.SEGMENT, config, [input_id], name=name)
|
||||
|
||||
def resize(self, input_id: str, width: int, height: int,
|
||||
mode: str = "fit", name: str = None) -> str:
|
||||
"""Add a RESIZE node."""
|
||||
return self._add(
|
||||
NodeType.RESIZE,
|
||||
{"width": width, "height": height, "mode": mode},
|
||||
[input_id],
|
||||
name=name
|
||||
)
|
||||
|
||||
def transform(self, input_id: str, effects: Dict[str, Any],
|
||||
name: str = None) -> str:
|
||||
"""Add a TRANSFORM node."""
|
||||
return self._add(NodeType.TRANSFORM, {"effects": effects}, [input_id], name=name)
|
||||
|
||||
# Compose operations
|
||||
|
||||
def sequence(self, input_ids: List[str], transition: Dict[str, Any] = None,
|
||||
name: str = None) -> str:
|
||||
"""Add a SEQUENCE node."""
|
||||
config = {"transition": transition or {"type": "cut"}}
|
||||
return self._add(NodeType.SEQUENCE, config, input_ids, name=name)
|
||||
|
||||
def layer(self, input_ids: List[str], configs: List[Dict] = None,
|
||||
name: str = None) -> str:
|
||||
"""Add a LAYER node."""
|
||||
return self._add(
|
||||
NodeType.LAYER,
|
||||
{"inputs": configs or [{}] * len(input_ids)},
|
||||
input_ids,
|
||||
name=name
|
||||
)
|
||||
|
||||
def mux(self, video_id: str, audio_id: str, shortest: bool = True,
|
||||
name: str = None) -> str:
|
||||
"""Add a MUX node."""
|
||||
return self._add(
|
||||
NodeType.MUX,
|
||||
{"video_stream": 0, "audio_stream": 1, "shortest": shortest},
|
||||
[video_id, audio_id],
|
||||
name=name
|
||||
)
|
||||
|
||||
def blend(self, input1_id: str, input2_id: str, mode: str = "overlay",
|
||||
opacity: float = 0.5, name: str = None) -> str:
|
||||
"""Add a BLEND node."""
|
||||
return self._add(
|
||||
NodeType.BLEND,
|
||||
{"mode": mode, "opacity": opacity},
|
||||
[input1_id, input2_id],
|
||||
name=name
|
||||
)
|
||||
|
||||
def audio_mix(self, input_ids: List[str], gains: List[float] = None,
|
||||
normalize: bool = True, name: str = None) -> str:
|
||||
"""Add an AUDIO_MIX node to mix multiple audio streams."""
|
||||
config = {"normalize": normalize}
|
||||
if gains is not None:
|
||||
config["gains"] = gains
|
||||
return self._add(NodeType.AUDIO_MIX, config, input_ids, name=name)
|
||||
|
||||
# Output
|
||||
|
||||
def set_output(self, node_id: str) -> "DAGBuilder":
|
||||
"""Set the output node."""
|
||||
self.dag.set_output(node_id)
|
||||
return self
|
||||
|
||||
def build(self) -> DAG:
|
||||
"""Build and validate the DAG."""
|
||||
errors = self.dag.validate()
|
||||
if errors:
|
||||
raise ValueError(f"Invalid DAG: {errors}")
|
||||
return self.dag
|
||||
55
core/artdag/effects/__init__.py
Normal file
55
core/artdag/effects/__init__.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Cacheable effect system.
|
||||
|
||||
Effects are single Python files with:
|
||||
- PEP 723 embedded dependencies
|
||||
- @-tag metadata in docstrings
|
||||
- Frame-by-frame or whole-video API
|
||||
|
||||
Effects are cached by content hash (SHA3-256) and executed in
|
||||
sandboxed environments for determinism.
|
||||
"""
|
||||
|
||||
from .meta import EffectMeta, ParamSpec, ExecutionContext
|
||||
from .loader import load_effect, load_effect_file, LoadedEffect, compute_cid
|
||||
from .binding import (
|
||||
AnalysisData,
|
||||
ResolvedBinding,
|
||||
resolve_binding,
|
||||
resolve_all_bindings,
|
||||
bindings_to_lookup_table,
|
||||
has_bindings,
|
||||
extract_binding_sources,
|
||||
)
|
||||
from .sandbox import Sandbox, SandboxConfig, SandboxResult, is_bwrap_available, get_venv_path
|
||||
from .runner import run_effect, run_effect_from_cache, EffectExecutor
|
||||
|
||||
__all__ = [
|
||||
# Meta types
|
||||
"EffectMeta",
|
||||
"ParamSpec",
|
||||
"ExecutionContext",
|
||||
# Loader
|
||||
"load_effect",
|
||||
"load_effect_file",
|
||||
"LoadedEffect",
|
||||
"compute_cid",
|
||||
# Binding
|
||||
"AnalysisData",
|
||||
"ResolvedBinding",
|
||||
"resolve_binding",
|
||||
"resolve_all_bindings",
|
||||
"bindings_to_lookup_table",
|
||||
"has_bindings",
|
||||
"extract_binding_sources",
|
||||
# Sandbox
|
||||
"Sandbox",
|
||||
"SandboxConfig",
|
||||
"SandboxResult",
|
||||
"is_bwrap_available",
|
||||
"get_venv_path",
|
||||
# Runner
|
||||
"run_effect",
|
||||
"run_effect_from_cache",
|
||||
"EffectExecutor",
|
||||
]
|
||||
311
core/artdag/effects/binding.py
Normal file
311
core/artdag/effects/binding.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
Parameter binding resolution.
|
||||
|
||||
Resolves bind expressions to per-frame lookup tables at plan time.
|
||||
Binding options:
|
||||
- :range [lo hi] - map 0-1 to output range
|
||||
- :smooth N - smoothing window in seconds
|
||||
- :offset N - time offset in seconds
|
||||
- :on-event V - value on discrete events
|
||||
- :decay N - exponential decay after event
|
||||
- :noise N - add deterministic noise (seeded)
|
||||
- :seed N - explicit RNG seed
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import math
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisData:
|
||||
"""
|
||||
Analysis data for binding resolution.
|
||||
|
||||
Attributes:
|
||||
frame_rate: Video frame rate
|
||||
total_frames: Total number of frames
|
||||
features: Dict mapping feature name to per-frame values
|
||||
events: Dict mapping event name to list of frame indices
|
||||
"""
|
||||
|
||||
frame_rate: float
|
||||
total_frames: int
|
||||
features: Dict[str, List[float]] # feature -> [value_per_frame]
|
||||
events: Dict[str, List[int]] # event -> [frame_indices]
|
||||
|
||||
def get_feature(self, name: str, frame: int) -> float:
|
||||
"""Get feature value at frame, interpolating if needed."""
|
||||
if name not in self.features:
|
||||
return 0.0
|
||||
values = self.features[name]
|
||||
if not values:
|
||||
return 0.0
|
||||
if frame >= len(values):
|
||||
return values[-1]
|
||||
return values[frame]
|
||||
|
||||
def get_events_in_range(
|
||||
self, name: str, start_frame: int, end_frame: int
|
||||
) -> List[int]:
|
||||
"""Get event frames in range."""
|
||||
if name not in self.events:
|
||||
return []
|
||||
return [f for f in self.events[name] if start_frame <= f < end_frame]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedBinding:
|
||||
"""
|
||||
Resolved binding with per-frame values.
|
||||
|
||||
Attributes:
|
||||
param_name: Parameter this binding applies to
|
||||
values: List of values, one per frame
|
||||
"""
|
||||
|
||||
param_name: str
|
||||
values: List[float]
|
||||
|
||||
def get(self, frame: int) -> float:
|
||||
"""Get value at frame."""
|
||||
if frame >= len(self.values):
|
||||
return self.values[-1] if self.values else 0.0
|
||||
return self.values[frame]
|
||||
|
||||
|
||||
def resolve_binding(
|
||||
binding: Dict[str, Any],
|
||||
analysis: AnalysisData,
|
||||
param_name: str,
|
||||
cache_id: str = None,
|
||||
) -> ResolvedBinding:
|
||||
"""
|
||||
Resolve a binding specification to per-frame values.
|
||||
|
||||
Args:
|
||||
binding: Binding spec with source, feature, and options
|
||||
analysis: Analysis data with features and events
|
||||
param_name: Name of the parameter being bound
|
||||
cache_id: Cache ID for deterministic seeding
|
||||
|
||||
Returns:
|
||||
ResolvedBinding with values for each frame
|
||||
"""
|
||||
feature = binding.get("feature")
|
||||
if not feature:
|
||||
raise ValueError(f"Binding for {param_name} missing feature")
|
||||
|
||||
# Get base values
|
||||
values = []
|
||||
is_event = feature in analysis.events
|
||||
|
||||
if is_event:
|
||||
# Event-based binding
|
||||
on_event = binding.get("on_event", 1.0)
|
||||
decay = binding.get("decay", 0.0)
|
||||
values = _resolve_event_binding(
|
||||
analysis.events.get(feature, []),
|
||||
analysis.total_frames,
|
||||
analysis.frame_rate,
|
||||
on_event,
|
||||
decay,
|
||||
)
|
||||
else:
|
||||
# Continuous feature binding
|
||||
feature_values = analysis.features.get(feature, [])
|
||||
if not feature_values:
|
||||
# No data, use zeros
|
||||
values = [0.0] * analysis.total_frames
|
||||
else:
|
||||
# Extend to total frames if needed
|
||||
values = list(feature_values)
|
||||
while len(values) < analysis.total_frames:
|
||||
values.append(values[-1] if values else 0.0)
|
||||
|
||||
# Apply offset
|
||||
offset = binding.get("offset")
|
||||
if offset:
|
||||
offset_frames = int(offset * analysis.frame_rate)
|
||||
values = _apply_offset(values, offset_frames)
|
||||
|
||||
# Apply smoothing
|
||||
smooth = binding.get("smooth")
|
||||
if smooth:
|
||||
window_frames = int(smooth * analysis.frame_rate)
|
||||
values = _apply_smoothing(values, window_frames)
|
||||
|
||||
# Apply range mapping
|
||||
range_spec = binding.get("range")
|
||||
if range_spec:
|
||||
lo, hi = range_spec
|
||||
values = _apply_range(values, lo, hi)
|
||||
|
||||
# Apply noise
|
||||
noise = binding.get("noise")
|
||||
if noise:
|
||||
seed = binding.get("seed")
|
||||
if seed is None and cache_id:
|
||||
# Derive seed from cache_id for determinism
|
||||
seed = int(hashlib.sha256(cache_id.encode()).hexdigest()[:8], 16)
|
||||
values = _apply_noise(values, noise, seed or 0)
|
||||
|
||||
return ResolvedBinding(param_name=param_name, values=values)
|
||||
|
||||
|
||||
def _resolve_event_binding(
|
||||
event_frames: List[int],
|
||||
total_frames: int,
|
||||
frame_rate: float,
|
||||
on_event: float,
|
||||
decay: float,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Resolve event-based binding with optional decay.
|
||||
|
||||
Args:
|
||||
event_frames: List of frame indices where events occur
|
||||
total_frames: Total number of frames
|
||||
frame_rate: Video frame rate
|
||||
on_event: Value at event
|
||||
decay: Decay time constant in seconds (0 = instant)
|
||||
|
||||
Returns:
|
||||
List of values per frame
|
||||
"""
|
||||
values = [0.0] * total_frames
|
||||
|
||||
if not event_frames:
|
||||
return values
|
||||
|
||||
event_set = set(event_frames)
|
||||
|
||||
if decay <= 0:
|
||||
# No decay - just mark event frames
|
||||
for f in event_frames:
|
||||
if 0 <= f < total_frames:
|
||||
values[f] = on_event
|
||||
else:
|
||||
# Apply exponential decay
|
||||
decay_frames = decay * frame_rate
|
||||
for f in event_frames:
|
||||
if f < 0 or f >= total_frames:
|
||||
continue
|
||||
# Apply decay from this event forward
|
||||
for i in range(f, total_frames):
|
||||
elapsed = i - f
|
||||
decayed = on_event * math.exp(-elapsed / decay_frames)
|
||||
if decayed < 0.001:
|
||||
break
|
||||
values[i] = max(values[i], decayed)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def _apply_offset(values: List[float], offset_frames: int) -> List[float]:
|
||||
"""Shift values by offset frames (positive = delay)."""
|
||||
if offset_frames == 0:
|
||||
return values
|
||||
|
||||
n = len(values)
|
||||
result = [0.0] * n
|
||||
|
||||
for i in range(n):
|
||||
src = i - offset_frames
|
||||
if 0 <= src < n:
|
||||
result[i] = values[src]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _apply_smoothing(values: List[float], window_frames: int) -> List[float]:
|
||||
"""Apply moving average smoothing."""
|
||||
if window_frames <= 1:
|
||||
return values
|
||||
|
||||
n = len(values)
|
||||
result = []
|
||||
half = window_frames // 2
|
||||
|
||||
for i in range(n):
|
||||
start = max(0, i - half)
|
||||
end = min(n, i + half + 1)
|
||||
avg = sum(values[start:end]) / (end - start)
|
||||
result.append(avg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _apply_range(values: List[float], lo: float, hi: float) -> List[float]:
|
||||
"""Map values from 0-1 to lo-hi range."""
|
||||
return [lo + v * (hi - lo) for v in values]
|
||||
|
||||
|
||||
def _apply_noise(values: List[float], amount: float, seed: int) -> List[float]:
|
||||
"""Add deterministic noise to values."""
|
||||
rng = random.Random(seed)
|
||||
return [v + rng.uniform(-amount, amount) for v in values]
|
||||
|
||||
|
||||
def resolve_all_bindings(
|
||||
config: Dict[str, Any],
|
||||
analysis: AnalysisData,
|
||||
cache_id: str = None,
|
||||
) -> Dict[str, ResolvedBinding]:
|
||||
"""
|
||||
Resolve all bindings in a config dict.
|
||||
|
||||
Looks for values with _binding: True marker.
|
||||
|
||||
Args:
|
||||
config: Node config with potential bindings
|
||||
analysis: Analysis data
|
||||
cache_id: Cache ID for seeding
|
||||
|
||||
Returns:
|
||||
Dict mapping param name to resolved binding
|
||||
"""
|
||||
resolved = {}
|
||||
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and value.get("_binding"):
|
||||
resolved[key] = resolve_binding(value, analysis, key, cache_id)
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def bindings_to_lookup_table(
|
||||
bindings: Dict[str, ResolvedBinding],
|
||||
) -> Dict[str, List[float]]:
|
||||
"""
|
||||
Convert resolved bindings to simple lookup tables.
|
||||
|
||||
Returns dict mapping param name to list of per-frame values.
|
||||
This format is JSON-serializable for inclusion in execution plans.
|
||||
"""
|
||||
return {name: binding.values for name, binding in bindings.items()}
|
||||
|
||||
|
||||
def has_bindings(config: Dict[str, Any]) -> bool:
|
||||
"""Check if config contains any bindings."""
|
||||
for value in config.values():
|
||||
if isinstance(value, dict) and value.get("_binding"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_binding_sources(config: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Extract all analysis source references from bindings.
|
||||
|
||||
Returns list of node IDs that provide analysis data.
|
||||
"""
|
||||
sources = []
|
||||
for value in config.values():
|
||||
if isinstance(value, dict) and value.get("_binding"):
|
||||
source = value.get("source")
|
||||
if source and source not in sources:
|
||||
sources.append(source)
|
||||
return sources
|
||||
347
core/artdag/effects/frame_processor.py
Normal file
347
core/artdag/effects/frame_processor.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""
|
||||
FFmpeg pipe-based frame processing.
|
||||
|
||||
Processes video through Python frame-by-frame effects using FFmpeg pipes:
|
||||
FFmpeg decode -> Python process_frame -> FFmpeg encode
|
||||
|
||||
This avoids writing intermediate frames to disk.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoInfo:
|
||||
"""Video metadata."""
|
||||
|
||||
width: int
|
||||
height: int
|
||||
frame_rate: float
|
||||
total_frames: int
|
||||
duration: float
|
||||
pixel_format: str = "rgb24"
|
||||
|
||||
|
||||
def probe_video(path: Path) -> VideoInfo:
|
||||
"""
|
||||
Get video information using ffprobe.
|
||||
|
||||
Args:
|
||||
path: Path to video file
|
||||
|
||||
Returns:
|
||||
VideoInfo with dimensions, frame rate, etc.
|
||||
"""
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=width,height,r_frame_rate,nb_frames,duration",
|
||||
"-of", "csv=p=0",
|
||||
str(path),
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"ffprobe failed: {result.stderr}")
|
||||
|
||||
parts = result.stdout.strip().split(",")
|
||||
if len(parts) < 4:
|
||||
raise RuntimeError(f"Unexpected ffprobe output: {result.stdout}")
|
||||
|
||||
width = int(parts[0])
|
||||
height = int(parts[1])
|
||||
|
||||
# Parse frame rate (could be "30/1" or "30")
|
||||
fr_parts = parts[2].split("/")
|
||||
if len(fr_parts) == 2:
|
||||
frame_rate = float(fr_parts[0]) / float(fr_parts[1])
|
||||
else:
|
||||
frame_rate = float(fr_parts[0])
|
||||
|
||||
# nb_frames might be N/A
|
||||
total_frames = 0
|
||||
duration = 0.0
|
||||
try:
|
||||
total_frames = int(parts[3])
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
try:
|
||||
duration = float(parts[4]) if len(parts) > 4 else 0.0
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
if total_frames == 0 and duration > 0:
|
||||
total_frames = int(duration * frame_rate)
|
||||
|
||||
return VideoInfo(
|
||||
width=width,
|
||||
height=height,
|
||||
frame_rate=frame_rate,
|
||||
total_frames=total_frames,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
|
||||
FrameProcessor = Callable[[np.ndarray, Dict[str, Any], Any], Tuple[np.ndarray, Any]]
|
||||
|
||||
|
||||
def process_video(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
process_frame: FrameProcessor,
|
||||
params: Dict[str, Any],
|
||||
bindings: Dict[str, List[float]] = None,
|
||||
initial_state: Any = None,
|
||||
pixel_format: str = "rgb24",
|
||||
output_codec: str = "libx264",
|
||||
output_options: List[str] = None,
|
||||
) -> Tuple[Path, Any]:
|
||||
"""
|
||||
Process video through frame-by-frame effect.
|
||||
|
||||
Args:
|
||||
input_path: Input video path
|
||||
output_path: Output video path
|
||||
process_frame: Function (frame, params, state) -> (frame, state)
|
||||
params: Static parameter dict
|
||||
bindings: Per-frame parameter lookup tables
|
||||
initial_state: Initial state for process_frame
|
||||
pixel_format: Pixel format for frame data
|
||||
output_codec: Video codec for output
|
||||
output_options: Additional ffmpeg output options
|
||||
|
||||
Returns:
|
||||
Tuple of (output_path, final_state)
|
||||
"""
|
||||
bindings = bindings or {}
|
||||
output_options = output_options or []
|
||||
|
||||
# Probe input
|
||||
info = probe_video(input_path)
|
||||
logger.info(f"Processing {info.width}x{info.height} @ {info.frame_rate}fps")
|
||||
|
||||
# Calculate bytes per frame
|
||||
if pixel_format == "rgb24":
|
||||
bytes_per_pixel = 3
|
||||
elif pixel_format == "rgba":
|
||||
bytes_per_pixel = 4
|
||||
else:
|
||||
bytes_per_pixel = 3 # Default to RGB
|
||||
|
||||
frame_size = info.width * info.height * bytes_per_pixel
|
||||
|
||||
# Start decoder process
|
||||
decode_cmd = [
|
||||
"ffmpeg",
|
||||
"-i", str(input_path),
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", pixel_format,
|
||||
"-",
|
||||
]
|
||||
|
||||
# Start encoder process
|
||||
encode_cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", pixel_format,
|
||||
"-s", f"{info.width}x{info.height}",
|
||||
"-r", str(info.frame_rate),
|
||||
"-i", "-",
|
||||
"-i", str(input_path), # For audio
|
||||
"-map", "0:v",
|
||||
"-map", "1:a?",
|
||||
"-c:v", output_codec,
|
||||
"-c:a", "aac",
|
||||
*output_options,
|
||||
str(output_path),
|
||||
]
|
||||
|
||||
logger.debug(f"Decoder: {' '.join(decode_cmd)}")
|
||||
logger.debug(f"Encoder: {' '.join(encode_cmd)}")
|
||||
|
||||
decoder = subprocess.Popen(
|
||||
decode_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
encoder = subprocess.Popen(
|
||||
encode_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
state = initial_state
|
||||
frame_idx = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Read frame from decoder
|
||||
raw_frame = decoder.stdout.read(frame_size)
|
||||
if len(raw_frame) < frame_size:
|
||||
break
|
||||
|
||||
# Convert to numpy
|
||||
frame = np.frombuffer(raw_frame, dtype=np.uint8)
|
||||
frame = frame.reshape((info.height, info.width, bytes_per_pixel))
|
||||
|
||||
# Build per-frame params
|
||||
frame_params = dict(params)
|
||||
for param_name, values in bindings.items():
|
||||
if frame_idx < len(values):
|
||||
frame_params[param_name] = values[frame_idx]
|
||||
|
||||
# Process frame
|
||||
processed, state = process_frame(frame, frame_params, state)
|
||||
|
||||
# Ensure correct shape and dtype
|
||||
if processed.shape != frame.shape:
|
||||
raise ValueError(
|
||||
f"Frame shape mismatch: {processed.shape} vs {frame.shape}"
|
||||
)
|
||||
processed = processed.astype(np.uint8)
|
||||
|
||||
# Write to encoder
|
||||
encoder.stdin.write(processed.tobytes())
|
||||
frame_idx += 1
|
||||
|
||||
if frame_idx % 100 == 0:
|
||||
logger.debug(f"Processed frame {frame_idx}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Frame processing failed at frame {frame_idx}: {e}")
|
||||
raise
|
||||
finally:
|
||||
decoder.stdout.close()
|
||||
decoder.wait()
|
||||
encoder.stdin.close()
|
||||
encoder.wait()
|
||||
|
||||
if encoder.returncode != 0:
|
||||
stderr = encoder.stderr.read().decode() if encoder.stderr else ""
|
||||
raise RuntimeError(f"Encoder failed: {stderr}")
|
||||
|
||||
logger.info(f"Processed {frame_idx} frames")
|
||||
return output_path, state
|
||||
|
||||
|
||||
def process_video_batch(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
process_frames: Callable[[List[np.ndarray], Dict[str, Any]], List[np.ndarray]],
|
||||
params: Dict[str, Any],
|
||||
batch_size: int = 30,
|
||||
pixel_format: str = "rgb24",
|
||||
output_codec: str = "libx264",
|
||||
) -> Path:
|
||||
"""
|
||||
Process video in batches for effects that need temporal context.
|
||||
|
||||
Args:
|
||||
input_path: Input video path
|
||||
output_path: Output video path
|
||||
process_frames: Function (frames_batch, params) -> processed_batch
|
||||
params: Parameter dict
|
||||
batch_size: Number of frames per batch
|
||||
pixel_format: Pixel format
|
||||
output_codec: Output codec
|
||||
|
||||
Returns:
|
||||
Output path
|
||||
"""
|
||||
info = probe_video(input_path)
|
||||
|
||||
if pixel_format == "rgb24":
|
||||
bytes_per_pixel = 3
|
||||
elif pixel_format == "rgba":
|
||||
bytes_per_pixel = 4
|
||||
else:
|
||||
bytes_per_pixel = 3
|
||||
|
||||
frame_size = info.width * info.height * bytes_per_pixel
|
||||
|
||||
decode_cmd = [
|
||||
"ffmpeg",
|
||||
"-i", str(input_path),
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", pixel_format,
|
||||
"-",
|
||||
]
|
||||
|
||||
encode_cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", pixel_format,
|
||||
"-s", f"{info.width}x{info.height}",
|
||||
"-r", str(info.frame_rate),
|
||||
"-i", "-",
|
||||
"-i", str(input_path),
|
||||
"-map", "0:v",
|
||||
"-map", "1:a?",
|
||||
"-c:v", output_codec,
|
||||
"-c:a", "aac",
|
||||
str(output_path),
|
||||
]
|
||||
|
||||
decoder = subprocess.Popen(
|
||||
decode_cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
|
||||
encoder = subprocess.Popen(
|
||||
encode_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
batch = []
|
||||
total_processed = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw_frame = decoder.stdout.read(frame_size)
|
||||
if len(raw_frame) < frame_size:
|
||||
# Process remaining batch
|
||||
if batch:
|
||||
processed = process_frames(batch, params)
|
||||
for frame in processed:
|
||||
encoder.stdin.write(frame.astype(np.uint8).tobytes())
|
||||
total_processed += 1
|
||||
break
|
||||
|
||||
frame = np.frombuffer(raw_frame, dtype=np.uint8)
|
||||
frame = frame.reshape((info.height, info.width, bytes_per_pixel))
|
||||
batch.append(frame)
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
processed = process_frames(batch, params)
|
||||
for frame in processed:
|
||||
encoder.stdin.write(frame.astype(np.uint8).tobytes())
|
||||
total_processed += 1
|
||||
batch = []
|
||||
|
||||
finally:
|
||||
decoder.stdout.close()
|
||||
decoder.wait()
|
||||
encoder.stdin.close()
|
||||
encoder.wait()
|
||||
|
||||
if encoder.returncode != 0:
|
||||
stderr = encoder.stderr.read().decode() if encoder.stderr else ""
|
||||
raise RuntimeError(f"Encoder failed: {stderr}")
|
||||
|
||||
logger.info(f"Processed {total_processed} frames in batches of {batch_size}")
|
||||
return output_path
|
||||
455
core/artdag/effects/loader.py
Normal file
455
core/artdag/effects/loader.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Effect file loader.
|
||||
|
||||
Parses effect files with:
|
||||
- PEP 723 inline script metadata for dependencies
|
||||
- @-tag docstrings for effect metadata
|
||||
- META object for programmatic access
|
||||
"""
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from .meta import EffectMeta, ParamSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedEffect:
|
||||
"""
|
||||
A loaded effect with all metadata.
|
||||
|
||||
Attributes:
|
||||
source: Original source code
|
||||
cid: SHA3-256 hash of source
|
||||
meta: Extracted EffectMeta
|
||||
dependencies: List of pip dependencies
|
||||
requires_python: Python version requirement
|
||||
module: Compiled module (if loaded)
|
||||
"""
|
||||
|
||||
source: str
|
||||
cid: str
|
||||
meta: EffectMeta
|
||||
dependencies: List[str] = field(default_factory=list)
|
||||
requires_python: str = ">=3.10"
|
||||
module: Any = None
|
||||
|
||||
def has_frame_api(self) -> bool:
|
||||
"""Check if effect has frame-by-frame API."""
|
||||
return self.meta.api_type == "frame"
|
||||
|
||||
def has_video_api(self) -> bool:
|
||||
"""Check if effect has whole-video API."""
|
||||
return self.meta.api_type == "video"
|
||||
|
||||
|
||||
def compute_cid(source: str) -> str:
|
||||
"""Compute SHA3-256 hash of effect source."""
|
||||
return hashlib.sha3_256(source.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def parse_pep723_metadata(source: str) -> Tuple[List[str], str]:
|
||||
"""
|
||||
Parse PEP 723 inline script metadata.
|
||||
|
||||
Looks for:
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["numpy", "opencv-python"]
|
||||
# ///
|
||||
|
||||
Returns:
|
||||
Tuple of (dependencies list, requires_python string)
|
||||
"""
|
||||
dependencies = []
|
||||
requires_python = ">=3.10"
|
||||
|
||||
# Match the script block
|
||||
pattern = r"# /// script\n(.*?)# ///"
|
||||
match = re.search(pattern, source, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
return dependencies, requires_python
|
||||
|
||||
block = match.group(1)
|
||||
|
||||
# Parse dependencies
|
||||
deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL)
|
||||
if deps_match:
|
||||
deps_str = deps_match.group(1)
|
||||
# Extract quoted strings
|
||||
dependencies = re.findall(r'"([^"]+)"', deps_str)
|
||||
|
||||
# Parse requires-python
|
||||
python_match = re.search(r'# requires-python = "([^"]+)"', block)
|
||||
if python_match:
|
||||
requires_python = python_match.group(1)
|
||||
|
||||
return dependencies, requires_python
|
||||
|
||||
|
||||
def parse_docstring_metadata(docstring: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse @-tag metadata from docstring.
|
||||
|
||||
Supports:
|
||||
@effect name
|
||||
@version 1.0.0
|
||||
@author @user@domain
|
||||
@temporal false
|
||||
@description
|
||||
Multi-line description text.
|
||||
|
||||
@param name type
|
||||
@range lo hi
|
||||
@default value
|
||||
Description text.
|
||||
|
||||
@example
|
||||
(fx effect :param value)
|
||||
|
||||
Returns:
|
||||
Dictionary with extracted metadata
|
||||
"""
|
||||
if not docstring:
|
||||
return {}
|
||||
|
||||
result = {
|
||||
"name": "",
|
||||
"version": "1.0.0",
|
||||
"author": "",
|
||||
"temporal": False,
|
||||
"description": "",
|
||||
"params": [],
|
||||
"examples": [],
|
||||
}
|
||||
|
||||
lines = docstring.strip().split("\n")
|
||||
i = 0
|
||||
current_param = None
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
|
||||
if line.startswith("@effect "):
|
||||
result["name"] = line[8:].strip()
|
||||
|
||||
elif line.startswith("@version "):
|
||||
result["version"] = line[9:].strip()
|
||||
|
||||
elif line.startswith("@author "):
|
||||
result["author"] = line[8:].strip()
|
||||
|
||||
elif line.startswith("@temporal "):
|
||||
val = line[10:].strip().lower()
|
||||
result["temporal"] = val in ("true", "yes", "1")
|
||||
|
||||
elif line.startswith("@description"):
|
||||
# Collect multi-line description
|
||||
desc_lines = []
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
next_line = lines[i]
|
||||
if next_line.strip().startswith("@"):
|
||||
i -= 1 # Back up to process this tag
|
||||
break
|
||||
desc_lines.append(next_line)
|
||||
i += 1
|
||||
result["description"] = "\n".join(desc_lines).strip()
|
||||
|
||||
elif line.startswith("@param "):
|
||||
# Parse parameter: @param name type
|
||||
parts = line[7:].split()
|
||||
if len(parts) >= 2:
|
||||
current_param = {
|
||||
"name": parts[0],
|
||||
"type": parts[1],
|
||||
"range": None,
|
||||
"default": None,
|
||||
"description": "",
|
||||
}
|
||||
# Collect param details
|
||||
desc_lines = []
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
next_line = lines[i]
|
||||
stripped = next_line.strip()
|
||||
|
||||
if stripped.startswith("@range "):
|
||||
range_parts = stripped[7:].split()
|
||||
if len(range_parts) >= 2:
|
||||
try:
|
||||
current_param["range"] = (
|
||||
float(range_parts[0]),
|
||||
float(range_parts[1]),
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
elif stripped.startswith("@default "):
|
||||
current_param["default"] = stripped[9:].strip()
|
||||
|
||||
elif stripped.startswith("@param ") or stripped.startswith("@example"):
|
||||
i -= 1 # Back up
|
||||
break
|
||||
|
||||
elif stripped.startswith("@"):
|
||||
i -= 1
|
||||
break
|
||||
|
||||
elif stripped:
|
||||
desc_lines.append(stripped)
|
||||
|
||||
i += 1
|
||||
|
||||
current_param["description"] = " ".join(desc_lines)
|
||||
result["params"].append(current_param)
|
||||
current_param = None
|
||||
|
||||
elif line.startswith("@example"):
|
||||
# Collect example
|
||||
example_lines = []
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
next_line = lines[i]
|
||||
if next_line.strip().startswith("@") and not next_line.strip().startswith("@example"):
|
||||
if next_line.strip().startswith("@example"):
|
||||
i -= 1
|
||||
break
|
||||
if next_line.strip().startswith("@example"):
|
||||
i -= 1
|
||||
break
|
||||
example_lines.append(next_line)
|
||||
i += 1
|
||||
example = "\n".join(example_lines).strip()
|
||||
if example:
|
||||
result["examples"].append(example)
|
||||
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def extract_meta_from_ast(source: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract META object from source AST.
|
||||
|
||||
Looks for:
|
||||
META = EffectMeta(...)
|
||||
|
||||
Returns the keyword arguments if found.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
return None
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Assign):
|
||||
for target in node.targets:
|
||||
if isinstance(target, ast.Name) and target.id == "META":
|
||||
if isinstance(node.value, ast.Call):
|
||||
return _extract_call_kwargs(node.value)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_call_kwargs(call: ast.Call) -> Dict[str, Any]:
|
||||
"""Extract keyword arguments from an AST Call node."""
|
||||
result = {}
|
||||
|
||||
for keyword in call.keywords:
|
||||
if keyword.arg is None:
|
||||
continue
|
||||
value = _ast_to_value(keyword.value)
|
||||
if value is not None:
|
||||
result[keyword.arg] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _ast_to_value(node: ast.expr) -> Any:
|
||||
"""Convert AST node to Python value."""
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
elif isinstance(node, ast.Str): # Python 3.7 compat
|
||||
return node.s
|
||||
elif isinstance(node, ast.Num): # Python 3.7 compat
|
||||
return node.n
|
||||
elif isinstance(node, ast.NameConstant): # Python 3.7 compat
|
||||
return node.value
|
||||
elif isinstance(node, ast.List):
|
||||
return [_ast_to_value(elt) for elt in node.elts]
|
||||
elif isinstance(node, ast.Tuple):
|
||||
return tuple(_ast_to_value(elt) for elt in node.elts)
|
||||
elif isinstance(node, ast.Dict):
|
||||
return {
|
||||
_ast_to_value(k): _ast_to_value(v)
|
||||
for k, v in zip(node.keys, node.values)
|
||||
if k is not None
|
||||
}
|
||||
elif isinstance(node, ast.Call):
|
||||
# Handle ParamSpec(...) calls
|
||||
if isinstance(node.func, ast.Name) and node.func.id == "ParamSpec":
|
||||
return _extract_call_kwargs(node)
|
||||
return None
|
||||
|
||||
|
||||
def get_module_docstring(source: str) -> str:
|
||||
"""Extract the module-level docstring from source."""
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
except SyntaxError:
|
||||
return ""
|
||||
|
||||
if tree.body and isinstance(tree.body[0], ast.Expr):
|
||||
if isinstance(tree.body[0].value, ast.Constant):
|
||||
return tree.body[0].value.value
|
||||
elif isinstance(tree.body[0].value, ast.Str): # Python 3.7 compat
|
||||
return tree.body[0].value.s
|
||||
return ""
|
||||
|
||||
|
||||
def load_effect(source: str) -> LoadedEffect:
|
||||
"""
|
||||
Load an effect from source code.
|
||||
|
||||
Parses:
|
||||
1. PEP 723 metadata for dependencies
|
||||
2. Module docstring for @-tag metadata
|
||||
3. META object for programmatic metadata
|
||||
|
||||
Priority: META object > docstring > defaults
|
||||
|
||||
Args:
|
||||
source: Effect source code
|
||||
|
||||
Returns:
|
||||
LoadedEffect with all metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If effect is invalid
|
||||
"""
|
||||
cid = compute_cid(source)
|
||||
|
||||
# Parse PEP 723 metadata
|
||||
dependencies, requires_python = parse_pep723_metadata(source)
|
||||
|
||||
# Parse docstring metadata
|
||||
docstring = get_module_docstring(source)
|
||||
doc_meta = parse_docstring_metadata(docstring)
|
||||
|
||||
# Try to extract META from AST
|
||||
ast_meta = extract_meta_from_ast(source)
|
||||
|
||||
# Build EffectMeta, preferring META object over docstring
|
||||
name = ""
|
||||
if ast_meta and "name" in ast_meta:
|
||||
name = ast_meta["name"]
|
||||
elif doc_meta.get("name"):
|
||||
name = doc_meta["name"]
|
||||
|
||||
if not name:
|
||||
raise ValueError("Effect must have a name (@effect or META.name)")
|
||||
|
||||
version = ast_meta.get("version") if ast_meta else doc_meta.get("version", "1.0.0")
|
||||
temporal = ast_meta.get("temporal") if ast_meta else doc_meta.get("temporal", False)
|
||||
author = ast_meta.get("author") if ast_meta else doc_meta.get("author", "")
|
||||
description = ast_meta.get("description") if ast_meta else doc_meta.get("description", "")
|
||||
examples = ast_meta.get("examples") if ast_meta else doc_meta.get("examples", [])
|
||||
|
||||
# Build params
|
||||
params = []
|
||||
if ast_meta and "params" in ast_meta:
|
||||
for p in ast_meta["params"]:
|
||||
if isinstance(p, dict):
|
||||
type_map = {"float": float, "int": int, "bool": bool, "str": str}
|
||||
param_type = type_map.get(p.get("param_type", "float"), float)
|
||||
if isinstance(p.get("param_type"), type):
|
||||
param_type = p["param_type"]
|
||||
params.append(
|
||||
ParamSpec(
|
||||
name=p.get("name", ""),
|
||||
param_type=param_type,
|
||||
default=p.get("default"),
|
||||
range=p.get("range"),
|
||||
description=p.get("description", ""),
|
||||
)
|
||||
)
|
||||
elif doc_meta.get("params"):
|
||||
for p in doc_meta["params"]:
|
||||
type_map = {"float": float, "int": int, "bool": bool, "str": str}
|
||||
param_type = type_map.get(p.get("type", "float"), float)
|
||||
|
||||
default = p.get("default")
|
||||
if default is not None:
|
||||
try:
|
||||
default = param_type(default)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
params.append(
|
||||
ParamSpec(
|
||||
name=p["name"],
|
||||
param_type=param_type,
|
||||
default=default,
|
||||
range=p.get("range"),
|
||||
description=p.get("description", ""),
|
||||
)
|
||||
)
|
||||
|
||||
# Determine API type by checking for function definitions
|
||||
api_type = "frame" # default
|
||||
try:
|
||||
tree = ast.parse(source)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
if node.name == "process":
|
||||
api_type = "video"
|
||||
break
|
||||
elif node.name == "process_frame":
|
||||
api_type = "frame"
|
||||
break
|
||||
except SyntaxError:
|
||||
pass
|
||||
|
||||
meta = EffectMeta(
|
||||
name=name,
|
||||
version=version if isinstance(version, str) else "1.0.0",
|
||||
temporal=bool(temporal),
|
||||
params=params,
|
||||
author=author if isinstance(author, str) else "",
|
||||
description=description if isinstance(description, str) else "",
|
||||
examples=examples if isinstance(examples, list) else [],
|
||||
dependencies=dependencies,
|
||||
requires_python=requires_python,
|
||||
api_type=api_type,
|
||||
)
|
||||
|
||||
return LoadedEffect(
|
||||
source=source,
|
||||
cid=cid,
|
||||
meta=meta,
|
||||
dependencies=dependencies,
|
||||
requires_python=requires_python,
|
||||
)
|
||||
|
||||
|
||||
def load_effect_file(path: Path) -> LoadedEffect:
|
||||
"""Load an effect from a file path."""
|
||||
source = path.read_text(encoding="utf-8")
|
||||
return load_effect(source)
|
||||
|
||||
|
||||
def compute_deps_hash(dependencies: List[str]) -> str:
|
||||
"""
|
||||
Compute hash of sorted dependencies.
|
||||
|
||||
Used for venv caching - same deps = same hash = reuse venv.
|
||||
"""
|
||||
sorted_deps = sorted(dep.lower().strip() for dep in dependencies)
|
||||
deps_str = "\n".join(sorted_deps)
|
||||
return hashlib.sha3_256(deps_str.encode("utf-8")).hexdigest()
|
||||
247
core/artdag/effects/meta.py
Normal file
247
core/artdag/effects/meta.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Effect metadata types.
|
||||
|
||||
Defines the core dataclasses for effect metadata:
|
||||
- ParamSpec: Parameter specification with type, range, and default
|
||||
- EffectMeta: Complete effect metadata including params and flags
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamSpec:
|
||||
"""
|
||||
Specification for an effect parameter.
|
||||
|
||||
Attributes:
|
||||
name: Parameter name (used in recipes as :name)
|
||||
param_type: Python type (float, int, bool, str)
|
||||
default: Default value if not specified
|
||||
range: Optional (min, max) tuple for numeric types
|
||||
description: Human-readable description
|
||||
choices: Optional list of allowed values (for enums)
|
||||
"""
|
||||
|
||||
name: str
|
||||
param_type: Type
|
||||
default: Any = None
|
||||
range: Optional[Tuple[float, float]] = None
|
||||
description: str = ""
|
||||
choices: Optional[List[Any]] = None
|
||||
|
||||
def validate(self, value: Any) -> Any:
|
||||
"""
|
||||
Validate and coerce a parameter value.
|
||||
|
||||
Args:
|
||||
value: Input value to validate
|
||||
|
||||
Returns:
|
||||
Validated and coerced value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is invalid
|
||||
"""
|
||||
if value is None:
|
||||
if self.default is not None:
|
||||
return self.default
|
||||
raise ValueError(f"Parameter '{self.name}' requires a value")
|
||||
|
||||
# Type coercion
|
||||
try:
|
||||
if self.param_type == bool:
|
||||
if isinstance(value, str):
|
||||
value = value.lower() in ("true", "1", "yes")
|
||||
else:
|
||||
value = bool(value)
|
||||
elif self.param_type == int:
|
||||
value = int(value)
|
||||
elif self.param_type == float:
|
||||
value = float(value)
|
||||
elif self.param_type == str:
|
||||
value = str(value)
|
||||
else:
|
||||
value = self.param_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(
|
||||
f"Parameter '{self.name}' expects {self.param_type.__name__}, "
|
||||
f"got {type(value).__name__}: {e}"
|
||||
)
|
||||
|
||||
# Range check for numeric types
|
||||
if self.range is not None and self.param_type in (int, float):
|
||||
min_val, max_val = self.range
|
||||
if value < min_val or value > max_val:
|
||||
raise ValueError(
|
||||
f"Parameter '{self.name}' must be in range "
|
||||
f"[{min_val}, {max_val}], got {value}"
|
||||
)
|
||||
|
||||
# Choices check
|
||||
if self.choices is not None and value not in self.choices:
|
||||
raise ValueError(
|
||||
f"Parameter '{self.name}' must be one of {self.choices}, got {value}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
d = {
|
||||
"name": self.name,
|
||||
"type": self.param_type.__name__,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.default is not None:
|
||||
d["default"] = self.default
|
||||
if self.range is not None:
|
||||
d["range"] = list(self.range)
|
||||
if self.choices is not None:
|
||||
d["choices"] = self.choices
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class EffectMeta:
|
||||
"""
|
||||
Complete metadata for an effect.
|
||||
|
||||
Attributes:
|
||||
name: Effect name (used in recipes)
|
||||
version: Semantic version string
|
||||
temporal: If True, effect needs complete input (can't be collapsed)
|
||||
params: List of parameter specifications
|
||||
author: Optional author identifier
|
||||
description: Human-readable description
|
||||
examples: List of example S-expression usages
|
||||
dependencies: List of Python package dependencies
|
||||
requires_python: Minimum Python version
|
||||
api_type: "frame" for frame-by-frame, "video" for whole-video
|
||||
"""
|
||||
|
||||
name: str
|
||||
version: str = "1.0.0"
|
||||
temporal: bool = False
|
||||
params: List[ParamSpec] = field(default_factory=list)
|
||||
author: str = ""
|
||||
description: str = ""
|
||||
examples: List[str] = field(default_factory=list)
|
||||
dependencies: List[str] = field(default_factory=list)
|
||||
requires_python: str = ">=3.10"
|
||||
api_type: str = "frame" # "frame" or "video"
|
||||
|
||||
def get_param(self, name: str) -> Optional[ParamSpec]:
|
||||
"""Get a parameter spec by name."""
|
||||
for param in self.params:
|
||||
if param.name == name:
|
||||
return param
|
||||
return None
|
||||
|
||||
def validate_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate all parameters.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameter values
|
||||
|
||||
Returns:
|
||||
Dictionary with validated/coerced values including defaults
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid
|
||||
"""
|
||||
result = {}
|
||||
for spec in self.params:
|
||||
value = params.get(spec.name)
|
||||
result[spec.name] = spec.validate(value)
|
||||
return result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"version": self.version,
|
||||
"temporal": self.temporal,
|
||||
"params": [p.to_dict() for p in self.params],
|
||||
"author": self.author,
|
||||
"description": self.description,
|
||||
"examples": self.examples,
|
||||
"dependencies": self.dependencies,
|
||||
"requires_python": self.requires_python,
|
||||
"api_type": self.api_type,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EffectMeta":
|
||||
"""Create from dictionary."""
|
||||
params = []
|
||||
for p in data.get("params", []):
|
||||
# Map type name back to Python type
|
||||
type_map = {"float": float, "int": int, "bool": bool, "str": str}
|
||||
param_type = type_map.get(p.get("type", "float"), float)
|
||||
params.append(
|
||||
ParamSpec(
|
||||
name=p["name"],
|
||||
param_type=param_type,
|
||||
default=p.get("default"),
|
||||
range=tuple(p["range"]) if p.get("range") else None,
|
||||
description=p.get("description", ""),
|
||||
choices=p.get("choices"),
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
name=data["name"],
|
||||
version=data.get("version", "1.0.0"),
|
||||
temporal=data.get("temporal", False),
|
||||
params=params,
|
||||
author=data.get("author", ""),
|
||||
description=data.get("description", ""),
|
||||
examples=data.get("examples", []),
|
||||
dependencies=data.get("dependencies", []),
|
||||
requires_python=data.get("requires_python", ">=3.10"),
|
||||
api_type=data.get("api_type", "frame"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Context passed to effect execution.
|
||||
|
||||
Provides controlled access to resources within sandbox.
|
||||
"""
|
||||
|
||||
input_paths: List[str]
|
||||
output_path: str
|
||||
params: Dict[str, Any]
|
||||
seed: int # Deterministic seed for RNG
|
||||
frame_rate: float = 30.0
|
||||
width: int = 1920
|
||||
height: int = 1080
|
||||
|
||||
# Resolved bindings (frame -> param value lookup)
|
||||
bindings: Dict[str, List[float]] = field(default_factory=dict)
|
||||
|
||||
def get_param_at_frame(self, param_name: str, frame: int) -> Any:
|
||||
"""
|
||||
Get parameter value at a specific frame.
|
||||
|
||||
If parameter has a binding, looks up the bound value.
|
||||
Otherwise returns the static parameter value.
|
||||
"""
|
||||
if param_name in self.bindings:
|
||||
binding_values = self.bindings[param_name]
|
||||
if frame < len(binding_values):
|
||||
return binding_values[frame]
|
||||
# Past end of binding data, use last value
|
||||
return binding_values[-1] if binding_values else self.params.get(param_name)
|
||||
return self.params.get(param_name)
|
||||
|
||||
def get_rng(self) -> "random.Random":
|
||||
"""Get a seeded random number generator."""
|
||||
import random
|
||||
|
||||
return random.Random(self.seed)
|
||||
259
core/artdag/effects/runner.py
Normal file
259
core/artdag/effects/runner.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Effect runner.
|
||||
|
||||
Main entry point for executing cached effects with sandboxing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .binding import AnalysisData, bindings_to_lookup_table, resolve_all_bindings
|
||||
from .loader import load_effect, LoadedEffect
|
||||
from .meta import ExecutionContext
|
||||
from .sandbox import Sandbox, SandboxConfig, SandboxResult, get_venv_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_effect(
|
||||
effect_source: str,
|
||||
input_paths: List[Path],
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
analysis: Optional[AnalysisData] = None,
|
||||
cache_id: str = None,
|
||||
seed: int = 0,
|
||||
trust_level: str = "untrusted",
|
||||
timeout: int = 3600,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Run an effect with full sandboxing.
|
||||
|
||||
This is the main entry point for effect execution.
|
||||
|
||||
Args:
|
||||
effect_source: Effect source code
|
||||
input_paths: List of input file paths
|
||||
output_path: Output file path
|
||||
params: Effect parameters (may contain bindings)
|
||||
analysis: Optional analysis data for binding resolution
|
||||
cache_id: Cache ID for deterministic seeding
|
||||
seed: RNG seed (overrides cache_id-based seed)
|
||||
trust_level: "untrusted" or "trusted"
|
||||
timeout: Maximum execution time in seconds
|
||||
|
||||
Returns:
|
||||
SandboxResult with success status and output
|
||||
"""
|
||||
# Load and validate effect
|
||||
loaded = load_effect(effect_source)
|
||||
logger.info(f"Running effect '{loaded.meta.name}' v{loaded.meta.version}")
|
||||
|
||||
# Resolve bindings if analysis data available
|
||||
bindings = {}
|
||||
if analysis:
|
||||
resolved = resolve_all_bindings(params, analysis, cache_id)
|
||||
bindings = bindings_to_lookup_table(resolved)
|
||||
# Remove binding dicts from params, keeping only resolved values
|
||||
params = {
|
||||
k: v for k, v in params.items()
|
||||
if not (isinstance(v, dict) and v.get("_binding"))
|
||||
}
|
||||
|
||||
# Validate parameters
|
||||
validated_params = loaded.meta.validate_params(params)
|
||||
|
||||
# Get or create venv for dependencies
|
||||
venv_path = None
|
||||
if loaded.dependencies:
|
||||
venv_path = get_venv_path(loaded.dependencies)
|
||||
|
||||
# Configure sandbox
|
||||
config = SandboxConfig(
|
||||
trust_level=trust_level,
|
||||
venv_path=venv_path,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Write effect to temp file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
suffix=".py",
|
||||
delete=False,
|
||||
) as f:
|
||||
f.write(effect_source)
|
||||
effect_path = Path(f.name)
|
||||
|
||||
try:
|
||||
with Sandbox(config) as sandbox:
|
||||
result = sandbox.run_effect(
|
||||
effect_path=effect_path,
|
||||
input_paths=input_paths,
|
||||
output_path=output_path,
|
||||
params=validated_params,
|
||||
bindings=bindings,
|
||||
seed=seed,
|
||||
)
|
||||
finally:
|
||||
effect_path.unlink(missing_ok=True)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def run_effect_from_cache(
|
||||
cache,
|
||||
effect_hash: str,
|
||||
input_paths: List[Path],
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
analysis: Optional[AnalysisData] = None,
|
||||
cache_id: str = None,
|
||||
seed: int = 0,
|
||||
trust_level: str = "untrusted",
|
||||
timeout: int = 3600,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Run an effect from cache by content hash.
|
||||
|
||||
Args:
|
||||
cache: Cache instance
|
||||
effect_hash: Content hash of effect
|
||||
input_paths: Input file paths
|
||||
output_path: Output file path
|
||||
params: Effect parameters
|
||||
analysis: Optional analysis data
|
||||
cache_id: Cache ID for seeding
|
||||
seed: RNG seed
|
||||
trust_level: "untrusted" or "trusted"
|
||||
timeout: Max execution time
|
||||
|
||||
Returns:
|
||||
SandboxResult
|
||||
"""
|
||||
effect_source = cache.get_effect(effect_hash)
|
||||
if not effect_source:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"Effect not found in cache: {effect_hash[:16]}...",
|
||||
)
|
||||
|
||||
return run_effect(
|
||||
effect_source=effect_source,
|
||||
input_paths=input_paths,
|
||||
output_path=output_path,
|
||||
params=params,
|
||||
analysis=analysis,
|
||||
cache_id=cache_id,
|
||||
seed=seed,
|
||||
trust_level=trust_level,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
def check_effect_temporal(cache, effect_hash: str) -> bool:
|
||||
"""
|
||||
Check if an effect is temporal (can't be collapsed).
|
||||
|
||||
Args:
|
||||
cache: Cache instance
|
||||
effect_hash: Content hash of effect
|
||||
|
||||
Returns:
|
||||
True if effect is temporal
|
||||
"""
|
||||
metadata = cache.get_effect_metadata(effect_hash)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
meta = metadata.get("meta", {})
|
||||
return meta.get("temporal", False)
|
||||
|
||||
|
||||
def get_effect_api_type(cache, effect_hash: str) -> str:
|
||||
"""
|
||||
Get the API type of an effect.
|
||||
|
||||
Args:
|
||||
cache: Cache instance
|
||||
effect_hash: Content hash of effect
|
||||
|
||||
Returns:
|
||||
"frame" or "video"
|
||||
"""
|
||||
metadata = cache.get_effect_metadata(effect_hash)
|
||||
if not metadata:
|
||||
return "frame"
|
||||
|
||||
meta = metadata.get("meta", {})
|
||||
return meta.get("api_type", "frame")
|
||||
|
||||
|
||||
class EffectExecutor:
|
||||
"""
|
||||
Executor for cached effects.
|
||||
|
||||
Provides a higher-level interface for effect execution.
|
||||
"""
|
||||
|
||||
def __init__(self, cache, trust_level: str = "untrusted"):
|
||||
"""
|
||||
Initialize executor.
|
||||
|
||||
Args:
|
||||
cache: Cache instance
|
||||
trust_level: Default trust level
|
||||
"""
|
||||
self.cache = cache
|
||||
self.trust_level = trust_level
|
||||
|
||||
def execute(
|
||||
self,
|
||||
effect_hash: str,
|
||||
input_paths: List[Path],
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
analysis: Optional[AnalysisData] = None,
|
||||
step_cache_id: str = None,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Execute an effect.
|
||||
|
||||
Args:
|
||||
effect_hash: Content hash of effect
|
||||
input_paths: Input file paths
|
||||
output_path: Output path
|
||||
params: Effect parameters
|
||||
analysis: Analysis data for bindings
|
||||
step_cache_id: Step cache ID for seeding
|
||||
|
||||
Returns:
|
||||
SandboxResult
|
||||
"""
|
||||
# Check effect metadata for trust level override
|
||||
metadata = self.cache.get_effect_metadata(effect_hash)
|
||||
trust_level = self.trust_level
|
||||
if metadata:
|
||||
# L1 owner can mark effect as trusted
|
||||
if metadata.get("trust_level") == "trusted":
|
||||
trust_level = "trusted"
|
||||
|
||||
return run_effect_from_cache(
|
||||
cache=self.cache,
|
||||
effect_hash=effect_hash,
|
||||
input_paths=input_paths,
|
||||
output_path=output_path,
|
||||
params=params,
|
||||
analysis=analysis,
|
||||
cache_id=step_cache_id,
|
||||
trust_level=trust_level,
|
||||
)
|
||||
|
||||
def is_temporal(self, effect_hash: str) -> bool:
|
||||
"""Check if effect is temporal."""
|
||||
return check_effect_temporal(self.cache, effect_hash)
|
||||
|
||||
def get_api_type(self, effect_hash: str) -> str:
|
||||
"""Get effect API type."""
|
||||
return get_effect_api_type(self.cache, effect_hash)
|
||||
431
core/artdag/effects/sandbox.py
Normal file
431
core/artdag/effects/sandbox.py
Normal file
@@ -0,0 +1,431 @@
|
||||
"""
|
||||
Sandbox for effect execution.
|
||||
|
||||
Uses bubblewrap (bwrap) for Linux namespace isolation.
|
||||
Provides controlled access to:
|
||||
- Input files (read-only)
|
||||
- Output file (write)
|
||||
- stderr (logging)
|
||||
- Seeded RNG
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxConfig:
|
||||
"""
|
||||
Sandbox configuration.
|
||||
|
||||
Attributes:
|
||||
trust_level: "untrusted" (full isolation) or "trusted" (allows subprocess)
|
||||
venv_path: Path to effect's virtual environment
|
||||
wheel_cache: Shared wheel cache directory
|
||||
timeout: Maximum execution time in seconds
|
||||
memory_limit: Memory limit in bytes (0 = unlimited)
|
||||
allow_network: Whether to allow network access
|
||||
"""
|
||||
|
||||
trust_level: str = "untrusted"
|
||||
venv_path: Optional[Path] = None
|
||||
wheel_cache: Path = field(default_factory=lambda: Path("/var/cache/artdag/wheels"))
|
||||
timeout: int = 3600 # 1 hour default
|
||||
memory_limit: int = 0
|
||||
allow_network: bool = False
|
||||
|
||||
|
||||
def is_bwrap_available() -> bool:
|
||||
"""Check if bubblewrap is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["bwrap", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def get_venv_path(dependencies: List[str], cache_dir: Path = None) -> Path:
|
||||
"""
|
||||
Get or create venv for given dependencies.
|
||||
|
||||
Uses hash of sorted dependencies for cache key.
|
||||
|
||||
Args:
|
||||
dependencies: List of pip package specifiers
|
||||
cache_dir: Base directory for venv cache
|
||||
|
||||
Returns:
|
||||
Path to venv directory
|
||||
"""
|
||||
cache_dir = cache_dir or Path("/var/cache/artdag/venvs")
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Compute deps hash
|
||||
sorted_deps = sorted(dep.lower().strip() for dep in dependencies)
|
||||
deps_str = "\n".join(sorted_deps)
|
||||
deps_hash = hashlib.sha3_256(deps_str.encode()).hexdigest()[:16]
|
||||
|
||||
venv_path = cache_dir / deps_hash
|
||||
|
||||
if venv_path.exists():
|
||||
logger.debug(f"Reusing venv at {venv_path}")
|
||||
return venv_path
|
||||
|
||||
# Create new venv
|
||||
logger.info(f"Creating venv for {len(dependencies)} deps at {venv_path}")
|
||||
|
||||
subprocess.run(
|
||||
["python", "-m", "venv", str(venv_path)],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Install dependencies
|
||||
pip_path = venv_path / "bin" / "pip"
|
||||
wheel_cache = Path("/var/cache/artdag/wheels")
|
||||
|
||||
if dependencies:
|
||||
cmd = [
|
||||
str(pip_path),
|
||||
"install",
|
||||
"--cache-dir", str(wheel_cache),
|
||||
*dependencies,
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
return venv_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxResult:
|
||||
"""Result of sandboxed execution."""
|
||||
|
||||
success: bool
|
||||
output_path: Optional[Path] = None
|
||||
stderr: str = ""
|
||||
exit_code: int = 0
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class Sandbox:
|
||||
"""
|
||||
Sandboxed effect execution environment.
|
||||
|
||||
Uses bubblewrap for namespace isolation when available,
|
||||
falls back to subprocess with restricted permissions.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SandboxConfig = None):
|
||||
self.config = config or SandboxConfig()
|
||||
self._temp_dirs: List[Path] = []
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.cleanup()
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up temporary directories."""
|
||||
for temp_dir in self._temp_dirs:
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
self._temp_dirs = []
|
||||
|
||||
def _create_temp_dir(self) -> Path:
|
||||
"""Create a temporary directory for sandbox use."""
|
||||
temp_dir = Path(tempfile.mkdtemp(prefix="artdag_sandbox_"))
|
||||
self._temp_dirs.append(temp_dir)
|
||||
return temp_dir
|
||||
|
||||
def run_effect(
|
||||
self,
|
||||
effect_path: Path,
|
||||
input_paths: List[Path],
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
bindings: Dict[str, List[float]] = None,
|
||||
seed: int = 0,
|
||||
) -> SandboxResult:
|
||||
"""
|
||||
Run an effect in the sandbox.
|
||||
|
||||
Args:
|
||||
effect_path: Path to effect.py
|
||||
input_paths: List of input file paths
|
||||
output_path: Output file path
|
||||
params: Effect parameters
|
||||
bindings: Per-frame parameter bindings
|
||||
seed: RNG seed for determinism
|
||||
|
||||
Returns:
|
||||
SandboxResult with success status and output
|
||||
"""
|
||||
bindings = bindings or {}
|
||||
|
||||
# Create work directory
|
||||
work_dir = self._create_temp_dir()
|
||||
config_path = work_dir / "config.json"
|
||||
effect_copy = work_dir / "effect.py"
|
||||
|
||||
# Copy effect to work dir
|
||||
shutil.copy(effect_path, effect_copy)
|
||||
|
||||
# Write config file
|
||||
config_data = {
|
||||
"input_paths": [str(p) for p in input_paths],
|
||||
"output_path": str(output_path),
|
||||
"params": params,
|
||||
"bindings": bindings,
|
||||
"seed": seed,
|
||||
}
|
||||
config_path.write_text(json.dumps(config_data))
|
||||
|
||||
if is_bwrap_available() and self.config.trust_level == "untrusted":
|
||||
return self._run_with_bwrap(
|
||||
effect_copy, config_path, input_paths, output_path, work_dir
|
||||
)
|
||||
else:
|
||||
return self._run_subprocess(
|
||||
effect_copy, config_path, input_paths, output_path, work_dir
|
||||
)
|
||||
|
||||
def _run_with_bwrap(
|
||||
self,
|
||||
effect_path: Path,
|
||||
config_path: Path,
|
||||
input_paths: List[Path],
|
||||
output_path: Path,
|
||||
work_dir: Path,
|
||||
) -> SandboxResult:
|
||||
"""Run effect with bubblewrap isolation."""
|
||||
logger.info("Running effect in bwrap sandbox")
|
||||
|
||||
# Build bwrap command
|
||||
cmd = [
|
||||
"bwrap",
|
||||
# New PID namespace
|
||||
"--unshare-pid",
|
||||
# No network
|
||||
"--unshare-net",
|
||||
# Read-only root filesystem
|
||||
"--ro-bind", "/", "/",
|
||||
# Read-write work directory
|
||||
"--bind", str(work_dir), str(work_dir),
|
||||
# Read-only input files
|
||||
]
|
||||
|
||||
for input_path in input_paths:
|
||||
cmd.extend(["--ro-bind", str(input_path), str(input_path)])
|
||||
|
||||
# Bind output directory as writable
|
||||
output_dir = output_path.parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
cmd.extend(["--bind", str(output_dir), str(output_dir)])
|
||||
|
||||
# Bind venv if available
|
||||
if self.config.venv_path and self.config.venv_path.exists():
|
||||
cmd.extend(["--ro-bind", str(self.config.venv_path), str(self.config.venv_path)])
|
||||
python_path = self.config.venv_path / "bin" / "python"
|
||||
else:
|
||||
python_path = Path("/usr/bin/python3")
|
||||
|
||||
# Add runner script
|
||||
runner_script = self._get_runner_script()
|
||||
runner_path = work_dir / "runner.py"
|
||||
runner_path.write_text(runner_script)
|
||||
|
||||
# Run the effect
|
||||
cmd.extend([
|
||||
str(python_path),
|
||||
str(runner_path),
|
||||
str(effect_path),
|
||||
str(config_path),
|
||||
])
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
|
||||
if result.returncode == 0 and output_path.exists():
|
||||
return SandboxResult(
|
||||
success=True,
|
||||
output_path=output_path,
|
||||
stderr=result.stderr,
|
||||
exit_code=0,
|
||||
)
|
||||
else:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
stderr=result.stderr,
|
||||
exit_code=result.returncode,
|
||||
error=result.stderr or "Effect execution failed",
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"Effect timed out after {self.config.timeout}s",
|
||||
exit_code=-1,
|
||||
)
|
||||
except Exception as e:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
exit_code=-1,
|
||||
)
|
||||
|
||||
def _run_subprocess(
|
||||
self,
|
||||
effect_path: Path,
|
||||
config_path: Path,
|
||||
input_paths: List[Path],
|
||||
output_path: Path,
|
||||
work_dir: Path,
|
||||
) -> SandboxResult:
|
||||
"""Run effect in subprocess (fallback without bwrap)."""
|
||||
logger.warning("Running effect without sandbox isolation")
|
||||
|
||||
# Create runner script
|
||||
runner_script = self._get_runner_script()
|
||||
runner_path = work_dir / "runner.py"
|
||||
runner_path.write_text(runner_script)
|
||||
|
||||
# Determine Python path
|
||||
if self.config.venv_path and self.config.venv_path.exists():
|
||||
python_path = self.config.venv_path / "bin" / "python"
|
||||
else:
|
||||
python_path = "python3"
|
||||
|
||||
cmd = [
|
||||
str(python_path),
|
||||
str(runner_path),
|
||||
str(effect_path),
|
||||
str(config_path),
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.config.timeout,
|
||||
cwd=str(work_dir),
|
||||
)
|
||||
|
||||
if result.returncode == 0 and output_path.exists():
|
||||
return SandboxResult(
|
||||
success=True,
|
||||
output_path=output_path,
|
||||
stderr=result.stderr,
|
||||
exit_code=0,
|
||||
)
|
||||
else:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
stderr=result.stderr,
|
||||
exit_code=result.returncode,
|
||||
error=result.stderr or "Effect execution failed",
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=f"Effect timed out after {self.config.timeout}s",
|
||||
exit_code=-1,
|
||||
)
|
||||
except Exception as e:
|
||||
return SandboxResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
exit_code=-1,
|
||||
)
|
||||
|
||||
def _get_runner_script(self) -> str:
|
||||
"""Get the runner script that executes effects."""
|
||||
return '''#!/usr/bin/env python3
|
||||
"""Effect runner script - executed in sandbox."""
|
||||
|
||||
import importlib.util
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def load_effect(effect_path):
|
||||
"""Load effect module from path."""
|
||||
spec = importlib.util.spec_from_file_location("effect", effect_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: runner.py <effect_path> <config_path>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
effect_path = Path(sys.argv[1])
|
||||
config_path = Path(sys.argv[2])
|
||||
|
||||
# Load config
|
||||
config = json.loads(config_path.read_text())
|
||||
|
||||
input_paths = [Path(p) for p in config["input_paths"]]
|
||||
output_path = Path(config["output_path"])
|
||||
params = config["params"]
|
||||
bindings = config.get("bindings", {})
|
||||
seed = config.get("seed", 0)
|
||||
|
||||
# Load effect
|
||||
effect = load_effect(effect_path)
|
||||
|
||||
# Check API type
|
||||
if hasattr(effect, "process"):
|
||||
# Whole-video API
|
||||
from artdag.effects.meta import ExecutionContext
|
||||
ctx = ExecutionContext(
|
||||
input_paths=[str(p) for p in input_paths],
|
||||
output_path=str(output_path),
|
||||
params=params,
|
||||
seed=seed,
|
||||
bindings=bindings,
|
||||
)
|
||||
effect.process(input_paths, output_path, params, ctx)
|
||||
|
||||
elif hasattr(effect, "process_frame"):
|
||||
# Frame-by-frame API
|
||||
from artdag.effects.frame_processor import process_video
|
||||
|
||||
result_path, _ = process_video(
|
||||
input_path=input_paths[0],
|
||||
output_path=output_path,
|
||||
process_frame=effect.process_frame,
|
||||
params=params,
|
||||
bindings=bindings,
|
||||
)
|
||||
|
||||
else:
|
||||
print("Effect must have process() or process_frame()", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Success: {output_path}", file=sys.stderr)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
'''
|
||||
246
core/artdag/engine.py
Normal file
246
core/artdag/engine.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# primitive/engine.py
|
||||
"""
|
||||
DAG execution engine.
|
||||
|
||||
Executes DAGs by:
|
||||
1. Resolving nodes in topological order
|
||||
2. Checking cache for each node
|
||||
3. Running executors for cache misses
|
||||
4. Storing results in cache
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from .dag import DAG, Node, NodeType
|
||||
from .cache import Cache
|
||||
from .executor import Executor, get_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result of executing a DAG."""
|
||||
success: bool
|
||||
output_path: Optional[Path] = None
|
||||
error: Optional[str] = None
|
||||
execution_time: float = 0.0
|
||||
nodes_executed: int = 0
|
||||
nodes_cached: int = 0
|
||||
node_results: Dict[str, Path] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NodeProgress:
|
||||
"""Progress update for a node."""
|
||||
node_id: str
|
||||
node_type: str
|
||||
status: str # "pending", "running", "cached", "completed", "failed"
|
||||
progress: float = 0.0 # 0.0 to 1.0
|
||||
message: str = ""
|
||||
|
||||
|
||||
# Progress callback type
|
||||
ProgressCallback = Callable[[NodeProgress], None]
|
||||
|
||||
|
||||
class Engine:
|
||||
"""
|
||||
DAG execution engine.
|
||||
|
||||
Manages cache, resolves dependencies, and runs executors.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Path | str):
|
||||
self.cache = Cache(cache_dir)
|
||||
self._progress_callback: Optional[ProgressCallback] = None
|
||||
|
||||
def set_progress_callback(self, callback: ProgressCallback):
|
||||
"""Set callback for progress updates."""
|
||||
self._progress_callback = callback
|
||||
|
||||
def _report_progress(self, progress: NodeProgress):
|
||||
"""Report progress to callback if set."""
|
||||
if self._progress_callback:
|
||||
try:
|
||||
self._progress_callback(progress)
|
||||
except Exception as e:
|
||||
logger.warning(f"Progress callback error: {e}")
|
||||
|
||||
def execute(self, dag: DAG) -> ExecutionResult:
|
||||
"""
|
||||
Execute a DAG and return the result.
|
||||
|
||||
Args:
|
||||
dag: The DAG to execute
|
||||
|
||||
Returns:
|
||||
ExecutionResult with output path or error
|
||||
"""
|
||||
start_time = time.time()
|
||||
node_results: Dict[str, Path] = {}
|
||||
nodes_executed = 0
|
||||
nodes_cached = 0
|
||||
|
||||
# Validate DAG
|
||||
errors = dag.validate()
|
||||
if errors:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Invalid DAG: {errors}",
|
||||
execution_time=time.time() - start_time,
|
||||
)
|
||||
|
||||
# Get topological order
|
||||
try:
|
||||
order = dag.topological_order()
|
||||
except Exception as e:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Failed to order DAG: {e}",
|
||||
execution_time=time.time() - start_time,
|
||||
)
|
||||
|
||||
# Execute each node
|
||||
for node_id in order:
|
||||
node = dag.get_node(node_id)
|
||||
type_str = node.node_type.name if isinstance(node.node_type, NodeType) else str(node.node_type)
|
||||
|
||||
# Report starting
|
||||
self._report_progress(NodeProgress(
|
||||
node_id=node_id,
|
||||
node_type=type_str,
|
||||
status="pending",
|
||||
message=f"Processing {type_str}",
|
||||
))
|
||||
|
||||
# Check cache first
|
||||
cached_path = self.cache.get(node_id)
|
||||
if cached_path is not None:
|
||||
node_results[node_id] = cached_path
|
||||
nodes_cached += 1
|
||||
self._report_progress(NodeProgress(
|
||||
node_id=node_id,
|
||||
node_type=type_str,
|
||||
status="cached",
|
||||
progress=1.0,
|
||||
message="Using cached result",
|
||||
))
|
||||
continue
|
||||
|
||||
# Get executor
|
||||
executor = get_executor(node.node_type)
|
||||
if executor is None:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"No executor for node type: {node.node_type}",
|
||||
execution_time=time.time() - start_time,
|
||||
node_results=node_results,
|
||||
)
|
||||
|
||||
# Resolve input paths
|
||||
input_paths = []
|
||||
for input_id in node.inputs:
|
||||
if input_id not in node_results:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Missing input {input_id} for node {node_id}",
|
||||
execution_time=time.time() - start_time,
|
||||
node_results=node_results,
|
||||
)
|
||||
input_paths.append(node_results[input_id])
|
||||
|
||||
# Determine output path
|
||||
output_path = self.cache.get_output_path(node_id, ".mkv")
|
||||
|
||||
# Execute
|
||||
self._report_progress(NodeProgress(
|
||||
node_id=node_id,
|
||||
node_type=type_str,
|
||||
status="running",
|
||||
progress=0.5,
|
||||
message=f"Executing {type_str}",
|
||||
))
|
||||
|
||||
node_start = time.time()
|
||||
try:
|
||||
result_path = executor.execute(
|
||||
config=node.config,
|
||||
inputs=input_paths,
|
||||
output_path=output_path,
|
||||
)
|
||||
node_time = time.time() - node_start
|
||||
|
||||
# Store in cache (file is already at output_path)
|
||||
self.cache.put(
|
||||
node_id=node_id,
|
||||
source_path=result_path,
|
||||
node_type=type_str,
|
||||
execution_time=node_time,
|
||||
move=False, # Already in place
|
||||
)
|
||||
|
||||
node_results[node_id] = result_path
|
||||
nodes_executed += 1
|
||||
|
||||
self._report_progress(NodeProgress(
|
||||
node_id=node_id,
|
||||
node_type=type_str,
|
||||
status="completed",
|
||||
progress=1.0,
|
||||
message=f"Completed in {node_time:.2f}s",
|
||||
))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Node {node_id} failed: {e}")
|
||||
self._report_progress(NodeProgress(
|
||||
node_id=node_id,
|
||||
node_type=type_str,
|
||||
status="failed",
|
||||
message=str(e),
|
||||
))
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Node {node_id} ({type_str}) failed: {e}",
|
||||
execution_time=time.time() - start_time,
|
||||
node_results=node_results,
|
||||
nodes_executed=nodes_executed,
|
||||
nodes_cached=nodes_cached,
|
||||
)
|
||||
|
||||
# Get final output
|
||||
output_path = node_results.get(dag.output_id)
|
||||
|
||||
return ExecutionResult(
|
||||
success=True,
|
||||
output_path=output_path,
|
||||
execution_time=time.time() - start_time,
|
||||
nodes_executed=nodes_executed,
|
||||
nodes_cached=nodes_cached,
|
||||
node_results=node_results,
|
||||
)
|
||||
|
||||
def execute_node(self, node: Node, inputs: List[Path]) -> Path:
|
||||
"""
|
||||
Execute a single node (bypassing DAG structure).
|
||||
|
||||
Useful for testing individual executors.
|
||||
"""
|
||||
executor = get_executor(node.node_type)
|
||||
if executor is None:
|
||||
raise ValueError(f"No executor for node type: {node.node_type}")
|
||||
|
||||
output_path = self.cache.get_output_path(node.node_id, ".mkv")
|
||||
return executor.execute(node.config, inputs, output_path)
|
||||
|
||||
def get_cache_stats(self):
|
||||
"""Get cache statistics."""
|
||||
return self.cache.get_stats()
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear the cache."""
|
||||
self.cache.clear()
|
||||
106
core/artdag/executor.py
Normal file
106
core/artdag/executor.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# primitive/executor.py
|
||||
"""
|
||||
Executor base class and registry.
|
||||
|
||||
Executors implement the actual operations for each node type.
|
||||
They are registered by node type and looked up during execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
from .dag import NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global executor registry
|
||||
_EXECUTORS: Dict[NodeType | str, Type["Executor"]] = {}
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""
|
||||
Base class for node executors.
|
||||
|
||||
Subclasses implement execute() to perform the actual operation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
"""
|
||||
Execute the node operation.
|
||||
|
||||
Args:
|
||||
config: Node configuration
|
||||
inputs: Paths to input files (from resolved input nodes)
|
||||
output_path: Where to write the output
|
||||
|
||||
Returns:
|
||||
Path to the output file
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_config(self, config: Dict[str, Any]) -> List[str]:
|
||||
"""
|
||||
Validate node configuration.
|
||||
|
||||
Returns list of error messages (empty if valid).
|
||||
Override in subclasses for specific validation.
|
||||
"""
|
||||
return []
|
||||
|
||||
def estimate_output_size(self, config: Dict[str, Any], input_sizes: List[int]) -> int:
|
||||
"""
|
||||
Estimate output size in bytes.
|
||||
|
||||
Override for better estimates. Default returns sum of inputs.
|
||||
"""
|
||||
return sum(input_sizes) if input_sizes else 0
|
||||
|
||||
|
||||
def register_executor(node_type: NodeType | str) -> Callable:
|
||||
"""
|
||||
Decorator to register an executor for a node type.
|
||||
|
||||
Usage:
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class SourceExecutor(Executor):
|
||||
...
|
||||
"""
|
||||
def decorator(cls: Type[Executor]) -> Type[Executor]:
|
||||
if node_type in _EXECUTORS:
|
||||
logger.warning(f"Overwriting executor for {node_type}")
|
||||
_EXECUTORS[node_type] = cls
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
|
||||
def get_executor(node_type: NodeType | str) -> Optional[Executor]:
|
||||
"""
|
||||
Get an executor instance for a node type.
|
||||
|
||||
Returns None if no executor is registered.
|
||||
"""
|
||||
executor_cls = _EXECUTORS.get(node_type)
|
||||
if executor_cls is None:
|
||||
return None
|
||||
return executor_cls()
|
||||
|
||||
|
||||
def list_executors() -> Dict[str, Type[Executor]]:
|
||||
"""List all registered executors."""
|
||||
return {
|
||||
(k.name if isinstance(k, NodeType) else k): v
|
||||
for k, v in _EXECUTORS.items()
|
||||
}
|
||||
|
||||
|
||||
def clear_executors():
|
||||
"""Clear all registered executors (for testing)."""
|
||||
_EXECUTORS.clear()
|
||||
11
core/artdag/nodes/__init__.py
Normal file
11
core/artdag/nodes/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# primitive/nodes/__init__.py
|
||||
"""
|
||||
Built-in node executors.
|
||||
|
||||
Import this module to register all built-in executors.
|
||||
"""
|
||||
|
||||
from . import source
|
||||
from . import transform
|
||||
from . import compose
|
||||
from . import effect
|
||||
548
core/artdag/nodes/compose.py
Normal file
548
core/artdag/nodes/compose.py
Normal file
@@ -0,0 +1,548 @@
|
||||
# primitive/nodes/compose.py
|
||||
"""
|
||||
Compose executors: Combine multiple media inputs.
|
||||
|
||||
Primitives: SEQUENCE, LAYER, MUX, BLEND
|
||||
"""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ..dag import NodeType
|
||||
from ..executor import Executor, register_executor
|
||||
from .encoding import WEB_ENCODING_ARGS_STR, get_web_encoding_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_duration(path: Path) -> float:
|
||||
"""Get media duration in seconds."""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
str(path)
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
return float(result.stdout.strip())
|
||||
|
||||
|
||||
def _get_video_info(path: Path) -> dict:
|
||||
"""Get video width, height, frame rate, and sample rate."""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=width,height,r_frame_rate",
|
||||
"-of", "csv=p=0",
|
||||
str(path)
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
parts = result.stdout.strip().split(",")
|
||||
width = int(parts[0]) if len(parts) > 0 and parts[0] else 1920
|
||||
height = int(parts[1]) if len(parts) > 1 and parts[1] else 1080
|
||||
fps_str = parts[2] if len(parts) > 2 else "30/1"
|
||||
# Parse frame rate (e.g., "30/1" or "30000/1001")
|
||||
if "/" in fps_str:
|
||||
num, den = fps_str.split("/")
|
||||
fps = float(num) / float(den) if float(den) != 0 else 30
|
||||
else:
|
||||
fps = float(fps_str) if fps_str else 30
|
||||
|
||||
# Get audio sample rate
|
||||
cmd_audio = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-select_streams", "a:0",
|
||||
"-show_entries", "stream=sample_rate",
|
||||
"-of", "csv=p=0",
|
||||
str(path)
|
||||
]
|
||||
result_audio = subprocess.run(cmd_audio, capture_output=True, text=True)
|
||||
sample_rate = int(result_audio.stdout.strip()) if result_audio.stdout.strip() else 44100
|
||||
|
||||
return {"width": width, "height": height, "fps": fps, "sample_rate": sample_rate}
|
||||
|
||||
|
||||
@register_executor(NodeType.SEQUENCE)
|
||||
class SequenceExecutor(Executor):
|
||||
"""
|
||||
Concatenate inputs in time order.
|
||||
|
||||
Config:
|
||||
transition: Transition config
|
||||
type: "cut" | "crossfade" | "fade"
|
||||
duration: Transition duration in seconds
|
||||
target_size: How to determine output dimensions when inputs differ
|
||||
"first": Use first input's dimensions (default)
|
||||
"last": Use last input's dimensions
|
||||
"largest": Use largest width and height from all inputs
|
||||
"explicit": Use width/height config values
|
||||
width: Target width (when target_size="explicit")
|
||||
height: Target height (when target_size="explicit")
|
||||
background: Padding color for letterbox/pillarbox (default: "black")
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) < 1:
|
||||
raise ValueError("SEQUENCE requires at least one input")
|
||||
|
||||
if len(inputs) == 1:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(inputs[0], output_path)
|
||||
return output_path
|
||||
|
||||
transition = config.get("transition", {"type": "cut"})
|
||||
transition_type = transition.get("type", "cut")
|
||||
transition_duration = transition.get("duration", 0.5)
|
||||
|
||||
# Size handling config
|
||||
target_size = config.get("target_size", "first")
|
||||
width = config.get("width")
|
||||
height = config.get("height")
|
||||
background = config.get("background", "black")
|
||||
|
||||
if transition_type == "cut":
|
||||
return self._concat_cut(inputs, output_path, target_size, width, height, background)
|
||||
elif transition_type == "crossfade":
|
||||
return self._concat_crossfade(inputs, output_path, transition_duration)
|
||||
elif transition_type == "fade":
|
||||
return self._concat_fade(inputs, output_path, transition_duration)
|
||||
else:
|
||||
raise ValueError(f"Unknown transition type: {transition_type}")
|
||||
|
||||
def _concat_cut(
|
||||
self,
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
target_size: str = "first",
|
||||
width: int = None,
|
||||
height: int = None,
|
||||
background: str = "black",
|
||||
) -> Path:
|
||||
"""
|
||||
Concatenate with scaling/padding to handle different resolutions.
|
||||
|
||||
Args:
|
||||
inputs: Input video paths
|
||||
output_path: Output path
|
||||
target_size: How to determine output size:
|
||||
- "first": Use first input's dimensions (default)
|
||||
- "last": Use last input's dimensions
|
||||
- "largest": Use largest dimensions from all inputs
|
||||
- "explicit": Use width/height params
|
||||
width: Explicit width (when target_size="explicit")
|
||||
height: Explicit height (when target_size="explicit")
|
||||
background: Padding color (default: black)
|
||||
"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
n = len(inputs)
|
||||
input_args = []
|
||||
for p in inputs:
|
||||
input_args.extend(["-i", str(p)])
|
||||
|
||||
# Get video info for all inputs
|
||||
infos = [_get_video_info(p) for p in inputs]
|
||||
|
||||
# Determine target dimensions
|
||||
if target_size == "explicit" and width and height:
|
||||
target_w, target_h = width, height
|
||||
elif target_size == "last":
|
||||
target_w, target_h = infos[-1]["width"], infos[-1]["height"]
|
||||
elif target_size == "largest":
|
||||
target_w = max(i["width"] for i in infos)
|
||||
target_h = max(i["height"] for i in infos)
|
||||
else: # "first" or default
|
||||
target_w, target_h = infos[0]["width"], infos[0]["height"]
|
||||
|
||||
# Use common frame rate (from first input) and sample rate
|
||||
target_fps = infos[0]["fps"]
|
||||
target_sr = max(i["sample_rate"] for i in infos)
|
||||
|
||||
# Build filter for each input: scale to fit + pad to target size
|
||||
filter_parts = []
|
||||
for i in range(n):
|
||||
# Scale to fit within target, maintaining aspect ratio, then pad
|
||||
vf = (
|
||||
f"[{i}:v]scale={target_w}:{target_h}:force_original_aspect_ratio=decrease,"
|
||||
f"pad={target_w}:{target_h}:(ow-iw)/2:(oh-ih)/2:color={background},"
|
||||
f"setsar=1,fps={target_fps:.6f}[v{i}]"
|
||||
)
|
||||
# Resample audio to common rate
|
||||
af = f"[{i}:a]aresample={target_sr}[a{i}]"
|
||||
filter_parts.append(vf)
|
||||
filter_parts.append(af)
|
||||
|
||||
# Build concat filter
|
||||
stream_labels = "".join(f"[v{i}][a{i}]" for i in range(n))
|
||||
filter_parts.append(f"{stream_labels}concat=n={n}:v=1:a=1[outv][outa]")
|
||||
|
||||
filter_complex = ";".join(filter_parts)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
*input_args,
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", "[outv]",
|
||||
"-map", "[outa]",
|
||||
*get_web_encoding_args(),
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
logger.debug(f"SEQUENCE cut: {n} clips -> {target_w}x{target_h} (web-optimized)")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Concat failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
|
||||
def _concat_crossfade(
|
||||
self,
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
duration: float,
|
||||
) -> Path:
|
||||
"""Concatenate with crossfade transitions."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
durations = [_get_duration(p) for p in inputs]
|
||||
n = len(inputs)
|
||||
input_args = " ".join(f"-i {p}" for p in inputs)
|
||||
|
||||
# Build xfade filter chain
|
||||
filter_parts = []
|
||||
current = "[0:v]"
|
||||
|
||||
for i in range(1, n):
|
||||
offset = sum(durations[:i]) - duration * i
|
||||
next_input = f"[{i}:v]"
|
||||
output_label = f"[v{i}]" if i < n - 1 else "[outv]"
|
||||
filter_parts.append(
|
||||
f"{current}{next_input}xfade=transition=fade:duration={duration}:offset={offset}{output_label}"
|
||||
)
|
||||
current = output_label
|
||||
|
||||
# Audio crossfade chain
|
||||
audio_current = "[0:a]"
|
||||
for i in range(1, n):
|
||||
next_input = f"[{i}:a]"
|
||||
output_label = f"[a{i}]" if i < n - 1 else "[outa]"
|
||||
filter_parts.append(
|
||||
f"{audio_current}{next_input}acrossfade=d={duration}{output_label}"
|
||||
)
|
||||
audio_current = output_label
|
||||
|
||||
filter_complex = ";".join(filter_parts)
|
||||
|
||||
cmd = f'ffmpeg -y {input_args} -filter_complex "{filter_complex}" -map [outv] -map [outa] {WEB_ENCODING_ARGS_STR} {output_path}'
|
||||
|
||||
logger.debug(f"SEQUENCE crossfade: {n} clips (web-optimized)")
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"Crossfade failed, falling back to cut: {result.stderr[:200]}")
|
||||
return self._concat_cut(inputs, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
def _concat_fade(
|
||||
self,
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
duration: float,
|
||||
) -> Path:
|
||||
"""Concatenate with fade out/in transitions."""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
faded_paths = []
|
||||
for i, path in enumerate(inputs):
|
||||
clip_dur = _get_duration(path)
|
||||
faded_path = output_path.parent / f"_faded_{i}.mkv"
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(path),
|
||||
"-vf", f"fade=in:st=0:d={duration},fade=out:st={clip_dur - duration}:d={duration}",
|
||||
"-af", f"afade=in:st=0:d={duration},afade=out:st={clip_dur - duration}:d={duration}",
|
||||
"-c:v", "libx264", "-preset", "ultrafast", "-crf", "18",
|
||||
"-c:a", "aac",
|
||||
str(faded_path)
|
||||
]
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
faded_paths.append(faded_path)
|
||||
|
||||
result = self._concat_cut(faded_paths, output_path)
|
||||
|
||||
for p in faded_paths:
|
||||
p.unlink()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@register_executor(NodeType.LAYER)
|
||||
class LayerExecutor(Executor):
|
||||
"""
|
||||
Layer inputs spatially (overlay/composite).
|
||||
|
||||
Config:
|
||||
inputs: List of per-input configs
|
||||
position: [x, y] offset
|
||||
opacity: 0.0-1.0
|
||||
scale: Scale factor
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) < 1:
|
||||
raise ValueError("LAYER requires at least one input")
|
||||
|
||||
if len(inputs) == 1:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(inputs[0], output_path)
|
||||
return output_path
|
||||
|
||||
input_configs = config.get("inputs", [{}] * len(inputs))
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_args = " ".join(f"-i {p}" for p in inputs)
|
||||
n = len(inputs)
|
||||
filter_parts = []
|
||||
current = "[0:v]"
|
||||
|
||||
for i in range(1, n):
|
||||
cfg = input_configs[i] if i < len(input_configs) else {}
|
||||
x, y = cfg.get("position", [0, 0])
|
||||
opacity = cfg.get("opacity", 1.0)
|
||||
scale = cfg.get("scale", 1.0)
|
||||
|
||||
scale_label = f"[s{i}]"
|
||||
if scale != 1.0:
|
||||
filter_parts.append(f"[{i}:v]scale=iw*{scale}:ih*{scale}{scale_label}")
|
||||
overlay_input = scale_label
|
||||
else:
|
||||
overlay_input = f"[{i}:v]"
|
||||
|
||||
output_label = f"[v{i}]" if i < n - 1 else "[outv]"
|
||||
|
||||
if opacity < 1.0:
|
||||
filter_parts.append(
|
||||
f"{overlay_input}format=rgba,colorchannelmixer=aa={opacity}[a{i}]"
|
||||
)
|
||||
overlay_input = f"[a{i}]"
|
||||
|
||||
filter_parts.append(
|
||||
f"{current}{overlay_input}overlay=x={x}:y={y}:format=auto{output_label}"
|
||||
)
|
||||
current = output_label
|
||||
|
||||
filter_complex = ";".join(filter_parts)
|
||||
|
||||
cmd = f'ffmpeg -y {input_args} -filter_complex "{filter_complex}" -map [outv] -map 0:a? {WEB_ENCODING_ARGS_STR} {output_path}'
|
||||
|
||||
logger.debug(f"LAYER: {n} inputs (web-optimized)")
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Layer failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@register_executor(NodeType.MUX)
|
||||
class MuxExecutor(Executor):
|
||||
"""
|
||||
Combine video and audio streams.
|
||||
|
||||
Config:
|
||||
video_stream: Index of video input (default: 0)
|
||||
audio_stream: Index of audio input (default: 1)
|
||||
shortest: End when shortest stream ends (default: True)
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) < 2:
|
||||
raise ValueError("MUX requires at least 2 inputs (video + audio)")
|
||||
|
||||
video_idx = config.get("video_stream", 0)
|
||||
audio_idx = config.get("audio_stream", 1)
|
||||
shortest = config.get("shortest", True)
|
||||
|
||||
video_path = inputs[video_idx]
|
||||
audio_path = inputs[audio_idx]
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(video_path),
|
||||
"-i", str(audio_path),
|
||||
"-c:v", "copy",
|
||||
"-c:a", "aac",
|
||||
"-map", "0:v:0",
|
||||
"-map", "1:a:0",
|
||||
]
|
||||
|
||||
if shortest:
|
||||
cmd.append("-shortest")
|
||||
|
||||
cmd.append(str(output_path))
|
||||
|
||||
logger.debug(f"MUX: video={video_path.name} + audio={audio_path.name}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Mux failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@register_executor(NodeType.BLEND)
|
||||
class BlendExecutor(Executor):
|
||||
"""
|
||||
Blend two inputs using a blend mode.
|
||||
|
||||
Config:
|
||||
mode: Blend mode (multiply, screen, overlay, add, etc.)
|
||||
opacity: 0.0-1.0 for second input
|
||||
"""
|
||||
|
||||
BLEND_MODES = {
|
||||
"multiply": "multiply",
|
||||
"screen": "screen",
|
||||
"overlay": "overlay",
|
||||
"add": "addition",
|
||||
"subtract": "subtract",
|
||||
"average": "average",
|
||||
"difference": "difference",
|
||||
"lighten": "lighten",
|
||||
"darken": "darken",
|
||||
}
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) != 2:
|
||||
raise ValueError("BLEND requires exactly 2 inputs")
|
||||
|
||||
mode = config.get("mode", "overlay")
|
||||
opacity = config.get("opacity", 0.5)
|
||||
|
||||
if mode not in self.BLEND_MODES:
|
||||
raise ValueError(f"Unknown blend mode: {mode}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
blend_mode = self.BLEND_MODES[mode]
|
||||
|
||||
if opacity < 1.0:
|
||||
filter_complex = (
|
||||
f"[1:v]format=rgba,colorchannelmixer=aa={opacity}[b];"
|
||||
f"[0:v][b]blend=all_mode={blend_mode}"
|
||||
)
|
||||
else:
|
||||
filter_complex = f"[0:v][1:v]blend=all_mode={blend_mode}"
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(inputs[0]),
|
||||
"-i", str(inputs[1]),
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", "0:a?",
|
||||
*get_web_encoding_args(),
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
logger.debug(f"BLEND: {mode} (opacity={opacity}) (web-optimized)")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Blend failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@register_executor(NodeType.AUDIO_MIX)
|
||||
class AudioMixExecutor(Executor):
|
||||
"""
|
||||
Mix multiple audio streams.
|
||||
|
||||
Config:
|
||||
gains: List of gain values per input (0.0-2.0, default 1.0)
|
||||
normalize: Normalize output to prevent clipping (default True)
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) < 2:
|
||||
raise ValueError("AUDIO_MIX requires at least 2 inputs")
|
||||
|
||||
gains = config.get("gains", [1.0] * len(inputs))
|
||||
normalize = config.get("normalize", True)
|
||||
|
||||
# Pad gains list if too short
|
||||
while len(gains) < len(inputs):
|
||||
gains.append(1.0)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build filter: apply volume to each input, then mix
|
||||
filter_parts = []
|
||||
mix_inputs = []
|
||||
|
||||
for i, gain in enumerate(gains[:len(inputs)]):
|
||||
if gain != 1.0:
|
||||
filter_parts.append(f"[{i}:a]volume={gain}[a{i}]")
|
||||
mix_inputs.append(f"[a{i}]")
|
||||
else:
|
||||
mix_inputs.append(f"[{i}:a]")
|
||||
|
||||
# amix filter
|
||||
normalize_flag = 1 if normalize else 0
|
||||
mix_filter = f"{''.join(mix_inputs)}amix=inputs={len(inputs)}:normalize={normalize_flag}[aout]"
|
||||
filter_parts.append(mix_filter)
|
||||
|
||||
filter_complex = ";".join(filter_parts)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
]
|
||||
for p in inputs:
|
||||
cmd.extend(["-i", str(p)])
|
||||
|
||||
cmd.extend([
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", "[aout]",
|
||||
"-c:a", "aac",
|
||||
str(output_path)
|
||||
])
|
||||
|
||||
logger.debug(f"AUDIO_MIX: {len(inputs)} inputs, gains={gains[:len(inputs)]}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Audio mix failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
520
core/artdag/nodes/effect.py
Normal file
520
core/artdag/nodes/effect.py
Normal file
@@ -0,0 +1,520 @@
|
||||
# artdag/nodes/effect.py
|
||||
"""
|
||||
Effect executor: Apply effects from the registry or IPFS.
|
||||
|
||||
Primitives: EFFECT
|
||||
|
||||
Effects can be:
|
||||
1. Built-in (registered with @register_effect)
|
||||
2. Stored in IPFS (referenced by CID)
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
|
||||
import requests
|
||||
|
||||
from ..executor import Executor, register_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for effect functions: (input_path, output_path, config) -> output_path
|
||||
EffectFn = Callable[[Path, Path, Dict[str, Any]], Path]
|
||||
|
||||
# Type variable for decorator
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# IPFS API multiaddr - same as ipfs_client.py for consistency
|
||||
# Docker uses /dns/ipfs/tcp/5001, local dev uses /ip4/127.0.0.1/tcp/5001
|
||||
IPFS_API = os.environ.get("IPFS_API", "/ip4/127.0.0.1/tcp/5001")
|
||||
|
||||
# Connection timeout in seconds
|
||||
IPFS_TIMEOUT = int(os.environ.get("IPFS_TIMEOUT", "30"))
|
||||
|
||||
|
||||
def _get_ipfs_base_url() -> str:
|
||||
"""
|
||||
Convert IPFS multiaddr to HTTP URL.
|
||||
|
||||
Matches the conversion logic in ipfs_client.py for consistency.
|
||||
"""
|
||||
multiaddr = IPFS_API
|
||||
|
||||
# Handle /dns/hostname/tcp/port format (Docker)
|
||||
dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if dns_match:
|
||||
return f"http://{dns_match.group(1)}:{dns_match.group(2)}"
|
||||
|
||||
# Handle /ip4/address/tcp/port format (local)
|
||||
ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if ip4_match:
|
||||
return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}"
|
||||
|
||||
# Fallback: assume it's already a URL or use default
|
||||
if multiaddr.startswith("http"):
|
||||
return multiaddr
|
||||
return "http://127.0.0.1:5001"
|
||||
|
||||
|
||||
def _get_effects_cache_dir() -> Optional[Path]:
|
||||
"""Get the effects cache directory from environment or default."""
|
||||
# Check both env var names (CACHE_DIR used by art-celery, ARTDAG_CACHE_DIR for standalone)
|
||||
for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]:
|
||||
cache_dir = os.environ.get(env_var)
|
||||
if cache_dir:
|
||||
effects_dir = Path(cache_dir) / "_effects"
|
||||
if effects_dir.exists():
|
||||
return effects_dir
|
||||
|
||||
# Try default locations
|
||||
for base in [Path.home() / ".artdag" / "cache", Path("/var/cache/artdag")]:
|
||||
effects_dir = base / "_effects"
|
||||
if effects_dir.exists():
|
||||
return effects_dir
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_effect_from_ipfs(cid: str, effect_path: Path) -> bool:
|
||||
"""
|
||||
Fetch an effect from IPFS and cache locally.
|
||||
|
||||
Uses the IPFS API endpoint (/api/v0/cat) for consistency with ipfs_client.py.
|
||||
This works reliably in Docker where IPFS_API=/dns/ipfs/tcp/5001.
|
||||
|
||||
Returns True on success, False on failure.
|
||||
"""
|
||||
try:
|
||||
# Use IPFS API (same as ipfs_client.py)
|
||||
base_url = _get_ipfs_base_url()
|
||||
url = f"{base_url}/api/v0/cat"
|
||||
params = {"arg": cid}
|
||||
|
||||
response = requests.post(url, params=params, timeout=IPFS_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
|
||||
# Cache locally
|
||||
effect_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
effect_path.write_bytes(response.content)
|
||||
logger.info(f"Fetched effect from IPFS: {cid[:16]}...")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch effect from IPFS {cid[:16]}...: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _parse_pep723_dependencies(source: str) -> List[str]:
|
||||
"""
|
||||
Parse PEP 723 dependencies from effect source code.
|
||||
|
||||
Returns list of package names (e.g., ["numpy", "opencv-python"]).
|
||||
"""
|
||||
match = re.search(r"# /// script\n(.*?)# ///", source, re.DOTALL)
|
||||
if not match:
|
||||
return []
|
||||
|
||||
block = match.group(1)
|
||||
deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL)
|
||||
if not deps_match:
|
||||
return []
|
||||
|
||||
return re.findall(r'"([^"]+)"', deps_match.group(1))
|
||||
|
||||
|
||||
def _ensure_dependencies(dependencies: List[str], effect_cid: str) -> bool:
|
||||
"""
|
||||
Ensure effect dependencies are installed.
|
||||
|
||||
Installs missing packages using pip. Returns True on success.
|
||||
"""
|
||||
if not dependencies:
|
||||
return True
|
||||
|
||||
missing = []
|
||||
for dep in dependencies:
|
||||
# Extract package name (strip version specifiers)
|
||||
pkg_name = re.split(r'[<>=!]', dep)[0].strip()
|
||||
# Normalize name (pip uses underscores, imports use underscores or hyphens)
|
||||
pkg_name_normalized = pkg_name.replace('-', '_').lower()
|
||||
|
||||
try:
|
||||
__import__(pkg_name_normalized)
|
||||
except ImportError:
|
||||
# Some packages have different import names
|
||||
try:
|
||||
# Try original name with hyphens replaced
|
||||
__import__(pkg_name.replace('-', '_'))
|
||||
except ImportError:
|
||||
missing.append(dep)
|
||||
|
||||
if not missing:
|
||||
return True
|
||||
|
||||
logger.info(f"Installing effect dependencies for {effect_cid[:16]}...: {missing}")
|
||||
|
||||
try:
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "--quiet"] + missing,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"Failed to install dependencies: {result.stderr}")
|
||||
return False
|
||||
|
||||
logger.info(f"Installed dependencies: {missing}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error installing dependencies: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _load_cached_effect(effect_cid: str) -> Optional[EffectFn]:
|
||||
"""
|
||||
Load an effect by CID, fetching from IPFS if not cached locally.
|
||||
|
||||
Returns the effect function or None if not found.
|
||||
"""
|
||||
effects_dir = _get_effects_cache_dir()
|
||||
|
||||
# Create cache dir if needed
|
||||
if not effects_dir:
|
||||
# Try to create default cache dir
|
||||
for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]:
|
||||
cache_dir = os.environ.get(env_var)
|
||||
if cache_dir:
|
||||
effects_dir = Path(cache_dir) / "_effects"
|
||||
effects_dir.mkdir(parents=True, exist_ok=True)
|
||||
break
|
||||
|
||||
if not effects_dir:
|
||||
effects_dir = Path.home() / ".artdag" / "cache" / "_effects"
|
||||
effects_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
effect_path = effects_dir / effect_cid / "effect.py"
|
||||
|
||||
# If not cached locally, fetch from IPFS
|
||||
if not effect_path.exists():
|
||||
if not _fetch_effect_from_ipfs(effect_cid, effect_path):
|
||||
logger.warning(f"Effect not found: {effect_cid[:16]}...")
|
||||
return None
|
||||
|
||||
# Parse and install dependencies before loading
|
||||
try:
|
||||
source = effect_path.read_text()
|
||||
dependencies = _parse_pep723_dependencies(source)
|
||||
if dependencies:
|
||||
logger.info(f"Effect {effect_cid[:16]}... requires: {dependencies}")
|
||||
if not _ensure_dependencies(dependencies, effect_cid):
|
||||
logger.error(f"Failed to install dependencies for effect {effect_cid[:16]}...")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing effect dependencies: {e}")
|
||||
# Continue anyway - the effect might work without the deps check
|
||||
|
||||
# Load the effect module
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location("cached_effect", effect_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
# Check for frame-by-frame API
|
||||
if hasattr(module, "process_frame"):
|
||||
return _wrap_frame_effect(module, effect_path)
|
||||
|
||||
# Check for whole-video API
|
||||
if hasattr(module, "process"):
|
||||
return _wrap_video_effect(module)
|
||||
|
||||
# Check for old-style effect function
|
||||
if hasattr(module, "effect"):
|
||||
return module.effect
|
||||
|
||||
logger.warning(f"Effect has no recognized API: {effect_cid[:16]}...")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load effect {effect_cid[:16]}...: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _wrap_frame_effect(module: ModuleType, effect_path: Path) -> EffectFn:
|
||||
"""Wrap a frame-by-frame effect to work with the executor API."""
|
||||
|
||||
def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
|
||||
"""Run frame-by-frame effect through FFmpeg pipes."""
|
||||
try:
|
||||
from ..effects.frame_processor import process_video
|
||||
except ImportError:
|
||||
logger.error("Frame processor not available - falling back to copy")
|
||||
shutil.copy2(input_path, output_path)
|
||||
return output_path
|
||||
|
||||
# Extract params from config (excluding internal keys)
|
||||
params = {k: v for k, v in config.items()
|
||||
if k not in ("effect", "hash", "_binding")}
|
||||
|
||||
# Get bindings if present
|
||||
bindings = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and value.get("_resolved_values"):
|
||||
bindings[key] = value["_resolved_values"]
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
actual_output = output_path.with_suffix(".mp4")
|
||||
|
||||
process_video(
|
||||
input_path=input_path,
|
||||
output_path=actual_output,
|
||||
process_frame=module.process_frame,
|
||||
params=params,
|
||||
bindings=bindings,
|
||||
)
|
||||
|
||||
return actual_output
|
||||
|
||||
return wrapped_effect
|
||||
|
||||
|
||||
def _wrap_video_effect(module: ModuleType) -> EffectFn:
|
||||
"""Wrap a whole-video effect to work with the executor API."""
|
||||
|
||||
def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
|
||||
"""Run whole-video effect."""
|
||||
from ..effects.meta import ExecutionContext
|
||||
|
||||
params = {k: v for k, v in config.items()
|
||||
if k not in ("effect", "hash", "_binding")}
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ctx = ExecutionContext(
|
||||
input_paths=[str(input_path)],
|
||||
output_path=str(output_path),
|
||||
params=params,
|
||||
seed=hash(str(input_path)) & 0xFFFFFFFF,
|
||||
)
|
||||
|
||||
module.process([input_path], output_path, params, ctx)
|
||||
return output_path
|
||||
|
||||
return wrapped_effect
|
||||
|
||||
|
||||
# Effect registry - maps effect names to implementations
|
||||
_EFFECTS: Dict[str, EffectFn] = {}
|
||||
|
||||
|
||||
def register_effect(name: str) -> Callable[[F], F]:
|
||||
"""Decorator to register an effect implementation."""
|
||||
def decorator(func: F) -> F:
|
||||
_EFFECTS[name] = func # type: ignore[assignment]
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def get_effect(name: str) -> Optional[EffectFn]:
|
||||
"""Get an effect implementation by name."""
|
||||
return _EFFECTS.get(name)
|
||||
|
||||
|
||||
# Built-in effects
|
||||
|
||||
@register_effect("identity")
|
||||
def effect_identity(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
|
||||
"""
|
||||
Identity effect - returns input unchanged.
|
||||
|
||||
This is the foundational effect: identity(x) = x
|
||||
"""
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Remove existing output if any
|
||||
if output_path.exists() or output_path.is_symlink():
|
||||
output_path.unlink()
|
||||
|
||||
# Preserve extension from input
|
||||
actual_output = output_path.with_suffix(input_path.suffix)
|
||||
if actual_output.exists() or actual_output.is_symlink():
|
||||
actual_output.unlink()
|
||||
|
||||
# Symlink to input (zero-copy identity)
|
||||
os.symlink(input_path.resolve(), actual_output)
|
||||
logger.debug(f"EFFECT identity: {input_path.name} -> {actual_output}")
|
||||
|
||||
return actual_output
|
||||
|
||||
|
||||
def _get_sexp_effect(effect_path: str, recipe_dir: Path = None) -> Optional[EffectFn]:
|
||||
"""
|
||||
Load a sexp effect from a .sexp file.
|
||||
|
||||
Args:
|
||||
effect_path: Relative path to the .sexp effect file
|
||||
recipe_dir: Base directory for resolving paths
|
||||
|
||||
Returns:
|
||||
Effect function or None if not a sexp effect
|
||||
"""
|
||||
if not effect_path or not effect_path.endswith(".sexp"):
|
||||
return None
|
||||
|
||||
try:
|
||||
from ..sexp.effect_loader import SexpEffectLoader
|
||||
except ImportError:
|
||||
logger.warning("Sexp effect loader not available")
|
||||
return None
|
||||
|
||||
try:
|
||||
loader = SexpEffectLoader(recipe_dir or Path.cwd())
|
||||
return loader.load_effect_path(effect_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load sexp effect from {effect_path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_python_primitive_effect(effect_name: str) -> Optional[EffectFn]:
|
||||
"""
|
||||
Get a Python primitive frame processor effect.
|
||||
|
||||
Checks if the effect has a python_primitive in FFmpegCompiler.EFFECT_MAPPINGS
|
||||
and wraps it for the executor API.
|
||||
"""
|
||||
try:
|
||||
from ..sexp.ffmpeg_compiler import FFmpegCompiler
|
||||
from ..sexp.primitives import get_primitive
|
||||
from ..effects.frame_processor import process_video
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
compiler = FFmpegCompiler()
|
||||
primitive_name = compiler.has_python_primitive(effect_name)
|
||||
if not primitive_name:
|
||||
return None
|
||||
|
||||
primitive_fn = get_primitive(primitive_name)
|
||||
if not primitive_fn:
|
||||
logger.warning(f"Python primitive '{primitive_name}' not found for effect '{effect_name}'")
|
||||
return None
|
||||
|
||||
def wrapped_effect(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
|
||||
"""Run Python primitive effect via frame processor."""
|
||||
# Extract params (excluding internal keys)
|
||||
params = {k: v for k, v in config.items()
|
||||
if k not in ("effect", "cid", "hash", "effect_path", "_binding")}
|
||||
|
||||
# Get bindings if present
|
||||
bindings = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and value.get("_resolved_values"):
|
||||
bindings[key] = value["_resolved_values"]
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
actual_output = output_path.with_suffix(".mp4")
|
||||
|
||||
# Wrap primitive to match frame processor signature
|
||||
def process_frame(frame, frame_params, state):
|
||||
# Call primitive with frame and params
|
||||
result = primitive_fn(frame, **frame_params)
|
||||
return result, state
|
||||
|
||||
process_video(
|
||||
input_path=input_path,
|
||||
output_path=actual_output,
|
||||
process_frame=process_frame,
|
||||
params=params,
|
||||
bindings=bindings,
|
||||
)
|
||||
|
||||
logger.info(f"Processed effect '{effect_name}' via Python primitive '{primitive_name}'")
|
||||
return actual_output
|
||||
|
||||
return wrapped_effect
|
||||
|
||||
|
||||
@register_executor("EFFECT")
|
||||
class EffectExecutor(Executor):
|
||||
"""
|
||||
Apply an effect from the registry or IPFS.
|
||||
|
||||
Config:
|
||||
effect: Name of the effect to apply
|
||||
cid: IPFS CID for the effect (fetched from IPFS if not cached)
|
||||
hash: Legacy alias for cid (backwards compatibility)
|
||||
params: Optional parameters for the effect
|
||||
|
||||
Inputs:
|
||||
Single input file to transform
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
effect_name = config.get("effect")
|
||||
# Support both "cid" (new) and "hash" (legacy)
|
||||
effect_cid = config.get("cid") or config.get("hash")
|
||||
|
||||
if not effect_name:
|
||||
raise ValueError("EFFECT requires 'effect' config")
|
||||
|
||||
if len(inputs) != 1:
|
||||
raise ValueError(f"EFFECT expects 1 input, got {len(inputs)}")
|
||||
|
||||
# Try IPFS effect first if CID provided
|
||||
effect_fn: Optional[EffectFn] = None
|
||||
if effect_cid:
|
||||
effect_fn = _load_cached_effect(effect_cid)
|
||||
if effect_fn:
|
||||
logger.info(f"Running effect '{effect_name}' (cid={effect_cid[:16]}...)")
|
||||
|
||||
# Try sexp effect from effect_path (.sexp file)
|
||||
if effect_fn is None:
|
||||
effect_path = config.get("effect_path")
|
||||
if effect_path and effect_path.endswith(".sexp"):
|
||||
effect_fn = _get_sexp_effect(effect_path)
|
||||
if effect_fn:
|
||||
logger.info(f"Running effect '{effect_name}' via sexp definition")
|
||||
|
||||
# Try Python primitive (from FFmpegCompiler.EFFECT_MAPPINGS)
|
||||
if effect_fn is None:
|
||||
effect_fn = _get_python_primitive_effect(effect_name)
|
||||
if effect_fn:
|
||||
logger.info(f"Running effect '{effect_name}' via Python primitive")
|
||||
|
||||
# Fall back to built-in effect
|
||||
if effect_fn is None:
|
||||
effect_fn = get_effect(effect_name)
|
||||
|
||||
if effect_fn is None:
|
||||
raise ValueError(f"Unknown effect: {effect_name}")
|
||||
|
||||
# Pass full config (effect can extract what it needs)
|
||||
return effect_fn(inputs[0], output_path, config)
|
||||
|
||||
def validate_config(self, config: Dict[str, Any]) -> List[str]:
|
||||
errors = []
|
||||
if "effect" not in config:
|
||||
errors.append("EFFECT requires 'effect' config")
|
||||
else:
|
||||
# If CID provided, we'll load from IPFS - skip built-in check
|
||||
has_cid = config.get("cid") or config.get("hash")
|
||||
if not has_cid and get_effect(config["effect"]) is None:
|
||||
errors.append(f"Unknown effect: {config['effect']}")
|
||||
return errors
|
||||
50
core/artdag/nodes/encoding.py
Normal file
50
core/artdag/nodes/encoding.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# artdag/nodes/encoding.py
|
||||
"""
|
||||
Web-optimized video encoding settings.
|
||||
|
||||
Provides common FFmpeg arguments for producing videos that:
|
||||
- Stream efficiently (faststart)
|
||||
- Play on all browsers (H.264 High profile)
|
||||
- Support seeking (regular keyframes)
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
# Standard web-optimized video encoding arguments
|
||||
WEB_VIDEO_ARGS: List[str] = [
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-crf", "18",
|
||||
"-profile:v", "high",
|
||||
"-level", "4.1",
|
||||
"-pix_fmt", "yuv420p", # Ensure broad compatibility
|
||||
"-movflags", "+faststart", # Enable streaming before full download
|
||||
"-g", "48", # Keyframe every ~2 seconds at 24fps (for seeking)
|
||||
]
|
||||
|
||||
# Standard audio encoding arguments
|
||||
WEB_AUDIO_ARGS: List[str] = [
|
||||
"-c:a", "aac",
|
||||
"-b:a", "192k",
|
||||
]
|
||||
|
||||
|
||||
def get_web_encoding_args() -> List[str]:
|
||||
"""Get FFmpeg args for web-optimized video+audio encoding."""
|
||||
return WEB_VIDEO_ARGS + WEB_AUDIO_ARGS
|
||||
|
||||
|
||||
def get_web_video_args() -> List[str]:
|
||||
"""Get FFmpeg args for web-optimized video encoding only."""
|
||||
return WEB_VIDEO_ARGS.copy()
|
||||
|
||||
|
||||
def get_web_audio_args() -> List[str]:
|
||||
"""Get FFmpeg args for web-optimized audio encoding only."""
|
||||
return WEB_AUDIO_ARGS.copy()
|
||||
|
||||
|
||||
# For shell commands (string format)
|
||||
WEB_VIDEO_ARGS_STR = " ".join(WEB_VIDEO_ARGS)
|
||||
WEB_AUDIO_ARGS_STR = " ".join(WEB_AUDIO_ARGS)
|
||||
WEB_ENCODING_ARGS_STR = f"{WEB_VIDEO_ARGS_STR} {WEB_AUDIO_ARGS_STR}"
|
||||
62
core/artdag/nodes/source.py
Normal file
62
core/artdag/nodes/source.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# primitive/nodes/source.py
|
||||
"""
|
||||
Source executors: Load media from paths.
|
||||
|
||||
Primitives: SOURCE
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ..dag import NodeType
|
||||
from ..executor import Executor, register_executor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class SourceExecutor(Executor):
|
||||
"""
|
||||
Load source media from a path.
|
||||
|
||||
Config:
|
||||
path: Path to source file
|
||||
|
||||
Creates a symlink to the source file for zero-copy loading.
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
source_path = Path(config["path"])
|
||||
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Source file not found: {source_path}")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use symlink for zero-copy
|
||||
if output_path.exists() or output_path.is_symlink():
|
||||
output_path.unlink()
|
||||
|
||||
# Preserve extension from source
|
||||
actual_output = output_path.with_suffix(source_path.suffix)
|
||||
if actual_output.exists() or actual_output.is_symlink():
|
||||
actual_output.unlink()
|
||||
|
||||
os.symlink(source_path.resolve(), actual_output)
|
||||
logger.debug(f"SOURCE: {source_path.name} -> {actual_output}")
|
||||
|
||||
return actual_output
|
||||
|
||||
def validate_config(self, config: Dict[str, Any]) -> List[str]:
|
||||
errors = []
|
||||
if "path" not in config:
|
||||
errors.append("SOURCE requires 'path' config")
|
||||
return errors
|
||||
224
core/artdag/nodes/transform.py
Normal file
224
core/artdag/nodes/transform.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# primitive/nodes/transform.py
|
||||
"""
|
||||
Transform executors: Modify single media inputs.
|
||||
|
||||
Primitives: SEGMENT, RESIZE, TRANSFORM
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ..dag import NodeType
|
||||
from ..executor import Executor, register_executor
|
||||
from .encoding import get_web_encoding_args, get_web_video_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_executor(NodeType.SEGMENT)
|
||||
class SegmentExecutor(Executor):
|
||||
"""
|
||||
Extract a time segment from media.
|
||||
|
||||
Config:
|
||||
offset: Start time in seconds (default: 0)
|
||||
duration: Duration in seconds (optional, default: to end)
|
||||
precise: Use frame-accurate seeking (default: True)
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("SEGMENT requires exactly one input")
|
||||
|
||||
input_path = inputs[0]
|
||||
offset = config.get("offset", 0)
|
||||
duration = config.get("duration")
|
||||
precise = config.get("precise", True)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if precise:
|
||||
# Frame-accurate: decode-seek (slower but precise)
|
||||
cmd = ["ffmpeg", "-y", "-i", str(input_path)]
|
||||
if offset > 0:
|
||||
cmd.extend(["-ss", str(offset)])
|
||||
if duration:
|
||||
cmd.extend(["-t", str(duration)])
|
||||
cmd.extend([*get_web_encoding_args(), str(output_path)])
|
||||
else:
|
||||
# Fast: input-seek at keyframes (may be slightly off)
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
if offset > 0:
|
||||
cmd.extend(["-ss", str(offset)])
|
||||
cmd.extend(["-i", str(input_path)])
|
||||
if duration:
|
||||
cmd.extend(["-t", str(duration)])
|
||||
cmd.extend(["-c", "copy", str(output_path)])
|
||||
|
||||
logger.debug(f"SEGMENT: offset={offset}, duration={duration}, precise={precise}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Segment failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@register_executor(NodeType.RESIZE)
|
||||
class ResizeExecutor(Executor):
|
||||
"""
|
||||
Resize media to target dimensions.
|
||||
|
||||
Config:
|
||||
width: Target width
|
||||
height: Target height
|
||||
mode: "fit" (letterbox), "fill" (crop), "stretch", "pad"
|
||||
background: Background color for pad mode (default: black)
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("RESIZE requires exactly one input")
|
||||
|
||||
input_path = inputs[0]
|
||||
width = config["width"]
|
||||
height = config["height"]
|
||||
mode = config.get("mode", "fit")
|
||||
background = config.get("background", "black")
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if mode == "fit":
|
||||
# Scale to fit, add letterboxing
|
||||
vf = f"scale={width}:{height}:force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:color={background}"
|
||||
elif mode == "fill":
|
||||
# Scale to fill, crop excess
|
||||
vf = f"scale={width}:{height}:force_original_aspect_ratio=increase,crop={width}:{height}"
|
||||
elif mode == "stretch":
|
||||
# Stretch to exact size
|
||||
vf = f"scale={width}:{height}"
|
||||
elif mode == "pad":
|
||||
# Scale down only if larger, then pad
|
||||
vf = f"scale='min({width},iw)':'min({height},ih)':force_original_aspect_ratio=decrease,pad={width}:{height}:(ow-iw)/2:(oh-ih)/2:color={background}"
|
||||
else:
|
||||
raise ValueError(f"Unknown resize mode: {mode}")
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(input_path),
|
||||
"-vf", vf,
|
||||
*get_web_video_args(),
|
||||
"-c:a", "copy",
|
||||
str(output_path)
|
||||
]
|
||||
|
||||
logger.debug(f"RESIZE: {width}x{height} ({mode}) (web-optimized)")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Resize failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
@register_executor(NodeType.TRANSFORM)
|
||||
class TransformExecutor(Executor):
|
||||
"""
|
||||
Apply visual effects to media.
|
||||
|
||||
Config:
|
||||
effects: Dict of effect -> value
|
||||
saturation: 0.0-2.0 (1.0 = normal)
|
||||
contrast: 0.0-2.0 (1.0 = normal)
|
||||
brightness: -1.0 to 1.0 (0.0 = normal)
|
||||
gamma: 0.1-10.0 (1.0 = normal)
|
||||
hue: degrees shift
|
||||
blur: blur radius
|
||||
sharpen: sharpen amount
|
||||
speed: playback speed multiplier
|
||||
"""
|
||||
|
||||
def execute(
|
||||
self,
|
||||
config: Dict[str, Any],
|
||||
inputs: List[Path],
|
||||
output_path: Path,
|
||||
) -> Path:
|
||||
if len(inputs) != 1:
|
||||
raise ValueError("TRANSFORM requires exactly one input")
|
||||
|
||||
input_path = inputs[0]
|
||||
effects = config.get("effects", {})
|
||||
|
||||
if not effects:
|
||||
# No effects - just copy
|
||||
import shutil
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(input_path, output_path)
|
||||
return output_path
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build filter chain
|
||||
vf_parts = []
|
||||
af_parts = []
|
||||
|
||||
# Color adjustments via eq filter
|
||||
eq_parts = []
|
||||
if "saturation" in effects:
|
||||
eq_parts.append(f"saturation={effects['saturation']}")
|
||||
if "contrast" in effects:
|
||||
eq_parts.append(f"contrast={effects['contrast']}")
|
||||
if "brightness" in effects:
|
||||
eq_parts.append(f"brightness={effects['brightness']}")
|
||||
if "gamma" in effects:
|
||||
eq_parts.append(f"gamma={effects['gamma']}")
|
||||
if eq_parts:
|
||||
vf_parts.append(f"eq={':'.join(eq_parts)}")
|
||||
|
||||
# Hue adjustment
|
||||
if "hue" in effects:
|
||||
vf_parts.append(f"hue=h={effects['hue']}")
|
||||
|
||||
# Blur
|
||||
if "blur" in effects:
|
||||
vf_parts.append(f"boxblur={effects['blur']}")
|
||||
|
||||
# Sharpen
|
||||
if "sharpen" in effects:
|
||||
vf_parts.append(f"unsharp=5:5:{effects['sharpen']}:5:5:0")
|
||||
|
||||
# Speed change
|
||||
if "speed" in effects:
|
||||
speed = effects["speed"]
|
||||
vf_parts.append(f"setpts={1/speed}*PTS")
|
||||
af_parts.append(f"atempo={speed}")
|
||||
|
||||
cmd = ["ffmpeg", "-y", "-i", str(input_path)]
|
||||
|
||||
if vf_parts:
|
||||
cmd.extend(["-vf", ",".join(vf_parts)])
|
||||
if af_parts:
|
||||
cmd.extend(["-af", ",".join(af_parts)])
|
||||
|
||||
cmd.extend([*get_web_encoding_args(), str(output_path)])
|
||||
|
||||
logger.debug(f"TRANSFORM: {list(effects.keys())} (web-optimized)")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Transform failed: {result.stderr}")
|
||||
|
||||
return output_path
|
||||
29
core/artdag/planning/__init__.py
Normal file
29
core/artdag/planning/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# artdag/planning - Execution plan generation
|
||||
#
|
||||
# Provides the Planning phase of the 3-phase execution model:
|
||||
# 1. ANALYZE - Extract features from inputs
|
||||
# 2. PLAN - Generate execution plan with cache IDs
|
||||
# 3. EXECUTE - Run steps with caching
|
||||
|
||||
from .schema import (
|
||||
ExecutionStep,
|
||||
ExecutionPlan,
|
||||
StepStatus,
|
||||
StepOutput,
|
||||
StepInput,
|
||||
PlanInput,
|
||||
)
|
||||
from .planner import RecipePlanner, Recipe
|
||||
from .tree_reduction import TreeReducer
|
||||
|
||||
__all__ = [
|
||||
"ExecutionStep",
|
||||
"ExecutionPlan",
|
||||
"StepStatus",
|
||||
"StepOutput",
|
||||
"StepInput",
|
||||
"PlanInput",
|
||||
"RecipePlanner",
|
||||
"Recipe",
|
||||
"TreeReducer",
|
||||
]
|
||||
756
core/artdag/planning/planner.py
Normal file
756
core/artdag/planning/planner.py
Normal file
@@ -0,0 +1,756 @@
|
||||
# artdag/planning/planner.py
|
||||
"""
|
||||
Recipe planner - converts recipes into execution plans.
|
||||
|
||||
The planner is the second phase of the 3-phase execution model.
|
||||
It takes a recipe and analysis results and generates a complete
|
||||
execution plan with pre-computed cache IDs.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from .schema import ExecutionPlan, ExecutionStep, StepOutput, StepInput, PlanInput
|
||||
from .tree_reduction import TreeReducer, reduce_sequence
|
||||
from ..analysis import AnalysisResult
|
||||
|
||||
|
||||
def _infer_media_type(node_type: str, config: Dict[str, Any] = None) -> str:
|
||||
"""Infer media type from node type and config."""
|
||||
config = config or {}
|
||||
|
||||
# Audio operations
|
||||
if node_type in ("AUDIO", "MIX_AUDIO", "EXTRACT_AUDIO"):
|
||||
return "audio/wav"
|
||||
if "audio" in node_type.lower():
|
||||
return "audio/wav"
|
||||
|
||||
# Image operations
|
||||
if node_type in ("FRAME", "THUMBNAIL", "IMAGE"):
|
||||
return "image/png"
|
||||
|
||||
# Default to video
|
||||
return "video/mp4"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""Create stable hash from arbitrary data."""
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecipeNode:
|
||||
"""A node in the recipe DAG."""
|
||||
id: str
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
inputs: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Recipe:
|
||||
"""Parsed recipe structure."""
|
||||
name: str
|
||||
version: str
|
||||
description: str
|
||||
nodes: List[RecipeNode]
|
||||
output: str
|
||||
registry: Dict[str, Any]
|
||||
owner: str
|
||||
raw_yaml: str
|
||||
|
||||
@property
|
||||
def recipe_hash(self) -> str:
|
||||
"""Compute hash of recipe content."""
|
||||
return _stable_hash({"yaml": self.raw_yaml})
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_content: str) -> "Recipe":
|
||||
"""Parse recipe from YAML string."""
|
||||
data = yaml.safe_load(yaml_content)
|
||||
|
||||
nodes = []
|
||||
for node_data in data.get("dag", {}).get("nodes", []):
|
||||
# Handle both 'inputs' as list and 'inputs' as dict
|
||||
inputs = node_data.get("inputs", [])
|
||||
if isinstance(inputs, dict):
|
||||
# Extract input references from dict structure
|
||||
input_list = []
|
||||
for key, value in inputs.items():
|
||||
if isinstance(value, str):
|
||||
input_list.append(value)
|
||||
elif isinstance(value, list):
|
||||
input_list.extend(value)
|
||||
inputs = input_list
|
||||
elif isinstance(inputs, str):
|
||||
inputs = [inputs]
|
||||
|
||||
nodes.append(RecipeNode(
|
||||
id=node_data["id"],
|
||||
type=node_data["type"],
|
||||
config=node_data.get("config", {}),
|
||||
inputs=inputs,
|
||||
))
|
||||
|
||||
return cls(
|
||||
name=data.get("name", "unnamed"),
|
||||
version=data.get("version", "1.0"),
|
||||
description=data.get("description", ""),
|
||||
nodes=nodes,
|
||||
output=data.get("dag", {}).get("output", ""),
|
||||
registry=data.get("registry", {}),
|
||||
owner=data.get("owner", ""),
|
||||
raw_yaml=yaml_content,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Path) -> "Recipe":
|
||||
"""Load recipe from YAML file."""
|
||||
with open(path, "r") as f:
|
||||
return cls.from_yaml(f.read())
|
||||
|
||||
|
||||
class RecipePlanner:
|
||||
"""
|
||||
Generates execution plans from recipes.
|
||||
|
||||
The planner:
|
||||
1. Parses the recipe
|
||||
2. Resolves fixed inputs from registry
|
||||
3. Maps variable inputs to provided hashes
|
||||
4. Expands MAP/iteration nodes
|
||||
5. Applies tree reduction for SEQUENCE nodes
|
||||
6. Computes cache IDs for all steps
|
||||
"""
|
||||
|
||||
def __init__(self, use_tree_reduction: bool = True):
|
||||
"""
|
||||
Initialize the planner.
|
||||
|
||||
Args:
|
||||
use_tree_reduction: Whether to use tree reduction for SEQUENCE
|
||||
"""
|
||||
self.use_tree_reduction = use_tree_reduction
|
||||
|
||||
def plan(
|
||||
self,
|
||||
recipe: Recipe,
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Optional[Dict[str, AnalysisResult]] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""
|
||||
Generate an execution plan from a recipe.
|
||||
|
||||
Args:
|
||||
recipe: The parsed recipe
|
||||
input_hashes: Mapping from input name to content hash
|
||||
analysis: Analysis results for inputs (keyed by hash)
|
||||
seed: Random seed for deterministic planning
|
||||
|
||||
Returns:
|
||||
ExecutionPlan with pre-computed cache IDs
|
||||
"""
|
||||
logger.info(f"Planning recipe: {recipe.name}")
|
||||
|
||||
# Build node lookup
|
||||
nodes_by_id = {n.id: n for n in recipe.nodes}
|
||||
|
||||
# Topologically sort nodes
|
||||
sorted_ids = self._topological_sort(recipe.nodes)
|
||||
|
||||
# Resolve registry references
|
||||
registry_hashes = self._resolve_registry(recipe.registry)
|
||||
|
||||
# Build PlanInput objects from input_hashes
|
||||
plan_inputs = []
|
||||
for name, cid in input_hashes.items():
|
||||
# Try to find matching SOURCE node for media type
|
||||
media_type = "application/octet-stream"
|
||||
for node in recipe.nodes:
|
||||
if node.id == name and node.type == "SOURCE":
|
||||
media_type = _infer_media_type("SOURCE", node.config)
|
||||
break
|
||||
|
||||
plan_inputs.append(PlanInput(
|
||||
name=name,
|
||||
cache_id=cid,
|
||||
cid=cid,
|
||||
media_type=media_type,
|
||||
))
|
||||
|
||||
# Generate steps
|
||||
steps = []
|
||||
step_id_map = {} # Maps recipe node ID to step ID(s)
|
||||
step_name_map = {} # Maps recipe node ID to human-readable name
|
||||
analysis_cache_ids = {}
|
||||
|
||||
for node_id in sorted_ids:
|
||||
node = nodes_by_id[node_id]
|
||||
logger.debug(f"Processing node: {node.id} ({node.type})")
|
||||
|
||||
new_steps, output_step_id = self._process_node(
|
||||
node=node,
|
||||
step_id_map=step_id_map,
|
||||
step_name_map=step_name_map,
|
||||
input_hashes=input_hashes,
|
||||
registry_hashes=registry_hashes,
|
||||
analysis=analysis or {},
|
||||
recipe_name=recipe.name,
|
||||
)
|
||||
|
||||
steps.extend(new_steps)
|
||||
step_id_map[node_id] = output_step_id
|
||||
# Track human-readable name for this node
|
||||
if new_steps:
|
||||
step_name_map[node_id] = new_steps[-1].name
|
||||
|
||||
# Find output step
|
||||
output_step = step_id_map.get(recipe.output)
|
||||
if not output_step:
|
||||
raise ValueError(f"Output node '{recipe.output}' not found")
|
||||
|
||||
# Determine output name
|
||||
output_name = f"{recipe.name}.output"
|
||||
output_step_obj = next((s for s in steps if s.step_id == output_step), None)
|
||||
if output_step_obj and output_step_obj.outputs:
|
||||
output_name = output_step_obj.outputs[0].name
|
||||
|
||||
# Build analysis cache IDs
|
||||
if analysis:
|
||||
analysis_cache_ids = {
|
||||
h: a.cache_id for h, a in analysis.items()
|
||||
if a.cache_id
|
||||
}
|
||||
|
||||
# Create plan
|
||||
plan = ExecutionPlan(
|
||||
plan_id=None, # Computed in __post_init__
|
||||
name=f"{recipe.name}_plan",
|
||||
recipe_id=recipe.name,
|
||||
recipe_name=recipe.name,
|
||||
recipe_hash=recipe.recipe_hash,
|
||||
seed=seed,
|
||||
inputs=plan_inputs,
|
||||
steps=steps,
|
||||
output_step=output_step,
|
||||
output_name=output_name,
|
||||
analysis_cache_ids=analysis_cache_ids,
|
||||
input_hashes=input_hashes,
|
||||
metadata={
|
||||
"recipe_version": recipe.version,
|
||||
"recipe_description": recipe.description,
|
||||
"owner": recipe.owner,
|
||||
},
|
||||
)
|
||||
|
||||
# Compute all cache IDs and then generate outputs
|
||||
plan.compute_all_cache_ids()
|
||||
plan.compute_levels()
|
||||
|
||||
# Now add outputs to each step (needs cache_id to be computed first)
|
||||
self._add_step_outputs(plan, recipe.name)
|
||||
|
||||
# Recompute plan_id after outputs are added
|
||||
plan.plan_id = plan._compute_plan_id()
|
||||
|
||||
logger.info(f"Generated plan with {len(steps)} steps")
|
||||
return plan
|
||||
|
||||
def _add_step_outputs(self, plan: ExecutionPlan, recipe_name: str) -> None:
|
||||
"""Add output definitions to each step after cache_ids are computed."""
|
||||
for step in plan.steps:
|
||||
if step.outputs:
|
||||
continue # Already has outputs
|
||||
|
||||
# Generate output name from step name
|
||||
base_name = step.name or step.step_id
|
||||
output_name = f"{recipe_name}.{base_name}.out"
|
||||
|
||||
media_type = _infer_media_type(step.node_type, step.config)
|
||||
|
||||
step.add_output(
|
||||
name=output_name,
|
||||
media_type=media_type,
|
||||
index=0,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
def plan_from_yaml(
|
||||
self,
|
||||
yaml_content: str,
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Optional[Dict[str, AnalysisResult]] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""
|
||||
Generate plan from YAML string.
|
||||
|
||||
Args:
|
||||
yaml_content: Recipe YAML content
|
||||
input_hashes: Mapping from input name to content hash
|
||||
analysis: Analysis results
|
||||
|
||||
Returns:
|
||||
ExecutionPlan
|
||||
"""
|
||||
recipe = Recipe.from_yaml(yaml_content)
|
||||
return self.plan(recipe, input_hashes, analysis)
|
||||
|
||||
def plan_from_file(
|
||||
self,
|
||||
recipe_path: Path,
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Optional[Dict[str, AnalysisResult]] = None,
|
||||
) -> ExecutionPlan:
|
||||
"""
|
||||
Generate plan from recipe file.
|
||||
|
||||
Args:
|
||||
recipe_path: Path to recipe YAML file
|
||||
input_hashes: Mapping from input name to content hash
|
||||
analysis: Analysis results
|
||||
|
||||
Returns:
|
||||
ExecutionPlan
|
||||
"""
|
||||
recipe = Recipe.from_file(recipe_path)
|
||||
return self.plan(recipe, input_hashes, analysis)
|
||||
|
||||
def _topological_sort(self, nodes: List[RecipeNode]) -> List[str]:
|
||||
"""Topologically sort recipe nodes."""
|
||||
nodes_by_id = {n.id: n for n in nodes}
|
||||
visited = set()
|
||||
order = []
|
||||
|
||||
def visit(node_id: str):
|
||||
if node_id in visited:
|
||||
return
|
||||
if node_id not in nodes_by_id:
|
||||
return # External input
|
||||
visited.add(node_id)
|
||||
node = nodes_by_id[node_id]
|
||||
for input_id in node.inputs:
|
||||
visit(input_id)
|
||||
order.append(node_id)
|
||||
|
||||
for node in nodes:
|
||||
visit(node.id)
|
||||
|
||||
return order
|
||||
|
||||
def _resolve_registry(self, registry: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""
|
||||
Resolve registry references to content hashes.
|
||||
|
||||
Args:
|
||||
registry: Registry section from recipe
|
||||
|
||||
Returns:
|
||||
Mapping from name to content hash
|
||||
"""
|
||||
hashes = {}
|
||||
|
||||
# Assets
|
||||
for name, asset_data in registry.get("assets", {}).items():
|
||||
if isinstance(asset_data, dict) and "hash" in asset_data:
|
||||
hashes[name] = asset_data["hash"]
|
||||
elif isinstance(asset_data, str):
|
||||
hashes[name] = asset_data
|
||||
|
||||
# Effects
|
||||
for name, effect_data in registry.get("effects", {}).items():
|
||||
if isinstance(effect_data, dict) and "hash" in effect_data:
|
||||
hashes[f"effect:{name}"] = effect_data["hash"]
|
||||
elif isinstance(effect_data, str):
|
||||
hashes[f"effect:{name}"] = effect_data
|
||||
|
||||
return hashes
|
||||
|
||||
def _process_node(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
step_name_map: Dict[str, str],
|
||||
input_hashes: Dict[str, str],
|
||||
registry_hashes: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process a recipe node into execution steps.
|
||||
|
||||
Args:
|
||||
node: Recipe node to process
|
||||
step_id_map: Mapping from processed node IDs to step IDs
|
||||
step_name_map: Mapping from node IDs to human-readable names
|
||||
input_hashes: User-provided input hashes
|
||||
registry_hashes: Registry-resolved hashes
|
||||
analysis: Analysis results
|
||||
recipe_name: Name of the recipe (for generating readable names)
|
||||
|
||||
Returns:
|
||||
Tuple of (new steps, output step ID)
|
||||
"""
|
||||
# SOURCE nodes
|
||||
if node.type == "SOURCE":
|
||||
return self._process_source(node, input_hashes, registry_hashes, recipe_name)
|
||||
|
||||
# SOURCE_LIST nodes
|
||||
if node.type == "SOURCE_LIST":
|
||||
return self._process_source_list(node, input_hashes, recipe_name)
|
||||
|
||||
# ANALYZE nodes
|
||||
if node.type == "ANALYZE":
|
||||
return self._process_analyze(node, step_id_map, analysis, recipe_name)
|
||||
|
||||
# MAP nodes
|
||||
if node.type == "MAP":
|
||||
return self._process_map(node, step_id_map, input_hashes, analysis, recipe_name)
|
||||
|
||||
# SEQUENCE nodes (may use tree reduction)
|
||||
if node.type == "SEQUENCE":
|
||||
return self._process_sequence(node, step_id_map, recipe_name)
|
||||
|
||||
# SEGMENT_AT nodes
|
||||
if node.type == "SEGMENT_AT":
|
||||
return self._process_segment_at(node, step_id_map, analysis, recipe_name)
|
||||
|
||||
# Standard nodes (SEGMENT, RESIZE, TRANSFORM, LAYER, MUX, BLEND, etc.)
|
||||
return self._process_standard(node, step_id_map, recipe_name)
|
||||
|
||||
def _process_source(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
input_hashes: Dict[str, str],
|
||||
registry_hashes: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""Process SOURCE node."""
|
||||
config = dict(node.config)
|
||||
|
||||
# Variable input?
|
||||
if config.get("input"):
|
||||
# Look up in user-provided inputs
|
||||
if node.id not in input_hashes:
|
||||
raise ValueError(f"Missing input for SOURCE node '{node.id}'")
|
||||
cid = input_hashes[node.id]
|
||||
# Fixed asset from registry?
|
||||
elif config.get("asset"):
|
||||
asset_name = config["asset"]
|
||||
if asset_name not in registry_hashes:
|
||||
raise ValueError(f"Asset '{asset_name}' not found in registry")
|
||||
cid = registry_hashes[asset_name]
|
||||
else:
|
||||
raise ValueError(f"SOURCE node '{node.id}' has no input or asset")
|
||||
|
||||
# Human-readable name
|
||||
display_name = config.get("name", node.id)
|
||||
step_name = f"{recipe_name}.inputs.{display_name}" if recipe_name else display_name
|
||||
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="SOURCE",
|
||||
config={"input_ref": node.id, "cid": cid},
|
||||
input_steps=[],
|
||||
cache_id=cid, # SOURCE cache_id is just the content hash
|
||||
name=step_name,
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_source_list(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
input_hashes: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process SOURCE_LIST node.
|
||||
|
||||
Creates individual SOURCE steps for each item in the list.
|
||||
"""
|
||||
# Look for list input
|
||||
if node.id not in input_hashes:
|
||||
raise ValueError(f"Missing input for SOURCE_LIST node '{node.id}'")
|
||||
|
||||
input_value = input_hashes[node.id]
|
||||
|
||||
# Parse as comma-separated list if string
|
||||
if isinstance(input_value, str):
|
||||
items = [h.strip() for h in input_value.split(",")]
|
||||
else:
|
||||
items = list(input_value)
|
||||
|
||||
display_name = node.config.get("name", node.id)
|
||||
base_name = f"{recipe_name}.{display_name}" if recipe_name else display_name
|
||||
|
||||
steps = []
|
||||
for i, cid in enumerate(items):
|
||||
step = ExecutionStep(
|
||||
step_id=f"{node.id}_{i}",
|
||||
node_type="SOURCE",
|
||||
config={"input_ref": f"{node.id}[{i}]", "cid": cid},
|
||||
input_steps=[],
|
||||
cache_id=cid,
|
||||
name=f"{base_name}[{i}]",
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
# Return list marker as output
|
||||
list_step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="_LIST",
|
||||
config={"items": [s.step_id for s in steps]},
|
||||
input_steps=[s.step_id for s in steps],
|
||||
name=f"{base_name}.list",
|
||||
)
|
||||
steps.append(list_step)
|
||||
|
||||
return steps, list_step.step_id
|
||||
|
||||
def _process_analyze(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process ANALYZE node.
|
||||
|
||||
ANALYZE nodes reference pre-computed analysis results.
|
||||
"""
|
||||
input_step = step_id_map.get(node.inputs[0]) if node.inputs else None
|
||||
if not input_step:
|
||||
raise ValueError(f"ANALYZE node '{node.id}' has no input")
|
||||
|
||||
feature = node.config.get("feature", "all")
|
||||
step_name = f"{recipe_name}.analysis.{feature}" if recipe_name else f"analysis.{feature}"
|
||||
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="ANALYZE",
|
||||
config={
|
||||
"feature": feature,
|
||||
**node.config,
|
||||
},
|
||||
input_steps=[input_step],
|
||||
name=step_name,
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_map(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
input_hashes: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process MAP node - expand iteration over list.
|
||||
|
||||
MAP applies an operation to each item in a list.
|
||||
"""
|
||||
operation = node.config.get("operation", "TRANSFORM")
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
|
||||
# Get items input
|
||||
items_ref = node.config.get("items") or (
|
||||
node.inputs[0] if isinstance(node.inputs, list) else
|
||||
node.inputs.get("items") if isinstance(node.inputs, dict) else None
|
||||
)
|
||||
|
||||
if not items_ref:
|
||||
raise ValueError(f"MAP node '{node.id}' has no items input")
|
||||
|
||||
# Resolve items to list of step IDs
|
||||
if items_ref in step_id_map:
|
||||
# Reference to SOURCE_LIST output
|
||||
items_step = step_id_map[items_ref]
|
||||
# TODO: expand list items
|
||||
logger.warning(f"MAP node '{node.id}' references list step, expansion TBD")
|
||||
item_steps = [items_step]
|
||||
else:
|
||||
item_steps = [items_ref]
|
||||
|
||||
# Generate step for each item
|
||||
steps = []
|
||||
output_steps = []
|
||||
|
||||
for i, item_step in enumerate(item_steps):
|
||||
step_id = f"{node.id}_{i}"
|
||||
|
||||
if operation == "RANDOM_SLICE":
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="SEGMENT",
|
||||
config={
|
||||
"random": True,
|
||||
"seed_from": node.config.get("seed_from"),
|
||||
"index": i,
|
||||
},
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.slice[{i}]",
|
||||
)
|
||||
elif operation == "TRANSFORM":
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="TRANSFORM",
|
||||
config=node.config.get("effects", {}),
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.transform[{i}]",
|
||||
)
|
||||
elif operation == "ANALYZE":
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="ANALYZE",
|
||||
config={"feature": node.config.get("feature", "all")},
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.analyze[{i}]",
|
||||
)
|
||||
else:
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type=operation,
|
||||
config=node.config,
|
||||
input_steps=[item_step],
|
||||
name=f"{base_name}.{operation.lower()}[{i}]",
|
||||
)
|
||||
|
||||
steps.append(step)
|
||||
output_steps.append(step_id)
|
||||
|
||||
# Create list output
|
||||
list_step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="_LIST",
|
||||
config={"items": output_steps},
|
||||
input_steps=output_steps,
|
||||
name=f"{base_name}.results",
|
||||
)
|
||||
steps.append(list_step)
|
||||
|
||||
return steps, list_step.step_id
|
||||
|
||||
def _process_sequence(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process SEQUENCE node.
|
||||
|
||||
Uses tree reduction for parallel composition if enabled.
|
||||
"""
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
|
||||
# Resolve input steps
|
||||
input_steps = []
|
||||
for input_id in node.inputs:
|
||||
if input_id in step_id_map:
|
||||
input_steps.append(step_id_map[input_id])
|
||||
else:
|
||||
input_steps.append(input_id)
|
||||
|
||||
if len(input_steps) == 0:
|
||||
raise ValueError(f"SEQUENCE node '{node.id}' has no inputs")
|
||||
|
||||
if len(input_steps) == 1:
|
||||
# Single input, no sequence needed
|
||||
return [], input_steps[0]
|
||||
|
||||
transition_config = node.config.get("transition", {"type": "cut"})
|
||||
config = {"transition": transition_config}
|
||||
|
||||
if self.use_tree_reduction and len(input_steps) > 2:
|
||||
# Use tree reduction
|
||||
reduction_steps, output_id = reduce_sequence(
|
||||
input_steps,
|
||||
transition_config=config,
|
||||
id_prefix=node.id,
|
||||
)
|
||||
|
||||
steps = []
|
||||
for i, (step_id, inputs, step_config) in enumerate(reduction_steps):
|
||||
step = ExecutionStep(
|
||||
step_id=step_id,
|
||||
node_type="SEQUENCE",
|
||||
config=step_config,
|
||||
input_steps=inputs,
|
||||
name=f"{base_name}.reduce[{i}]",
|
||||
)
|
||||
steps.append(step)
|
||||
|
||||
return steps, output_id
|
||||
else:
|
||||
# Direct sequence
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="SEQUENCE",
|
||||
config=config,
|
||||
input_steps=input_steps,
|
||||
name=f"{base_name}.concat",
|
||||
)
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_segment_at(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
analysis: Dict[str, AnalysisResult],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""
|
||||
Process SEGMENT_AT node - cut at specific times.
|
||||
|
||||
Creates SEGMENT steps for each time range.
|
||||
"""
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
times_from = node.config.get("times_from")
|
||||
distribute = node.config.get("distribute", "round_robin")
|
||||
|
||||
# TODO: Resolve times from analysis
|
||||
# For now, create a placeholder
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type="SEGMENT_AT",
|
||||
config=node.config,
|
||||
input_steps=[step_id_map.get(i, i) for i in node.inputs],
|
||||
name=f"{base_name}.segment",
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
|
||||
def _process_standard(
|
||||
self,
|
||||
node: RecipeNode,
|
||||
step_id_map: Dict[str, str],
|
||||
recipe_name: str = "",
|
||||
) -> Tuple[List[ExecutionStep], str]:
|
||||
"""Process standard transformation/composition node."""
|
||||
base_name = f"{recipe_name}.{node.id}" if recipe_name else node.id
|
||||
input_steps = [step_id_map.get(i, i) for i in node.inputs]
|
||||
|
||||
step = ExecutionStep(
|
||||
step_id=node.id,
|
||||
node_type=node.type,
|
||||
config=node.config,
|
||||
input_steps=input_steps,
|
||||
name=f"{base_name}.{node.type.lower()}",
|
||||
)
|
||||
|
||||
return [step], step.step_id
|
||||
594
core/artdag/planning/schema.py
Normal file
594
core/artdag/planning/schema.py
Normal file
@@ -0,0 +1,594 @@
|
||||
# artdag/planning/schema.py
|
||||
"""
|
||||
Data structures for execution plans.
|
||||
|
||||
An ExecutionPlan contains all steps needed to execute a recipe,
|
||||
with pre-computed cache IDs for each step.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
# Cluster key for trust domains
|
||||
# Systems with the same key produce the same cache_ids and can share work
|
||||
# Systems with different keys have isolated cache namespaces
|
||||
CLUSTER_KEY: Optional[str] = os.environ.get("ARTDAG_CLUSTER_KEY")
|
||||
|
||||
|
||||
def set_cluster_key(key: Optional[str]) -> None:
|
||||
"""Set the cluster key programmatically."""
|
||||
global CLUSTER_KEY
|
||||
CLUSTER_KEY = key
|
||||
|
||||
|
||||
def get_cluster_key() -> Optional[str]:
|
||||
"""Get the current cluster key."""
|
||||
return CLUSTER_KEY
|
||||
|
||||
|
||||
def _stable_hash(data: Any, algorithm: str = "sha3_256") -> str:
|
||||
"""
|
||||
Create stable hash from arbitrary data.
|
||||
|
||||
If ARTDAG_CLUSTER_KEY is set, it's mixed into the hash to create
|
||||
isolated trust domains. Systems with the same key can share work;
|
||||
systems with different keys have separate cache namespaces.
|
||||
"""
|
||||
# Mix in cluster key if set
|
||||
if CLUSTER_KEY:
|
||||
data = {"_cluster_key": CLUSTER_KEY, "_data": data}
|
||||
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
hasher = hashlib.new(algorithm)
|
||||
hasher.update(json_str.encode())
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
"""Status of an execution step."""
|
||||
PENDING = "pending"
|
||||
CLAIMED = "claimed"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CACHED = "cached"
|
||||
FAILED = "failed"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepOutput:
|
||||
"""
|
||||
A single output from an execution step.
|
||||
|
||||
Nodes may produce multiple outputs (e.g., split_on_beats produces N segments).
|
||||
Each output has a human-readable name and a cache_id for storage.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name (e.g., "beats.split.segment[0]")
|
||||
cache_id: Content-addressed hash for caching
|
||||
media_type: MIME type of the output (e.g., "video/mp4", "audio/wav")
|
||||
index: Output index for multi-output nodes
|
||||
metadata: Optional additional metadata (time_range, etc.)
|
||||
"""
|
||||
name: str
|
||||
cache_id: str
|
||||
media_type: str = "application/octet-stream"
|
||||
index: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"cache_id": self.cache_id,
|
||||
"media_type": self.media_type,
|
||||
"index": self.index,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StepOutput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cache_id=data["cache_id"],
|
||||
media_type=data.get("media_type", "application/octet-stream"),
|
||||
index=data.get("index", 0),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepInput:
|
||||
"""
|
||||
Reference to an input for a step.
|
||||
|
||||
Inputs can reference outputs from other steps by name.
|
||||
|
||||
Attributes:
|
||||
name: Input slot name (e.g., "video", "audio", "segments")
|
||||
source: Source output name (e.g., "beats.split.segment[0]")
|
||||
cache_id: Resolved cache_id of the source (populated during planning)
|
||||
"""
|
||||
name: str
|
||||
source: str
|
||||
cache_id: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"source": self.source,
|
||||
"cache_id": self.cache_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StepInput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
source=data["source"],
|
||||
cache_id=data.get("cache_id"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionStep:
|
||||
"""
|
||||
A single step in the execution plan.
|
||||
|
||||
Each step has a pre-computed cache_id that uniquely identifies
|
||||
its output based on its configuration and input cache_ids.
|
||||
|
||||
Steps can produce multiple outputs (e.g., split_on_beats produces N segments).
|
||||
Each output has its own cache_id derived from the step's cache_id + index.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name relating to recipe (e.g., "beats.split")
|
||||
step_id: Unique identifier (hash) for this step
|
||||
node_type: The primitive type (SOURCE, SEQUENCE, TRANSFORM, etc.)
|
||||
config: Configuration for the primitive
|
||||
input_steps: IDs of steps this depends on (legacy, use inputs for new code)
|
||||
inputs: Structured input references with names and sources
|
||||
cache_id: Pre-computed cache ID (hash of config + input cache_ids)
|
||||
outputs: List of outputs this step produces
|
||||
estimated_duration: Optional estimated execution time
|
||||
level: Dependency level (0 = no dependencies, higher = more deps)
|
||||
"""
|
||||
step_id: str
|
||||
node_type: str
|
||||
config: Dict[str, Any]
|
||||
input_steps: List[str] = field(default_factory=list)
|
||||
inputs: List[StepInput] = field(default_factory=list)
|
||||
cache_id: Optional[str] = None
|
||||
outputs: List[StepOutput] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
estimated_duration: Optional[float] = None
|
||||
level: int = 0
|
||||
|
||||
def compute_cache_id(self, input_cache_ids: Dict[str, str]) -> str:
|
||||
"""
|
||||
Compute cache ID from configuration and input cache IDs.
|
||||
|
||||
cache_id = SHA3-256(node_type + config + sorted(input_cache_ids))
|
||||
|
||||
Args:
|
||||
input_cache_ids: Mapping from input step_id/name to their cache_id
|
||||
|
||||
Returns:
|
||||
The computed cache_id
|
||||
"""
|
||||
# Use structured inputs if available, otherwise fall back to input_steps
|
||||
if self.inputs:
|
||||
resolved_inputs = [
|
||||
inp.cache_id or input_cache_ids.get(inp.source, inp.source)
|
||||
for inp in sorted(self.inputs, key=lambda x: x.name)
|
||||
]
|
||||
else:
|
||||
resolved_inputs = [input_cache_ids.get(s, s) for s in sorted(self.input_steps)]
|
||||
|
||||
content = {
|
||||
"node_type": self.node_type,
|
||||
"config": self.config,
|
||||
"inputs": resolved_inputs,
|
||||
}
|
||||
self.cache_id = _stable_hash(content)
|
||||
return self.cache_id
|
||||
|
||||
def compute_output_cache_id(self, index: int) -> str:
|
||||
"""
|
||||
Compute cache ID for a specific output index.
|
||||
|
||||
output_cache_id = SHA3-256(step_cache_id + index)
|
||||
|
||||
Args:
|
||||
index: The output index
|
||||
|
||||
Returns:
|
||||
Cache ID for that output
|
||||
"""
|
||||
if not self.cache_id:
|
||||
raise ValueError("Step cache_id must be computed first")
|
||||
content = {"step_cache_id": self.cache_id, "output_index": index}
|
||||
return _stable_hash(content)
|
||||
|
||||
def add_output(
|
||||
self,
|
||||
name: str,
|
||||
media_type: str = "application/octet-stream",
|
||||
index: Optional[int] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> StepOutput:
|
||||
"""
|
||||
Add an output to this step.
|
||||
|
||||
Args:
|
||||
name: Human-readable output name
|
||||
media_type: MIME type of the output
|
||||
index: Output index (defaults to next available)
|
||||
metadata: Optional metadata
|
||||
|
||||
Returns:
|
||||
The created StepOutput
|
||||
"""
|
||||
if index is None:
|
||||
index = len(self.outputs)
|
||||
|
||||
cache_id = self.compute_output_cache_id(index)
|
||||
output = StepOutput(
|
||||
name=name,
|
||||
cache_id=cache_id,
|
||||
media_type=media_type,
|
||||
index=index,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self.outputs.append(output)
|
||||
return output
|
||||
|
||||
def get_output(self, index: int = 0) -> Optional[StepOutput]:
|
||||
"""Get output by index."""
|
||||
if index < len(self.outputs):
|
||||
return self.outputs[index]
|
||||
return None
|
||||
|
||||
def get_output_by_name(self, name: str) -> Optional[StepOutput]:
|
||||
"""Get output by name."""
|
||||
for output in self.outputs:
|
||||
if output.name == name:
|
||||
return output
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"step_id": self.step_id,
|
||||
"name": self.name,
|
||||
"node_type": self.node_type,
|
||||
"config": self.config,
|
||||
"input_steps": self.input_steps,
|
||||
"inputs": [inp.to_dict() for inp in self.inputs],
|
||||
"cache_id": self.cache_id,
|
||||
"outputs": [out.to_dict() for out in self.outputs],
|
||||
"estimated_duration": self.estimated_duration,
|
||||
"level": self.level,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionStep":
|
||||
inputs = [StepInput.from_dict(i) for i in data.get("inputs", [])]
|
||||
outputs = [StepOutput.from_dict(o) for o in data.get("outputs", [])]
|
||||
return cls(
|
||||
step_id=data["step_id"],
|
||||
node_type=data["node_type"],
|
||||
config=data.get("config", {}),
|
||||
input_steps=data.get("input_steps", []),
|
||||
inputs=inputs,
|
||||
cache_id=data.get("cache_id"),
|
||||
outputs=outputs,
|
||||
name=data.get("name"),
|
||||
estimated_duration=data.get("estimated_duration"),
|
||||
level=data.get("level", 0),
|
||||
)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ExecutionStep":
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanInput:
|
||||
"""
|
||||
An input to the execution plan.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable name from recipe (e.g., "source_video")
|
||||
cache_id: Content hash of the input file
|
||||
cid: Same as cache_id (for clarity)
|
||||
media_type: MIME type of the input
|
||||
"""
|
||||
name: str
|
||||
cache_id: str
|
||||
cid: str
|
||||
media_type: str = "application/octet-stream"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"cache_id": self.cache_id,
|
||||
"cid": self.cid,
|
||||
"media_type": self.media_type,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PlanInput":
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cache_id=data["cache_id"],
|
||||
cid=data.get("cid", data["cache_id"]),
|
||||
media_type=data.get("media_type", "application/octet-stream"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPlan:
|
||||
"""
|
||||
Complete execution plan for a recipe.
|
||||
|
||||
Contains all steps in topological order with pre-computed cache IDs.
|
||||
The plan is deterministic: same recipe + same inputs = same plan.
|
||||
|
||||
Attributes:
|
||||
name: Human-readable plan name from recipe
|
||||
plan_id: Hash of the entire plan (for deduplication)
|
||||
recipe_id: Source recipe identifier
|
||||
recipe_name: Human-readable recipe name
|
||||
recipe_hash: Hash of the recipe content
|
||||
seed: Random seed used for planning
|
||||
steps: List of steps in execution order
|
||||
output_step: ID of the final output step
|
||||
output_name: Human-readable name of the final output
|
||||
inputs: Structured input definitions
|
||||
analysis_cache_ids: Cache IDs of analysis results used
|
||||
input_hashes: Content hashes of input files (legacy, use inputs)
|
||||
created_at: When the plan was generated
|
||||
metadata: Optional additional metadata
|
||||
"""
|
||||
plan_id: Optional[str]
|
||||
recipe_id: str
|
||||
recipe_hash: str
|
||||
steps: List[ExecutionStep]
|
||||
output_step: str
|
||||
name: Optional[str] = None
|
||||
recipe_name: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
output_name: Optional[str] = None
|
||||
inputs: List[PlanInput] = field(default_factory=list)
|
||||
analysis_cache_ids: Dict[str, str] = field(default_factory=dict)
|
||||
input_hashes: Dict[str, str] = field(default_factory=dict)
|
||||
created_at: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now(timezone.utc).isoformat()
|
||||
if self.plan_id is None:
|
||||
self.plan_id = self._compute_plan_id()
|
||||
|
||||
def _compute_plan_id(self) -> str:
|
||||
"""Compute plan ID from contents."""
|
||||
content = {
|
||||
"recipe_hash": self.recipe_hash,
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"input_hashes": self.input_hashes,
|
||||
"analysis_cache_ids": self.analysis_cache_ids,
|
||||
}
|
||||
return _stable_hash(content)
|
||||
|
||||
def compute_all_cache_ids(self) -> None:
|
||||
"""
|
||||
Compute cache IDs for all steps in dependency order.
|
||||
|
||||
Must be called after all steps are added to ensure
|
||||
cache IDs propagate correctly through dependencies.
|
||||
"""
|
||||
# Build step lookup
|
||||
step_by_id = {s.step_id: s for s in self.steps}
|
||||
|
||||
# Cache IDs start with input hashes
|
||||
cache_ids = dict(self.input_hashes)
|
||||
|
||||
# Process in order (assumes topological order)
|
||||
for step in self.steps:
|
||||
# For SOURCE steps referencing inputs, use input hash
|
||||
if step.node_type == "SOURCE" and step.config.get("input_ref"):
|
||||
ref = step.config["input_ref"]
|
||||
if ref in self.input_hashes:
|
||||
step.cache_id = self.input_hashes[ref]
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
continue
|
||||
|
||||
# For other steps, compute from inputs
|
||||
input_cache_ids = {}
|
||||
for input_step_id in step.input_steps:
|
||||
if input_step_id in cache_ids:
|
||||
input_cache_ids[input_step_id] = cache_ids[input_step_id]
|
||||
elif input_step_id in step_by_id:
|
||||
# Step should have been processed already
|
||||
input_cache_ids[input_step_id] = step_by_id[input_step_id].cache_id
|
||||
else:
|
||||
raise ValueError(f"Input step {input_step_id} not found for {step.step_id}")
|
||||
|
||||
step.compute_cache_id(input_cache_ids)
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Recompute plan_id with final cache IDs
|
||||
self.plan_id = self._compute_plan_id()
|
||||
|
||||
def compute_levels(self) -> int:
|
||||
"""
|
||||
Compute dependency levels for all steps.
|
||||
|
||||
Level 0 = no dependencies (can start immediately)
|
||||
Level N = depends on steps at level N-1
|
||||
|
||||
Returns:
|
||||
Maximum level (number of sequential dependency levels)
|
||||
"""
|
||||
step_by_id = {s.step_id: s for s in self.steps}
|
||||
levels = {}
|
||||
|
||||
def compute_level(step_id: str) -> int:
|
||||
if step_id in levels:
|
||||
return levels[step_id]
|
||||
|
||||
step = step_by_id.get(step_id)
|
||||
if step is None:
|
||||
return 0 # Input from outside the plan
|
||||
|
||||
if not step.input_steps:
|
||||
levels[step_id] = 0
|
||||
step.level = 0
|
||||
return 0
|
||||
|
||||
max_input_level = max(compute_level(s) for s in step.input_steps)
|
||||
level = max_input_level + 1
|
||||
levels[step_id] = level
|
||||
step.level = level
|
||||
return level
|
||||
|
||||
for step in self.steps:
|
||||
compute_level(step.step_id)
|
||||
|
||||
return max(levels.values()) if levels else 0
|
||||
|
||||
def get_steps_by_level(self) -> Dict[int, List[ExecutionStep]]:
|
||||
"""
|
||||
Group steps by dependency level.
|
||||
|
||||
Steps at the same level can execute in parallel.
|
||||
|
||||
Returns:
|
||||
Dict mapping level -> list of steps at that level
|
||||
"""
|
||||
by_level: Dict[int, List[ExecutionStep]] = {}
|
||||
for step in self.steps:
|
||||
by_level.setdefault(step.level, []).append(step)
|
||||
return by_level
|
||||
|
||||
def get_step(self, step_id: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by ID."""
|
||||
for step in self.steps:
|
||||
if step.step_id == step_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_step_by_cache_id(self, cache_id: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by cache ID."""
|
||||
for step in self.steps:
|
||||
if step.cache_id == cache_id:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_step_by_name(self, name: str) -> Optional[ExecutionStep]:
|
||||
"""Get step by human-readable name."""
|
||||
for step in self.steps:
|
||||
if step.name == name:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_all_outputs(self) -> Dict[str, StepOutput]:
|
||||
"""
|
||||
Get all outputs from all steps, keyed by output name.
|
||||
|
||||
Returns:
|
||||
Dict mapping output name -> StepOutput
|
||||
"""
|
||||
outputs = {}
|
||||
for step in self.steps:
|
||||
for output in step.outputs:
|
||||
outputs[output.name] = output
|
||||
return outputs
|
||||
|
||||
def get_output_cache_ids(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get mapping of output names to cache IDs.
|
||||
|
||||
Returns:
|
||||
Dict mapping output name -> cache_id
|
||||
"""
|
||||
return {
|
||||
output.name: output.cache_id
|
||||
for step in self.steps
|
||||
for output in step.outputs
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"plan_id": self.plan_id,
|
||||
"name": self.name,
|
||||
"recipe_id": self.recipe_id,
|
||||
"recipe_name": self.recipe_name,
|
||||
"recipe_hash": self.recipe_hash,
|
||||
"seed": self.seed,
|
||||
"inputs": [i.to_dict() for i in self.inputs],
|
||||
"steps": [s.to_dict() for s in self.steps],
|
||||
"output_step": self.output_step,
|
||||
"output_name": self.output_name,
|
||||
"analysis_cache_ids": self.analysis_cache_ids,
|
||||
"input_hashes": self.input_hashes,
|
||||
"created_at": self.created_at,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExecutionPlan":
|
||||
inputs = [PlanInput.from_dict(i) for i in data.get("inputs", [])]
|
||||
return cls(
|
||||
plan_id=data.get("plan_id"),
|
||||
name=data.get("name"),
|
||||
recipe_id=data["recipe_id"],
|
||||
recipe_name=data.get("recipe_name"),
|
||||
recipe_hash=data["recipe_hash"],
|
||||
seed=data.get("seed"),
|
||||
inputs=inputs,
|
||||
steps=[ExecutionStep.from_dict(s) for s in data.get("steps", [])],
|
||||
output_step=data["output_step"],
|
||||
output_name=data.get("output_name"),
|
||||
analysis_cache_ids=data.get("analysis_cache_ids", {}),
|
||||
input_hashes=data.get("input_hashes", {}),
|
||||
created_at=data.get("created_at"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def to_json(self, indent: int = 2) -> str:
|
||||
return json.dumps(self.to_dict(), indent=indent)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "ExecutionPlan":
|
||||
return cls.from_dict(json.loads(json_str))
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Get a human-readable summary of the plan."""
|
||||
by_level = self.get_steps_by_level()
|
||||
max_level = max(by_level.keys()) if by_level else 0
|
||||
|
||||
lines = [
|
||||
f"Execution Plan: {self.plan_id[:16]}...",
|
||||
f"Recipe: {self.recipe_id}",
|
||||
f"Steps: {len(self.steps)}",
|
||||
f"Levels: {max_level + 1}",
|
||||
"",
|
||||
]
|
||||
|
||||
for level in sorted(by_level.keys()):
|
||||
steps = by_level[level]
|
||||
lines.append(f"Level {level}: ({len(steps)} steps, can run in parallel)")
|
||||
for step in steps:
|
||||
cache_status = f"[{step.cache_id[:8]}...]" if step.cache_id else "[no cache_id]"
|
||||
lines.append(f" - {step.step_id}: {step.node_type} {cache_status}")
|
||||
|
||||
return "\n".join(lines)
|
||||
231
core/artdag/planning/tree_reduction.py
Normal file
231
core/artdag/planning/tree_reduction.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# artdag/planning/tree_reduction.py
|
||||
"""
|
||||
Tree reduction for parallel composition.
|
||||
|
||||
Instead of sequential pairwise composition:
|
||||
A → AB → ABC → ABCD (3 sequential steps)
|
||||
|
||||
Use parallel tree reduction:
|
||||
A ─┬─ AB ─┬─ ABCD
|
||||
B ─┘ │
|
||||
C ─┬─ CD ─┘
|
||||
D ─┘
|
||||
|
||||
This reduces O(N) to O(log N) levels of sequential dependency.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReductionNode:
|
||||
"""A node in the reduction tree."""
|
||||
node_id: str
|
||||
input_ids: List[str]
|
||||
level: int
|
||||
position: int # Position within level
|
||||
|
||||
|
||||
class TreeReducer:
|
||||
"""
|
||||
Generates tree reduction plans for parallel composition.
|
||||
|
||||
Used to convert N inputs into optimal parallel SEQUENCE operations.
|
||||
"""
|
||||
|
||||
def __init__(self, node_type: str = "SEQUENCE"):
|
||||
"""
|
||||
Initialize the reducer.
|
||||
|
||||
Args:
|
||||
node_type: The composition node type (SEQUENCE, AUDIO_MIX, etc.)
|
||||
"""
|
||||
self.node_type = node_type
|
||||
|
||||
def reduce(
|
||||
self,
|
||||
input_ids: List[str],
|
||||
id_prefix: str = "reduce",
|
||||
) -> Tuple[List[ReductionNode], str]:
|
||||
"""
|
||||
Generate a tree reduction plan for the given inputs.
|
||||
|
||||
Args:
|
||||
input_ids: List of input step IDs to reduce
|
||||
id_prefix: Prefix for generated node IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of reduction nodes, final output node ID)
|
||||
"""
|
||||
if len(input_ids) == 0:
|
||||
raise ValueError("Cannot reduce empty input list")
|
||||
|
||||
if len(input_ids) == 1:
|
||||
# Single input, no reduction needed
|
||||
return [], input_ids[0]
|
||||
|
||||
if len(input_ids) == 2:
|
||||
# Two inputs, single reduction
|
||||
node_id = f"{id_prefix}_final"
|
||||
node = ReductionNode(
|
||||
node_id=node_id,
|
||||
input_ids=input_ids,
|
||||
level=0,
|
||||
position=0,
|
||||
)
|
||||
return [node], node_id
|
||||
|
||||
# Build tree levels
|
||||
nodes = []
|
||||
current_level = list(input_ids)
|
||||
level_num = 0
|
||||
|
||||
while len(current_level) > 1:
|
||||
next_level = []
|
||||
position = 0
|
||||
|
||||
# Pair up nodes at current level
|
||||
i = 0
|
||||
while i < len(current_level):
|
||||
if i + 1 < len(current_level):
|
||||
# Pair available
|
||||
left = current_level[i]
|
||||
right = current_level[i + 1]
|
||||
node_id = f"{id_prefix}_L{level_num}_P{position}"
|
||||
node = ReductionNode(
|
||||
node_id=node_id,
|
||||
input_ids=[left, right],
|
||||
level=level_num,
|
||||
position=position,
|
||||
)
|
||||
nodes.append(node)
|
||||
next_level.append(node_id)
|
||||
i += 2
|
||||
else:
|
||||
# Odd one out, promote to next level
|
||||
next_level.append(current_level[i])
|
||||
i += 1
|
||||
|
||||
position += 1
|
||||
|
||||
current_level = next_level
|
||||
level_num += 1
|
||||
|
||||
# The last remaining node is the output
|
||||
output_id = current_level[0]
|
||||
|
||||
# Rename final node for clarity
|
||||
if nodes and nodes[-1].node_id == output_id:
|
||||
nodes[-1].node_id = f"{id_prefix}_final"
|
||||
output_id = f"{id_prefix}_final"
|
||||
|
||||
return nodes, output_id
|
||||
|
||||
def get_reduction_depth(self, n: int) -> int:
|
||||
"""
|
||||
Calculate the number of reduction levels needed.
|
||||
|
||||
Args:
|
||||
n: Number of inputs
|
||||
|
||||
Returns:
|
||||
Number of sequential reduction levels (log2(n) ceiling)
|
||||
"""
|
||||
if n <= 1:
|
||||
return 0
|
||||
return math.ceil(math.log2(n))
|
||||
|
||||
def get_total_operations(self, n: int) -> int:
|
||||
"""
|
||||
Calculate total number of reduction operations.
|
||||
|
||||
Args:
|
||||
n: Number of inputs
|
||||
|
||||
Returns:
|
||||
Total composition operations (always n-1)
|
||||
"""
|
||||
return max(0, n - 1)
|
||||
|
||||
def reduce_with_config(
|
||||
self,
|
||||
input_ids: List[str],
|
||||
base_config: Dict[str, Any],
|
||||
id_prefix: str = "reduce",
|
||||
) -> Tuple[List[Tuple[ReductionNode, Dict[str, Any]]], str]:
|
||||
"""
|
||||
Generate reduction plan with configuration for each node.
|
||||
|
||||
Args:
|
||||
input_ids: List of input step IDs
|
||||
base_config: Base configuration to use for each reduction
|
||||
id_prefix: Prefix for generated node IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (node, config) pairs, final output ID)
|
||||
"""
|
||||
nodes, output_id = self.reduce(input_ids, id_prefix)
|
||||
result = [(node, dict(base_config)) for node in nodes]
|
||||
return result, output_id
|
||||
|
||||
|
||||
def reduce_sequence(
|
||||
input_ids: List[str],
|
||||
transition_config: Dict[str, Any] = None,
|
||||
id_prefix: str = "seq",
|
||||
) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]:
|
||||
"""
|
||||
Convenience function for SEQUENCE reduction.
|
||||
|
||||
Args:
|
||||
input_ids: Input step IDs to sequence
|
||||
transition_config: Transition configuration (default: cut)
|
||||
id_prefix: Prefix for generated step IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (step_id, inputs, config), final step ID)
|
||||
"""
|
||||
if transition_config is None:
|
||||
transition_config = {"transition": {"type": "cut"}}
|
||||
|
||||
reducer = TreeReducer("SEQUENCE")
|
||||
nodes, output_id = reducer.reduce(input_ids, id_prefix)
|
||||
|
||||
result = [
|
||||
(node.node_id, node.input_ids, dict(transition_config))
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
return result, output_id
|
||||
|
||||
|
||||
def reduce_audio_mix(
|
||||
input_ids: List[str],
|
||||
mix_config: Dict[str, Any] = None,
|
||||
id_prefix: str = "mix",
|
||||
) -> Tuple[List[Tuple[str, List[str], Dict[str, Any]]], str]:
|
||||
"""
|
||||
Convenience function for AUDIO_MIX reduction.
|
||||
|
||||
Args:
|
||||
input_ids: Input step IDs to mix
|
||||
mix_config: Mix configuration
|
||||
id_prefix: Prefix for generated step IDs
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (step_id, inputs, config), final step ID)
|
||||
"""
|
||||
if mix_config is None:
|
||||
mix_config = {"normalize": True}
|
||||
|
||||
reducer = TreeReducer("AUDIO_MIX")
|
||||
nodes, output_id = reducer.reduce(input_ids, id_prefix)
|
||||
|
||||
result = [
|
||||
(node.node_id, node.input_ids, dict(mix_config))
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
return result, output_id
|
||||
20
core/artdag/registry/__init__.py
Normal file
20
core/artdag/registry/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# primitive/registry/__init__.py
|
||||
"""
|
||||
Art DAG Registry.
|
||||
|
||||
The registry is the foundational data structure that maps named assets
|
||||
to their source paths or content-addressed IDs. Assets in the registry
|
||||
can be referenced by DAGs.
|
||||
|
||||
Example:
|
||||
registry = Registry("/path/to/registry")
|
||||
registry.add("cat", "/path/to/cat.jpg", tags=["animal", "photo"])
|
||||
|
||||
# Later, in a DAG:
|
||||
builder = DAGBuilder()
|
||||
cat = builder.source(registry.get("cat").path)
|
||||
"""
|
||||
|
||||
from .registry import Registry, Asset
|
||||
|
||||
__all__ = ["Registry", "Asset"]
|
||||
294
core/artdag/registry/registry.py
Normal file
294
core/artdag/registry/registry.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# primitive/registry/registry.py
|
||||
"""
|
||||
Asset registry for the Art DAG.
|
||||
|
||||
The registry stores named assets with metadata, enabling:
|
||||
- Named references to source files
|
||||
- Tagging and categorization
|
||||
- Content-addressed deduplication
|
||||
- Asset discovery and search
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
def _file_hash(path: Path, algorithm: str = "sha3_256") -> str:
|
||||
"""
|
||||
Compute content hash of a file.
|
||||
|
||||
Uses SHA-3 (Keccak) by default for quantum resistance.
|
||||
SHA-3-256 provides 128-bit security against quantum attacks (Grover's algorithm).
|
||||
|
||||
Args:
|
||||
path: File to hash
|
||||
algorithm: Hash algorithm (sha3_256, sha3_512, sha256, blake2b)
|
||||
|
||||
Returns:
|
||||
Full hex digest (no truncation)
|
||||
"""
|
||||
hasher = hashlib.new(algorithm)
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
hasher.update(chunk)
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Asset:
|
||||
"""
|
||||
A registered asset in the Art DAG.
|
||||
|
||||
The cid is the true identifier. URL and local_path are
|
||||
locations where the content can be fetched.
|
||||
|
||||
Attributes:
|
||||
name: Unique name for the asset
|
||||
cid: SHA-3-256 hash - the canonical identifier
|
||||
url: Public URL (canonical location)
|
||||
local_path: Optional local path (for local execution)
|
||||
asset_type: Type of asset (image, video, audio, etc.)
|
||||
tags: List of tags for categorization
|
||||
metadata: Additional metadata (dimensions, duration, etc.)
|
||||
created_at: Timestamp when added to registry
|
||||
"""
|
||||
name: str
|
||||
cid: str
|
||||
url: Optional[str] = None
|
||||
local_path: Optional[Path] = None
|
||||
asset_type: str = "unknown"
|
||||
tags: List[str] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
@property
|
||||
def path(self) -> Optional[Path]:
|
||||
"""Backwards compatible path property."""
|
||||
return self.local_path
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = {
|
||||
"name": self.name,
|
||||
"cid": self.cid,
|
||||
"asset_type": self.asset_type,
|
||||
"tags": self.tags,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
if self.url:
|
||||
data["url"] = self.url
|
||||
if self.local_path:
|
||||
data["local_path"] = str(self.local_path)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Asset":
|
||||
local_path = data.get("local_path") or data.get("path") # backwards compat
|
||||
return cls(
|
||||
name=data["name"],
|
||||
cid=data["cid"],
|
||||
url=data.get("url"),
|
||||
local_path=Path(local_path) if local_path else None,
|
||||
asset_type=data.get("asset_type", "unknown"),
|
||||
tags=data.get("tags", []),
|
||||
metadata=data.get("metadata", {}),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
)
|
||||
|
||||
|
||||
class Registry:
|
||||
"""
|
||||
The Art DAG registry.
|
||||
|
||||
Stores named assets that can be referenced by DAGs.
|
||||
|
||||
Structure:
|
||||
registry_dir/
|
||||
registry.json # Index of all assets
|
||||
assets/ # Optional: copied asset files
|
||||
<hash>/
|
||||
<filename>
|
||||
"""
|
||||
|
||||
def __init__(self, registry_dir: Path | str, copy_assets: bool = False):
|
||||
"""
|
||||
Initialize the registry.
|
||||
|
||||
Args:
|
||||
registry_dir: Directory to store registry data
|
||||
copy_assets: If True, copy assets into registry (content-addressed)
|
||||
"""
|
||||
self.registry_dir = Path(registry_dir)
|
||||
self.registry_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.copy_assets = copy_assets
|
||||
self._assets: Dict[str, Asset] = {}
|
||||
self._load()
|
||||
|
||||
def _index_path(self) -> Path:
|
||||
return self.registry_dir / "registry.json"
|
||||
|
||||
def _assets_dir(self) -> Path:
|
||||
return self.registry_dir / "assets"
|
||||
|
||||
def _load(self):
|
||||
"""Load registry from disk."""
|
||||
index_path = self._index_path()
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
data = json.load(f)
|
||||
self._assets = {
|
||||
name: Asset.from_dict(asset_data)
|
||||
for name, asset_data in data.get("assets", {}).items()
|
||||
}
|
||||
|
||||
def _save(self):
|
||||
"""Save registry to disk."""
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"assets": {name: asset.to_dict() for name, asset in self._assets.items()},
|
||||
}
|
||||
with open(self._index_path(), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def add(
|
||||
self,
|
||||
name: str,
|
||||
cid: str,
|
||||
url: str = None,
|
||||
local_path: Path | str = None,
|
||||
asset_type: str = None,
|
||||
tags: List[str] = None,
|
||||
metadata: Dict[str, Any] = None,
|
||||
) -> Asset:
|
||||
"""
|
||||
Add an asset to the registry.
|
||||
|
||||
Args:
|
||||
name: Unique name for the asset
|
||||
cid: SHA-3-256 hash of the content (the canonical identifier)
|
||||
url: Public URL where the asset can be fetched
|
||||
local_path: Optional local path (for local execution)
|
||||
asset_type: Type of asset (image, video, audio, etc.)
|
||||
tags: List of tags for categorization
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
The created Asset
|
||||
"""
|
||||
# Auto-detect asset type from URL or path extension
|
||||
if asset_type is None:
|
||||
ext = None
|
||||
if url:
|
||||
ext = Path(url.split("?")[0]).suffix.lower()
|
||||
elif local_path:
|
||||
ext = Path(local_path).suffix.lower()
|
||||
if ext:
|
||||
type_map = {
|
||||
".jpg": "image", ".jpeg": "image", ".png": "image",
|
||||
".gif": "image", ".webp": "image", ".bmp": "image",
|
||||
".mp4": "video", ".mkv": "video", ".avi": "video",
|
||||
".mov": "video", ".webm": "video",
|
||||
".mp3": "audio", ".wav": "audio", ".flac": "audio",
|
||||
".ogg": "audio", ".aac": "audio",
|
||||
}
|
||||
asset_type = type_map.get(ext, "unknown")
|
||||
else:
|
||||
asset_type = "unknown"
|
||||
|
||||
asset = Asset(
|
||||
name=name,
|
||||
cid=cid,
|
||||
url=url,
|
||||
local_path=Path(local_path).resolve() if local_path else None,
|
||||
asset_type=asset_type,
|
||||
tags=tags or [],
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._assets[name] = asset
|
||||
self._save()
|
||||
return asset
|
||||
|
||||
def add_from_file(
|
||||
self,
|
||||
name: str,
|
||||
path: Path | str,
|
||||
url: str = None,
|
||||
asset_type: str = None,
|
||||
tags: List[str] = None,
|
||||
metadata: Dict[str, Any] = None,
|
||||
) -> Asset:
|
||||
"""
|
||||
Add an asset from a local file (computes hash automatically).
|
||||
|
||||
Args:
|
||||
name: Unique name for the asset
|
||||
path: Path to the source file
|
||||
url: Optional public URL
|
||||
asset_type: Type of asset (auto-detected if not provided)
|
||||
tags: List of tags for categorization
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
The created Asset
|
||||
"""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Asset file not found: {path}")
|
||||
|
||||
cid = _file_hash(path)
|
||||
|
||||
return self.add(
|
||||
name=name,
|
||||
cid=cid,
|
||||
url=url,
|
||||
local_path=path,
|
||||
asset_type=asset_type,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def get(self, name: str) -> Optional[Asset]:
|
||||
"""Get an asset by name."""
|
||||
return self._assets.get(name)
|
||||
|
||||
def remove(self, name: str) -> bool:
|
||||
"""Remove an asset from the registry."""
|
||||
if name not in self._assets:
|
||||
return False
|
||||
del self._assets[name]
|
||||
self._save()
|
||||
return True
|
||||
|
||||
def list(self) -> List[Asset]:
|
||||
"""List all assets."""
|
||||
return list(self._assets.values())
|
||||
|
||||
def find_by_tag(self, tag: str) -> List[Asset]:
|
||||
"""Find assets with a specific tag."""
|
||||
return [a for a in self._assets.values() if tag in a.tags]
|
||||
|
||||
def find_by_type(self, asset_type: str) -> List[Asset]:
|
||||
"""Find assets of a specific type."""
|
||||
return [a for a in self._assets.values() if a.asset_type == asset_type]
|
||||
|
||||
def find_by_hash(self, cid: str) -> Optional[Asset]:
|
||||
"""Find an asset by content hash."""
|
||||
for asset in self._assets.values():
|
||||
if asset.cid == cid:
|
||||
return asset
|
||||
return None
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._assets
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._assets)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._assets.values())
|
||||
253
core/artdag/server.py
Normal file
253
core/artdag/server.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# primitive/server.py
|
||||
"""
|
||||
HTTP server for primitive execution engine.
|
||||
|
||||
Provides a REST API for submitting DAGs and retrieving results.
|
||||
|
||||
Endpoints:
|
||||
POST /execute - Submit DAG for execution
|
||||
GET /status/:id - Get execution status
|
||||
GET /result/:id - Get execution result
|
||||
GET /cache/stats - Get cache statistics
|
||||
DELETE /cache - Clear cache
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .dag import DAG
|
||||
from .engine import Engine, ExecutionResult
|
||||
from . import nodes # Register built-in executors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Job:
|
||||
"""A pending or completed execution job."""
|
||||
job_id: str
|
||||
dag: DAG
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
result: Optional[ExecutionResult] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PrimitiveServer:
|
||||
"""
|
||||
HTTP server for the primitive engine.
|
||||
|
||||
Usage:
|
||||
server = PrimitiveServer(cache_dir="/tmp/primitive_cache", port=8080)
|
||||
server.start() # Blocking
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Path | str, host: str = "127.0.0.1", port: int = 8080):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.engine = Engine(self.cache_dir)
|
||||
self.jobs: Dict[str, Job] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def submit_job(self, dag: DAG) -> str:
|
||||
"""Submit a DAG for execution, return job ID."""
|
||||
job_id = str(uuid.uuid4())[:8]
|
||||
job = Job(job_id=job_id, dag=dag)
|
||||
|
||||
with self._lock:
|
||||
self.jobs[job_id] = job
|
||||
|
||||
# Execute in background thread
|
||||
thread = threading.Thread(target=self._execute_job, args=(job_id,))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return job_id
|
||||
|
||||
def _execute_job(self, job_id: str):
|
||||
"""Execute a job in background."""
|
||||
with self._lock:
|
||||
job = self.jobs.get(job_id)
|
||||
if not job:
|
||||
return
|
||||
job.status = "running"
|
||||
|
||||
try:
|
||||
result = self.engine.execute(job.dag)
|
||||
with self._lock:
|
||||
job.result = result
|
||||
job.status = "completed" if result.success else "failed"
|
||||
if not result.success:
|
||||
job.error = result.error
|
||||
except Exception as e:
|
||||
logger.exception(f"Job {job_id} failed")
|
||||
with self._lock:
|
||||
job.status = "failed"
|
||||
job.error = str(e)
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
"""Get job by ID."""
|
||||
with self._lock:
|
||||
return self.jobs.get(job_id)
|
||||
|
||||
def _create_handler(server_instance):
|
||||
"""Create request handler with access to server instance."""
|
||||
|
||||
class RequestHandler(BaseHTTPRequestHandler):
|
||||
server_ref = server_instance
|
||||
|
||||
def log_message(self, format, *args):
|
||||
logger.debug(format % args)
|
||||
|
||||
def _send_json(self, data: Any, status: int = 200):
|
||||
self.send_response(status)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(data).encode())
|
||||
|
||||
def _send_error(self, message: str, status: int = 400):
|
||||
self._send_json({"error": message}, status)
|
||||
|
||||
def do_GET(self):
|
||||
parsed = urlparse(self.path)
|
||||
path = parsed.path
|
||||
|
||||
if path.startswith("/status/"):
|
||||
job_id = path[8:]
|
||||
job = self.server_ref.get_job(job_id)
|
||||
if not job:
|
||||
self._send_error("Job not found", 404)
|
||||
return
|
||||
self._send_json({
|
||||
"job_id": job.job_id,
|
||||
"status": job.status,
|
||||
"error": job.error,
|
||||
})
|
||||
|
||||
elif path.startswith("/result/"):
|
||||
job_id = path[8:]
|
||||
job = self.server_ref.get_job(job_id)
|
||||
if not job:
|
||||
self._send_error("Job not found", 404)
|
||||
return
|
||||
if job.status == "pending" or job.status == "running":
|
||||
self._send_json({
|
||||
"job_id": job.job_id,
|
||||
"status": job.status,
|
||||
"ready": False,
|
||||
})
|
||||
return
|
||||
|
||||
result = job.result
|
||||
self._send_json({
|
||||
"job_id": job.job_id,
|
||||
"status": job.status,
|
||||
"ready": True,
|
||||
"success": result.success if result else False,
|
||||
"output_path": str(result.output_path) if result and result.output_path else None,
|
||||
"error": job.error,
|
||||
"execution_time": result.execution_time if result else 0,
|
||||
"nodes_executed": result.nodes_executed if result else 0,
|
||||
"nodes_cached": result.nodes_cached if result else 0,
|
||||
})
|
||||
|
||||
elif path == "/cache/stats":
|
||||
stats = self.server_ref.engine.get_cache_stats()
|
||||
self._send_json({
|
||||
"total_entries": stats.total_entries,
|
||||
"total_size_bytes": stats.total_size_bytes,
|
||||
"hits": stats.hits,
|
||||
"misses": stats.misses,
|
||||
"hit_rate": stats.hit_rate,
|
||||
})
|
||||
|
||||
elif path == "/health":
|
||||
self._send_json({"status": "ok"})
|
||||
|
||||
else:
|
||||
self._send_error("Not found", 404)
|
||||
|
||||
def do_POST(self):
|
||||
if self.path == "/execute":
|
||||
try:
|
||||
content_length = int(self.headers.get("Content-Length", 0))
|
||||
body = self.rfile.read(content_length).decode()
|
||||
data = json.loads(body)
|
||||
|
||||
dag = DAG.from_dict(data)
|
||||
job_id = self.server_ref.submit_job(dag)
|
||||
|
||||
self._send_json({
|
||||
"job_id": job_id,
|
||||
"status": "pending",
|
||||
})
|
||||
except json.JSONDecodeError as e:
|
||||
self._send_error(f"Invalid JSON: {e}")
|
||||
except Exception as e:
|
||||
self._send_error(str(e), 500)
|
||||
else:
|
||||
self._send_error("Not found", 404)
|
||||
|
||||
def do_DELETE(self):
|
||||
if self.path == "/cache":
|
||||
self.server_ref.engine.clear_cache()
|
||||
self._send_json({"status": "cleared"})
|
||||
else:
|
||||
self._send_error("Not found", 404)
|
||||
|
||||
return RequestHandler
|
||||
|
||||
def start(self):
|
||||
"""Start the HTTP server (blocking)."""
|
||||
handler = self._create_handler()
|
||||
server = HTTPServer((self.host, self.port), handler)
|
||||
logger.info(f"Primitive server starting on {self.host}:{self.port}")
|
||||
print(f"Primitive server running on http://{self.host}:{self.port}")
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down...")
|
||||
server.shutdown()
|
||||
|
||||
def start_background(self) -> threading.Thread:
|
||||
"""Start the server in a background thread."""
|
||||
thread = threading.Thread(target=self.start)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Primitive execution server")
|
||||
parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
|
||||
parser.add_argument("--port", type=int, default=8080, help="Port to bind to")
|
||||
parser.add_argument("--cache-dir", default="/tmp/primitive_cache", help="Cache directory")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
|
||||
server = PrimitiveServer(
|
||||
cache_dir=args.cache_dir,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
server.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
75
core/artdag/sexp/__init__.py
Normal file
75
core/artdag/sexp/__init__.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
S-expression parsing, compilation, and planning for ArtDAG.
|
||||
|
||||
This module provides:
|
||||
- parser: Parse S-expression text into Python data structures
|
||||
- compiler: Compile recipe S-expressions into DAG format
|
||||
- planner: Generate execution plans from recipes
|
||||
"""
|
||||
|
||||
from .parser import (
|
||||
parse,
|
||||
parse_all,
|
||||
serialize,
|
||||
Symbol,
|
||||
Keyword,
|
||||
ParseError,
|
||||
)
|
||||
|
||||
from .compiler import (
|
||||
compile_recipe,
|
||||
compile_string,
|
||||
CompiledRecipe,
|
||||
CompileError,
|
||||
ParamDef,
|
||||
_parse_params,
|
||||
)
|
||||
|
||||
from .planner import (
|
||||
create_plan,
|
||||
ExecutionPlanSexp,
|
||||
PlanStep,
|
||||
step_to_task_sexp,
|
||||
task_cache_id,
|
||||
)
|
||||
|
||||
from .scheduler import (
|
||||
PlanScheduler,
|
||||
PlanResult,
|
||||
StepResult,
|
||||
schedule_plan,
|
||||
step_to_sexp,
|
||||
step_sexp_to_string,
|
||||
verify_step_cache_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Parser
|
||||
'parse',
|
||||
'parse_all',
|
||||
'serialize',
|
||||
'Symbol',
|
||||
'Keyword',
|
||||
'ParseError',
|
||||
# Compiler
|
||||
'compile_recipe',
|
||||
'compile_string',
|
||||
'CompiledRecipe',
|
||||
'CompileError',
|
||||
'ParamDef',
|
||||
'_parse_params',
|
||||
# Planner
|
||||
'create_plan',
|
||||
'ExecutionPlanSexp',
|
||||
'PlanStep',
|
||||
'step_to_task_sexp',
|
||||
'task_cache_id',
|
||||
# Scheduler
|
||||
'PlanScheduler',
|
||||
'PlanResult',
|
||||
'StepResult',
|
||||
'schedule_plan',
|
||||
'step_to_sexp',
|
||||
'step_sexp_to_string',
|
||||
'verify_step_cache_id',
|
||||
]
|
||||
2463
core/artdag/sexp/compiler.py
Normal file
2463
core/artdag/sexp/compiler.py
Normal file
File diff suppressed because it is too large
Load Diff
337
core/artdag/sexp/effect_loader.py
Normal file
337
core/artdag/sexp/effect_loader.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Sexp effect loader.
|
||||
|
||||
Loads sexp effect definitions (define-effect forms) and creates
|
||||
frame processors that evaluate the sexp body with primitives.
|
||||
|
||||
Effects must use :params syntax:
|
||||
|
||||
(define-effect name
|
||||
:params (
|
||||
(param1 :type int :default 8 :range [4 32] :desc "description")
|
||||
(param2 :type string :default "value" :desc "description")
|
||||
)
|
||||
body)
|
||||
|
||||
For effects with no parameters, use empty :params ():
|
||||
|
||||
(define-effect name
|
||||
:params ()
|
||||
body)
|
||||
|
||||
Unknown parameters passed to effects will raise an error.
|
||||
|
||||
Usage:
|
||||
loader = SexpEffectLoader()
|
||||
effect_fn = loader.load_effect_file(Path("effects/ascii_art.sexp"))
|
||||
output = effect_fn(input_path, output_path, config)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .parser import parse_all, Symbol, Keyword
|
||||
from .evaluator import evaluate
|
||||
from .primitives import PRIMITIVES
|
||||
from .compiler import ParamDef, _parse_params, CompileError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_define_effect(sexp) -> tuple:
|
||||
"""
|
||||
Parse a define-effect form.
|
||||
|
||||
Required syntax:
|
||||
(define-effect name
|
||||
:params (
|
||||
(param1 :type int :default 8 :range [4 32] :desc "description")
|
||||
)
|
||||
body)
|
||||
|
||||
Effects MUST use :params syntax. Legacy ((param default) ...) syntax is not supported.
|
||||
|
||||
Returns (name, params_with_defaults, param_defs, body)
|
||||
where param_defs is a list of ParamDef objects
|
||||
"""
|
||||
if not isinstance(sexp, list) or len(sexp) < 3:
|
||||
raise ValueError(f"Invalid define-effect form: {sexp}")
|
||||
|
||||
head = sexp[0]
|
||||
if not (isinstance(head, Symbol) and head.name == "define-effect"):
|
||||
raise ValueError(f"Expected define-effect, got {head}")
|
||||
|
||||
name = sexp[1]
|
||||
if isinstance(name, Symbol):
|
||||
name = name.name
|
||||
|
||||
params_with_defaults = {}
|
||||
param_defs: List[ParamDef] = []
|
||||
body = None
|
||||
found_params = False
|
||||
|
||||
# Parse :params and body
|
||||
i = 2
|
||||
while i < len(sexp):
|
||||
item = sexp[i]
|
||||
if isinstance(item, Keyword) and item.name == "params":
|
||||
# :params syntax
|
||||
if i + 1 >= len(sexp):
|
||||
raise ValueError(f"Effect '{name}': Missing params list after :params keyword")
|
||||
try:
|
||||
param_defs = _parse_params(sexp[i + 1])
|
||||
# Build params_with_defaults from ParamDef objects
|
||||
for pd in param_defs:
|
||||
params_with_defaults[pd.name] = pd.default
|
||||
except CompileError as e:
|
||||
raise ValueError(f"Effect '{name}': Error parsing :params: {e}")
|
||||
found_params = True
|
||||
i += 2
|
||||
elif isinstance(item, Keyword):
|
||||
# Skip other keywords we don't recognize
|
||||
i += 2
|
||||
elif body is None:
|
||||
# First non-keyword item is the body
|
||||
if isinstance(item, list) and item:
|
||||
first_elem = item[0]
|
||||
# Check for legacy syntax and reject it
|
||||
if isinstance(first_elem, list) and len(first_elem) >= 2:
|
||||
raise ValueError(
|
||||
f"Effect '{name}': Legacy parameter syntax ((name default) ...) is not supported. "
|
||||
f"Use :params block instead:\n"
|
||||
f" :params (\n"
|
||||
f" (param_name :type int :default 0 :desc \"description\")\n"
|
||||
f" )"
|
||||
)
|
||||
body = item
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if body is None:
|
||||
raise ValueError(f"Effect '{name}': No body found")
|
||||
|
||||
if not found_params:
|
||||
raise ValueError(
|
||||
f"Effect '{name}': Missing :params block. Effects must declare parameters.\n"
|
||||
f"For effects with no parameters, use empty :params ():\n"
|
||||
f" (define-effect {name}\n"
|
||||
f" :params ()\n"
|
||||
f" body)"
|
||||
)
|
||||
|
||||
return name, params_with_defaults, param_defs, body
|
||||
|
||||
|
||||
def _create_process_frame(
|
||||
effect_name: str,
|
||||
params_with_defaults: Dict[str, Any],
|
||||
param_defs: List[ParamDef],
|
||||
body: Any,
|
||||
) -> Callable:
|
||||
"""
|
||||
Create a process_frame function that evaluates the sexp body.
|
||||
|
||||
The function signature is: (frame, params, state) -> (frame, state)
|
||||
"""
|
||||
import math
|
||||
|
||||
def process_frame(frame: np.ndarray, params: Dict[str, Any], state: Any):
|
||||
"""Evaluate sexp effect body on a frame."""
|
||||
# Build environment with primitives
|
||||
env = dict(PRIMITIVES)
|
||||
|
||||
# Add math functions
|
||||
env["floor"] = lambda x: int(math.floor(x))
|
||||
env["ceil"] = lambda x: int(math.ceil(x))
|
||||
env["round"] = lambda x: int(round(x))
|
||||
env["abs"] = abs
|
||||
env["min"] = min
|
||||
env["max"] = max
|
||||
env["sqrt"] = math.sqrt
|
||||
env["sin"] = math.sin
|
||||
env["cos"] = math.cos
|
||||
|
||||
# Add list operations
|
||||
env["list"] = lambda *args: tuple(args)
|
||||
env["nth"] = lambda coll, i: coll[int(i)] if coll else None
|
||||
|
||||
# Bind frame
|
||||
env["frame"] = frame
|
||||
|
||||
# Validate that all provided params are known
|
||||
known_params = set(params_with_defaults.keys())
|
||||
for k in params.keys():
|
||||
if k not in known_params:
|
||||
raise ValueError(
|
||||
f"Effect '{effect_name}': Unknown parameter '{k}'. "
|
||||
f"Valid parameters are: {', '.join(sorted(known_params)) if known_params else '(none)'}"
|
||||
)
|
||||
|
||||
# Bind parameters (defaults + overrides from config)
|
||||
for param_name, default in params_with_defaults.items():
|
||||
# Use config value if provided, otherwise default
|
||||
if param_name in params:
|
||||
env[param_name] = params[param_name]
|
||||
elif default is not None:
|
||||
env[param_name] = default
|
||||
|
||||
# Evaluate the body
|
||||
try:
|
||||
result = evaluate(body, env)
|
||||
if isinstance(result, np.ndarray):
|
||||
return result, state
|
||||
else:
|
||||
logger.warning(f"Effect {effect_name} returned {type(result)}, expected ndarray")
|
||||
return frame, state
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating effect {effect_name}: {e}")
|
||||
raise
|
||||
|
||||
return process_frame
|
||||
|
||||
|
||||
def load_sexp_effect(source: str, base_path: Optional[Path] = None) -> tuple:
|
||||
"""
|
||||
Load a sexp effect from source code.
|
||||
|
||||
Args:
|
||||
source: Sexp source code
|
||||
base_path: Base path for resolving relative imports
|
||||
|
||||
Returns:
|
||||
(effect_name, process_frame_fn, params_with_defaults, param_defs)
|
||||
where param_defs is a list of ParamDef objects for introspection
|
||||
"""
|
||||
exprs = parse_all(source)
|
||||
|
||||
# Find define-effect form
|
||||
define_effect = None
|
||||
if isinstance(exprs, list):
|
||||
for expr in exprs:
|
||||
if isinstance(expr, list) and expr and isinstance(expr[0], Symbol):
|
||||
if expr[0].name == "define-effect":
|
||||
define_effect = expr
|
||||
break
|
||||
elif isinstance(exprs, list) and exprs and isinstance(exprs[0], Symbol):
|
||||
if exprs[0].name == "define-effect":
|
||||
define_effect = exprs
|
||||
|
||||
if not define_effect:
|
||||
raise ValueError("No define-effect form found in sexp effect")
|
||||
|
||||
name, params_with_defaults, param_defs, body = _parse_define_effect(define_effect)
|
||||
process_frame = _create_process_frame(name, params_with_defaults, param_defs, body)
|
||||
|
||||
return name, process_frame, params_with_defaults, param_defs
|
||||
|
||||
|
||||
def load_sexp_effect_file(path: Path) -> tuple:
|
||||
"""
|
||||
Load a sexp effect from file.
|
||||
|
||||
Returns:
|
||||
(effect_name, process_frame_fn, params_with_defaults, param_defs)
|
||||
where param_defs is a list of ParamDef objects for introspection
|
||||
"""
|
||||
source = path.read_text()
|
||||
return load_sexp_effect(source, base_path=path.parent)
|
||||
|
||||
|
||||
class SexpEffectLoader:
|
||||
"""
|
||||
Loader for sexp effect definitions.
|
||||
|
||||
Creates effect functions compatible with the EffectExecutor.
|
||||
"""
|
||||
|
||||
def __init__(self, recipe_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize loader.
|
||||
|
||||
Args:
|
||||
recipe_dir: Base directory for resolving relative effect paths
|
||||
"""
|
||||
self.recipe_dir = recipe_dir or Path.cwd()
|
||||
# Cache loaded effects with their param_defs for introspection
|
||||
self._loaded_effects: Dict[str, tuple] = {}
|
||||
|
||||
def load_effect_path(self, effect_path: str) -> Callable:
|
||||
"""
|
||||
Load a sexp effect from a relative path.
|
||||
|
||||
Args:
|
||||
effect_path: Relative path to effect .sexp file
|
||||
|
||||
Returns:
|
||||
Effect function (input_path, output_path, config) -> output_path
|
||||
"""
|
||||
from ..effects.frame_processor import process_video
|
||||
|
||||
full_path = self.recipe_dir / effect_path
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"Sexp effect not found: {full_path}")
|
||||
|
||||
name, process_frame_fn, params_defaults, param_defs = load_sexp_effect_file(full_path)
|
||||
logger.info(f"Loaded sexp effect: {name} from {effect_path}")
|
||||
|
||||
# Cache for introspection
|
||||
self._loaded_effects[effect_path] = (name, params_defaults, param_defs)
|
||||
|
||||
def effect_fn(input_path: Path, output_path: Path, config: Dict[str, Any]) -> Path:
|
||||
"""Run sexp effect via frame processor."""
|
||||
# Extract params (excluding internal keys)
|
||||
params = dict(params_defaults) # Start with defaults
|
||||
for k, v in config.items():
|
||||
if k not in ("effect", "cid", "hash", "effect_path", "_binding"):
|
||||
params[k] = v
|
||||
|
||||
# Get bindings if present
|
||||
bindings = {}
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict) and value.get("_resolved_values"):
|
||||
bindings[key] = value["_resolved_values"]
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
actual_output = output_path.with_suffix(".mp4")
|
||||
|
||||
process_video(
|
||||
input_path=input_path,
|
||||
output_path=actual_output,
|
||||
process_frame=process_frame_fn,
|
||||
params=params,
|
||||
bindings=bindings,
|
||||
)
|
||||
|
||||
logger.info(f"Processed sexp effect '{name}' from {effect_path}")
|
||||
return actual_output
|
||||
|
||||
return effect_fn
|
||||
|
||||
def get_effect_params(self, effect_path: str) -> List[ParamDef]:
|
||||
"""
|
||||
Get parameter definitions for an effect.
|
||||
|
||||
Args:
|
||||
effect_path: Relative path to effect .sexp file
|
||||
|
||||
Returns:
|
||||
List of ParamDef objects describing the effect's parameters
|
||||
"""
|
||||
if effect_path not in self._loaded_effects:
|
||||
# Load the effect to get its params
|
||||
full_path = self.recipe_dir / effect_path
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"Sexp effect not found: {full_path}")
|
||||
name, _, params_defaults, param_defs = load_sexp_effect_file(full_path)
|
||||
self._loaded_effects[effect_path] = (name, params_defaults, param_defs)
|
||||
|
||||
return self._loaded_effects[effect_path][2]
|
||||
|
||||
|
||||
def get_sexp_effect_loader(recipe_dir: Optional[Path] = None) -> SexpEffectLoader:
|
||||
"""Get a sexp effect loader instance."""
|
||||
return SexpEffectLoader(recipe_dir)
|
||||
869
core/artdag/sexp/evaluator.py
Normal file
869
core/artdag/sexp/evaluator.py
Normal file
@@ -0,0 +1,869 @@
|
||||
"""
|
||||
Expression evaluator for S-expression DSL.
|
||||
|
||||
Supports:
|
||||
- Arithmetic: +, -, *, /, mod, sqrt, pow, abs, floor, ceil, round, min, max, clamp
|
||||
- Comparison: =, <, >, <=, >=
|
||||
- Logic: and, or, not
|
||||
- Predicates: odd?, even?, zero?, nil?
|
||||
- Conditionals: if, cond, case
|
||||
- Data: list, dict/map construction, get
|
||||
- Lambda calls
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Callable
|
||||
from .parser import Symbol, Keyword, Lambda, Binding
|
||||
|
||||
|
||||
class EvalError(Exception):
|
||||
"""Error during expression evaluation."""
|
||||
pass
|
||||
|
||||
|
||||
# Built-in functions
|
||||
BUILTINS: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def builtin(name: str):
|
||||
"""Decorator to register a builtin function."""
|
||||
def decorator(fn):
|
||||
BUILTINS[name] = fn
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
|
||||
@builtin("+")
|
||||
def add(*args):
|
||||
return sum(args)
|
||||
|
||||
|
||||
@builtin("-")
|
||||
def sub(a, b=None):
|
||||
if b is None:
|
||||
return -a
|
||||
return a - b
|
||||
|
||||
|
||||
@builtin("*")
|
||||
def mul(*args):
|
||||
result = 1
|
||||
for a in args:
|
||||
result *= a
|
||||
return result
|
||||
|
||||
|
||||
@builtin("/")
|
||||
def div(a, b):
|
||||
return a / b
|
||||
|
||||
|
||||
@builtin("mod")
|
||||
def mod(a, b):
|
||||
return a % b
|
||||
|
||||
|
||||
@builtin("sqrt")
|
||||
def sqrt(x):
|
||||
return x ** 0.5
|
||||
|
||||
|
||||
@builtin("pow")
|
||||
def power(x, n):
|
||||
return x ** n
|
||||
|
||||
|
||||
@builtin("abs")
|
||||
def absolute(x):
|
||||
return abs(x)
|
||||
|
||||
|
||||
@builtin("floor")
|
||||
def floor_fn(x):
|
||||
import math
|
||||
return math.floor(x)
|
||||
|
||||
|
||||
@builtin("ceil")
|
||||
def ceil_fn(x):
|
||||
import math
|
||||
return math.ceil(x)
|
||||
|
||||
|
||||
@builtin("round")
|
||||
def round_fn(x, ndigits=0):
|
||||
return round(x, int(ndigits))
|
||||
|
||||
|
||||
@builtin("min")
|
||||
def min_fn(*args):
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
return min(args[0])
|
||||
return min(args)
|
||||
|
||||
|
||||
@builtin("max")
|
||||
def max_fn(*args):
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
return max(args[0])
|
||||
return max(args)
|
||||
|
||||
|
||||
@builtin("clamp")
|
||||
def clamp(x, lo, hi):
|
||||
return max(lo, min(hi, x))
|
||||
|
||||
|
||||
@builtin("=")
|
||||
def eq(a, b):
|
||||
return a == b
|
||||
|
||||
|
||||
@builtin("<")
|
||||
def lt(a, b):
|
||||
return a < b
|
||||
|
||||
|
||||
@builtin(">")
|
||||
def gt(a, b):
|
||||
return a > b
|
||||
|
||||
|
||||
@builtin("<=")
|
||||
def lte(a, b):
|
||||
return a <= b
|
||||
|
||||
|
||||
@builtin(">=")
|
||||
def gte(a, b):
|
||||
return a >= b
|
||||
|
||||
|
||||
@builtin("odd?")
|
||||
def is_odd(n):
|
||||
return n % 2 == 1
|
||||
|
||||
|
||||
@builtin("even?")
|
||||
def is_even(n):
|
||||
return n % 2 == 0
|
||||
|
||||
|
||||
@builtin("zero?")
|
||||
def is_zero(n):
|
||||
return n == 0
|
||||
|
||||
|
||||
@builtin("nil?")
|
||||
def is_nil(x):
|
||||
return x is None
|
||||
|
||||
|
||||
@builtin("not")
|
||||
def not_fn(x):
|
||||
return not x
|
||||
|
||||
|
||||
@builtin("inc")
|
||||
def inc(n):
|
||||
return n + 1
|
||||
|
||||
|
||||
@builtin("dec")
|
||||
def dec(n):
|
||||
return n - 1
|
||||
|
||||
|
||||
@builtin("list")
|
||||
def make_list(*args):
|
||||
return list(args)
|
||||
|
||||
|
||||
@builtin("assert")
|
||||
def assert_true(condition, message="Assertion failed"):
|
||||
if not condition:
|
||||
raise RuntimeError(f"Assertion error: {message}")
|
||||
return True
|
||||
|
||||
|
||||
@builtin("get")
|
||||
def get(coll, key, default=None):
|
||||
if isinstance(coll, dict):
|
||||
# Try the key directly first
|
||||
result = coll.get(key, None)
|
||||
if result is not None:
|
||||
return result
|
||||
# If key is a Keyword, also try its string name (for Python dicts with string keys)
|
||||
if isinstance(key, Keyword):
|
||||
result = coll.get(key.name, None)
|
||||
if result is not None:
|
||||
return result
|
||||
# Return the default
|
||||
return default
|
||||
elif isinstance(coll, list):
|
||||
return coll[key] if 0 <= key < len(coll) else default
|
||||
else:
|
||||
raise EvalError(f"get: expected dict or list, got {type(coll).__name__}: {str(coll)[:100]}")
|
||||
|
||||
|
||||
@builtin("dict?")
|
||||
def is_dict(x):
|
||||
return isinstance(x, dict)
|
||||
|
||||
|
||||
@builtin("list?")
|
||||
def is_list(x):
|
||||
return isinstance(x, list)
|
||||
|
||||
|
||||
@builtin("nil?")
|
||||
def is_nil(x):
|
||||
return x is None
|
||||
|
||||
|
||||
@builtin("number?")
|
||||
def is_number(x):
|
||||
return isinstance(x, (int, float))
|
||||
|
||||
|
||||
@builtin("string?")
|
||||
def is_string(x):
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
@builtin("len")
|
||||
def length(coll):
|
||||
return len(coll)
|
||||
|
||||
|
||||
@builtin("first")
|
||||
def first(coll):
|
||||
return coll[0] if coll else None
|
||||
|
||||
|
||||
@builtin("last")
|
||||
def last(coll):
|
||||
return coll[-1] if coll else None
|
||||
|
||||
|
||||
@builtin("chunk-every")
|
||||
def chunk_every(coll, n):
|
||||
"""Split collection into chunks of n elements."""
|
||||
n = int(n)
|
||||
return [coll[i:i+n] for i in range(0, len(coll), n)]
|
||||
|
||||
|
||||
@builtin("rest")
|
||||
def rest(coll):
|
||||
return coll[1:] if coll else []
|
||||
|
||||
|
||||
@builtin("nth")
|
||||
def nth(coll, n):
|
||||
return coll[n] if 0 <= n < len(coll) else None
|
||||
|
||||
|
||||
@builtin("concat")
|
||||
def concat(*colls):
|
||||
"""Concatenate multiple lists/sequences."""
|
||||
result = []
|
||||
for c in colls:
|
||||
if c is not None:
|
||||
result.extend(c)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("cons")
|
||||
def cons(x, coll):
|
||||
"""Prepend x to collection."""
|
||||
return [x] + list(coll) if coll else [x]
|
||||
|
||||
|
||||
@builtin("append")
|
||||
def append(coll, x):
|
||||
"""Append x to collection."""
|
||||
return list(coll) + [x] if coll else [x]
|
||||
|
||||
|
||||
@builtin("range")
|
||||
def make_range(start, end, step=1):
|
||||
"""Create a range of numbers."""
|
||||
return list(range(int(start), int(end), int(step)))
|
||||
|
||||
|
||||
@builtin("zip-pairs")
|
||||
def zip_pairs(coll):
|
||||
"""Zip consecutive pairs: [a,b,c,d] -> [[a,b],[b,c],[c,d]]."""
|
||||
if not coll or len(coll) < 2:
|
||||
return []
|
||||
return [[coll[i], coll[i+1]] for i in range(len(coll)-1)]
|
||||
|
||||
|
||||
@builtin("dict")
|
||||
def make_dict(*pairs):
|
||||
"""Create dict from key-value pairs: (dict :a 1 :b 2)."""
|
||||
result = {}
|
||||
i = 0
|
||||
while i < len(pairs) - 1:
|
||||
key = pairs[i]
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result[key] = pairs[i + 1]
|
||||
i += 2
|
||||
return result
|
||||
|
||||
|
||||
@builtin("keys")
|
||||
def keys(d):
|
||||
"""Get the keys of a dict as a list."""
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"keys: expected dict, got {type(d).__name__}")
|
||||
return list(d.keys())
|
||||
|
||||
|
||||
@builtin("vals")
|
||||
def vals(d):
|
||||
"""Get the values of a dict as a list."""
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"vals: expected dict, got {type(d).__name__}")
|
||||
return list(d.values())
|
||||
|
||||
|
||||
@builtin("merge")
|
||||
def merge(*dicts):
|
||||
"""Merge multiple dicts, later dicts override earlier."""
|
||||
result = {}
|
||||
for d in dicts:
|
||||
if d is not None:
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"merge: expected dict, got {type(d).__name__}")
|
||||
result.update(d)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("assoc")
|
||||
def assoc(d, *pairs):
|
||||
"""Associate keys with values in a dict: (assoc d :a 1 :b 2)."""
|
||||
if d is None:
|
||||
result = {}
|
||||
elif isinstance(d, dict):
|
||||
result = dict(d)
|
||||
else:
|
||||
raise EvalError(f"assoc: expected dict or nil, got {type(d).__name__}")
|
||||
|
||||
i = 0
|
||||
while i < len(pairs) - 1:
|
||||
key = pairs[i]
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result[key] = pairs[i + 1]
|
||||
i += 2
|
||||
return result
|
||||
|
||||
|
||||
@builtin("dissoc")
|
||||
def dissoc(d, *keys_to_remove):
|
||||
"""Remove keys from a dict: (dissoc d :a :b)."""
|
||||
if d is None:
|
||||
return {}
|
||||
if not isinstance(d, dict):
|
||||
raise EvalError(f"dissoc: expected dict or nil, got {type(d).__name__}")
|
||||
|
||||
result = dict(d)
|
||||
for key in keys_to_remove:
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result.pop(key, None)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("into")
|
||||
def into(target, coll):
|
||||
"""Convert a collection into another collection type.
|
||||
|
||||
(into [] {:a 1 :b 2}) -> [["a" 1] ["b" 2]]
|
||||
(into {} [[:a 1] [:b 2]]) -> {"a": 1, "b": 2}
|
||||
(into [] [1 2 3]) -> [1 2 3]
|
||||
"""
|
||||
if isinstance(target, list):
|
||||
if isinstance(coll, dict):
|
||||
return [[k, v] for k, v in coll.items()]
|
||||
elif isinstance(coll, (list, tuple)):
|
||||
return list(coll)
|
||||
else:
|
||||
raise EvalError(f"into: cannot convert {type(coll).__name__} into list")
|
||||
elif isinstance(target, dict):
|
||||
if isinstance(coll, dict):
|
||||
return dict(coll)
|
||||
elif isinstance(coll, (list, tuple)):
|
||||
result = {}
|
||||
for item in coll:
|
||||
if isinstance(item, (list, tuple)) and len(item) >= 2:
|
||||
key = item[0]
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
result[key] = item[1]
|
||||
else:
|
||||
raise EvalError(f"into: expected [key value] pairs, got {item}")
|
||||
return result
|
||||
else:
|
||||
raise EvalError(f"into: cannot convert {type(coll).__name__} into dict")
|
||||
else:
|
||||
raise EvalError(f"into: unsupported target type {type(target).__name__}")
|
||||
|
||||
|
||||
@builtin("filter")
|
||||
def filter_fn(pred, coll):
|
||||
"""Filter collection by predicate. Pred must be a lambda."""
|
||||
if not isinstance(pred, Lambda):
|
||||
raise EvalError(f"filter: expected lambda as predicate, got {type(pred).__name__}")
|
||||
|
||||
result = []
|
||||
for item in coll:
|
||||
# Evaluate predicate with item
|
||||
local_env = {}
|
||||
if pred.closure:
|
||||
local_env.update(pred.closure)
|
||||
local_env[pred.params[0]] = item
|
||||
|
||||
# Inline evaluation of pred.body
|
||||
from . import evaluator
|
||||
if evaluator.evaluate(pred.body, local_env):
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
@builtin("some")
|
||||
def some(pred, coll):
|
||||
"""Return first truthy value of (pred item) for items in coll, or nil."""
|
||||
if not isinstance(pred, Lambda):
|
||||
raise EvalError(f"some: expected lambda as predicate, got {type(pred).__name__}")
|
||||
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if pred.closure:
|
||||
local_env.update(pred.closure)
|
||||
local_env[pred.params[0]] = item
|
||||
|
||||
from . import evaluator
|
||||
result = evaluator.evaluate(pred.body, local_env)
|
||||
if result:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
@builtin("every?")
|
||||
def every(pred, coll):
|
||||
"""Return true if (pred item) is truthy for all items in coll."""
|
||||
if not isinstance(pred, Lambda):
|
||||
raise EvalError(f"every?: expected lambda as predicate, got {type(pred).__name__}")
|
||||
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if pred.closure:
|
||||
local_env.update(pred.closure)
|
||||
local_env[pred.params[0]] = item
|
||||
|
||||
from . import evaluator
|
||||
if not evaluator.evaluate(pred.body, local_env):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@builtin("empty?")
|
||||
def is_empty(coll):
|
||||
"""Return true if collection is empty."""
|
||||
if coll is None:
|
||||
return True
|
||||
return len(coll) == 0
|
||||
|
||||
|
||||
@builtin("contains?")
|
||||
def contains(coll, key):
|
||||
"""Check if collection contains key (for dicts) or element (for lists)."""
|
||||
if isinstance(coll, dict):
|
||||
if isinstance(key, Keyword):
|
||||
key = key.name
|
||||
return key in coll
|
||||
elif isinstance(coll, (list, tuple)):
|
||||
return key in coll
|
||||
return False
|
||||
|
||||
|
||||
def evaluate(expr: Any, env: Dict[str, Any] = None) -> Any:
|
||||
"""
|
||||
Evaluate an S-expression in the given environment.
|
||||
|
||||
Args:
|
||||
expr: The expression to evaluate
|
||||
env: Variable bindings (name -> value)
|
||||
|
||||
Returns:
|
||||
The result of evaluation
|
||||
"""
|
||||
if env is None:
|
||||
env = {}
|
||||
|
||||
# Literals
|
||||
if isinstance(expr, (int, float, str, bool)) or expr is None:
|
||||
return expr
|
||||
|
||||
# Symbol - variable lookup
|
||||
if isinstance(expr, Symbol):
|
||||
name = expr.name
|
||||
if name in env:
|
||||
return env[name]
|
||||
if name in BUILTINS:
|
||||
return BUILTINS[name]
|
||||
if name == "true":
|
||||
return True
|
||||
if name == "false":
|
||||
return False
|
||||
if name == "nil":
|
||||
return None
|
||||
raise EvalError(f"Undefined symbol: {name}")
|
||||
|
||||
# Keyword - return as-is (used as map keys)
|
||||
if isinstance(expr, Keyword):
|
||||
return expr.name
|
||||
|
||||
# Lambda - return as-is (it's a value)
|
||||
if isinstance(expr, Lambda):
|
||||
return expr
|
||||
|
||||
# Binding - return as-is (resolved at execution time)
|
||||
if isinstance(expr, Binding):
|
||||
return expr
|
||||
|
||||
# Dict literal
|
||||
if isinstance(expr, dict):
|
||||
return {k: evaluate(v, env) for k, v in expr.items()}
|
||||
|
||||
# List - function call or special form
|
||||
if isinstance(expr, list):
|
||||
if not expr:
|
||||
return []
|
||||
|
||||
head = expr[0]
|
||||
|
||||
# If head is a string/number/etc (not Symbol), treat as data list
|
||||
if not isinstance(head, (Symbol, Lambda, list)):
|
||||
return [evaluate(x, env) for x in expr]
|
||||
|
||||
# Special forms
|
||||
if isinstance(head, Symbol):
|
||||
name = head.name
|
||||
|
||||
# if - conditional
|
||||
if name == "if":
|
||||
if len(expr) < 3:
|
||||
raise EvalError("if requires condition and then-branch")
|
||||
cond_result = evaluate(expr[1], env)
|
||||
if cond_result:
|
||||
return evaluate(expr[2], env)
|
||||
elif len(expr) > 3:
|
||||
return evaluate(expr[3], env)
|
||||
return None
|
||||
|
||||
# cond - multi-way conditional
|
||||
# Supports both Clojure style: (cond test1 result1 test2 result2 :else default)
|
||||
# and Scheme style: (cond (test1 result1) (test2 result2) (else default))
|
||||
if name == "cond":
|
||||
clauses = expr[1:]
|
||||
# Check if Clojure style (flat list) or Scheme style (nested pairs)
|
||||
# Scheme style: first clause is (test result) - exactly 2 elements
|
||||
# Clojure style: first clause is a test expression like (= x 0) - 3+ elements
|
||||
first_is_scheme_clause = (
|
||||
clauses and
|
||||
isinstance(clauses[0], list) and
|
||||
len(clauses[0]) == 2 and
|
||||
not (isinstance(clauses[0][0], Symbol) and clauses[0][0].name in ('=', '<', '>', '<=', '>=', '!=', 'not=', 'and', 'or'))
|
||||
)
|
||||
if first_is_scheme_clause:
|
||||
# Scheme style: ((test result) ...)
|
||||
for clause in clauses:
|
||||
if not isinstance(clause, list) or len(clause) < 2:
|
||||
raise EvalError("cond clause must be (test result)")
|
||||
test = clause[0]
|
||||
# Check for else/default
|
||||
if isinstance(test, Symbol) and test.name in ("else", ":else"):
|
||||
return evaluate(clause[1], env)
|
||||
if isinstance(test, Keyword) and test.name == "else":
|
||||
return evaluate(clause[1], env)
|
||||
if evaluate(test, env):
|
||||
return evaluate(clause[1], env)
|
||||
else:
|
||||
# Clojure style: test1 result1 test2 result2 ...
|
||||
i = 0
|
||||
while i < len(clauses) - 1:
|
||||
test = clauses[i]
|
||||
result = clauses[i + 1]
|
||||
# Check for :else
|
||||
if isinstance(test, Keyword) and test.name == "else":
|
||||
return evaluate(result, env)
|
||||
if isinstance(test, Symbol) and test.name == ":else":
|
||||
return evaluate(result, env)
|
||||
if evaluate(test, env):
|
||||
return evaluate(result, env)
|
||||
i += 2
|
||||
return None
|
||||
|
||||
# case - switch on value
|
||||
# (case expr val1 result1 val2 result2 :else default)
|
||||
if name == "case":
|
||||
if len(expr) < 2:
|
||||
raise EvalError("case requires expression to match")
|
||||
match_val = evaluate(expr[1], env)
|
||||
clauses = expr[2:]
|
||||
i = 0
|
||||
while i < len(clauses) - 1:
|
||||
test = clauses[i]
|
||||
result = clauses[i + 1]
|
||||
# Check for :else / else
|
||||
if isinstance(test, Keyword) and test.name == "else":
|
||||
return evaluate(result, env)
|
||||
if isinstance(test, Symbol) and test.name in (":else", "else"):
|
||||
return evaluate(result, env)
|
||||
# Evaluate test value and compare
|
||||
test_val = evaluate(test, env)
|
||||
if match_val == test_val:
|
||||
return evaluate(result, env)
|
||||
i += 2
|
||||
return None
|
||||
|
||||
# and - short-circuit
|
||||
if name == "and":
|
||||
result = True
|
||||
for arg in expr[1:]:
|
||||
result = evaluate(arg, env)
|
||||
if not result:
|
||||
return result
|
||||
return result
|
||||
|
||||
# or - short-circuit
|
||||
if name == "or":
|
||||
result = False
|
||||
for arg in expr[1:]:
|
||||
result = evaluate(arg, env)
|
||||
if result:
|
||||
return result
|
||||
return result
|
||||
|
||||
# let and let* - local bindings (both bind sequentially in this impl)
|
||||
if name in ("let", "let*"):
|
||||
if len(expr) < 3:
|
||||
raise EvalError(f"{name} requires bindings and body")
|
||||
bindings = expr[1]
|
||||
|
||||
local_env = dict(env)
|
||||
|
||||
if isinstance(bindings, list):
|
||||
# Check if it's ((name value) ...) style (Lisp let* style)
|
||||
if bindings and isinstance(bindings[0], list):
|
||||
for binding in bindings:
|
||||
if len(binding) != 2:
|
||||
raise EvalError(f"{name} binding must be (name value)")
|
||||
var_name = binding[0]
|
||||
if isinstance(var_name, Symbol):
|
||||
var_name = var_name.name
|
||||
value = evaluate(binding[1], local_env)
|
||||
local_env[var_name] = value
|
||||
# Vector-style [name value ...]
|
||||
elif len(bindings) % 2 == 0:
|
||||
for i in range(0, len(bindings), 2):
|
||||
var_name = bindings[i]
|
||||
if isinstance(var_name, Symbol):
|
||||
var_name = var_name.name
|
||||
value = evaluate(bindings[i + 1], local_env)
|
||||
local_env[var_name] = value
|
||||
else:
|
||||
raise EvalError(f"{name} bindings must be [name value ...] or ((name value) ...)")
|
||||
else:
|
||||
raise EvalError(f"{name} bindings must be a list")
|
||||
|
||||
return evaluate(expr[2], local_env)
|
||||
|
||||
# lambda / fn - create function with closure
|
||||
if name in ("lambda", "fn"):
|
||||
if len(expr) < 3:
|
||||
raise EvalError("lambda requires params and body")
|
||||
params = expr[1]
|
||||
if not isinstance(params, list):
|
||||
raise EvalError("lambda params must be a list")
|
||||
param_names = []
|
||||
for p in params:
|
||||
if isinstance(p, Symbol):
|
||||
param_names.append(p.name)
|
||||
elif isinstance(p, str):
|
||||
param_names.append(p)
|
||||
else:
|
||||
raise EvalError(f"Invalid param: {p}")
|
||||
# Capture current environment as closure
|
||||
return Lambda(param_names, expr[2], dict(env))
|
||||
|
||||
# quote - return unevaluated
|
||||
if name == "quote":
|
||||
return expr[1] if len(expr) > 1 else None
|
||||
|
||||
# bind - create binding to analysis data
|
||||
# (bind analysis-var)
|
||||
# (bind analysis-var :range [0.3 1.0])
|
||||
# (bind analysis-var :range [0 100] :transform sqrt)
|
||||
if name == "bind":
|
||||
if len(expr) < 2:
|
||||
raise EvalError("bind requires analysis reference")
|
||||
analysis_ref = expr[1]
|
||||
if isinstance(analysis_ref, Symbol):
|
||||
symbol_name = analysis_ref.name
|
||||
# Look up the symbol in environment
|
||||
if symbol_name in env:
|
||||
resolved = env[symbol_name]
|
||||
# If resolved is actual analysis data (dict with times/values or
|
||||
# S-expression list with Keywords), keep the symbol name as reference
|
||||
# for later lookup at execution time
|
||||
if isinstance(resolved, dict) and ("times" in resolved or "values" in resolved):
|
||||
analysis_ref = symbol_name # Use name as reference, not the data
|
||||
elif isinstance(resolved, list) and any(isinstance(x, Keyword) for x in resolved):
|
||||
# Parsed S-expression analysis data ([:times [...] :duration ...])
|
||||
analysis_ref = symbol_name
|
||||
else:
|
||||
analysis_ref = resolved
|
||||
else:
|
||||
raise EvalError(f"bind: undefined symbol '{symbol_name}' - must reference analysis data")
|
||||
|
||||
# Parse optional :range [min max] and :transform
|
||||
range_min, range_max = 0.0, 1.0
|
||||
transform = None
|
||||
i = 2
|
||||
while i < len(expr):
|
||||
if isinstance(expr[i], Keyword):
|
||||
kw = expr[i].name
|
||||
if kw == "range" and i + 1 < len(expr):
|
||||
range_val = evaluate(expr[i + 1], env) # Evaluate to get actual value
|
||||
if isinstance(range_val, list) and len(range_val) >= 2:
|
||||
range_min = float(range_val[0])
|
||||
range_max = float(range_val[1])
|
||||
i += 2
|
||||
elif kw == "transform" and i + 1 < len(expr):
|
||||
t = expr[i + 1]
|
||||
if isinstance(t, Symbol):
|
||||
transform = t.name
|
||||
elif isinstance(t, str):
|
||||
transform = t
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return Binding(analysis_ref, range_min=range_min, range_max=range_max, transform=transform)
|
||||
|
||||
# Vector literal [a b c]
|
||||
if name == "vec" or name == "vector":
|
||||
return [evaluate(e, env) for e in expr[1:]]
|
||||
|
||||
# map - (map fn coll)
|
||||
if name == "map":
|
||||
if len(expr) != 3:
|
||||
raise EvalError("map requires fn and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
coll = evaluate(expr[2], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"map requires lambda, got {type(fn)}")
|
||||
result = []
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = item
|
||||
result.append(evaluate(fn.body, local_env))
|
||||
return result
|
||||
|
||||
# map-indexed - (map-indexed fn coll)
|
||||
if name == "map-indexed":
|
||||
if len(expr) != 3:
|
||||
raise EvalError("map-indexed requires fn and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
coll = evaluate(expr[2], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"map-indexed requires lambda, got {type(fn)}")
|
||||
if len(fn.params) < 2:
|
||||
raise EvalError("map-indexed lambda needs (i item) params")
|
||||
result = []
|
||||
for i, item in enumerate(coll):
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = i
|
||||
local_env[fn.params[1]] = item
|
||||
result.append(evaluate(fn.body, local_env))
|
||||
return result
|
||||
|
||||
# reduce - (reduce fn init coll)
|
||||
if name == "reduce":
|
||||
if len(expr) != 4:
|
||||
raise EvalError("reduce requires fn, init, and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
acc = evaluate(expr[2], env)
|
||||
coll = evaluate(expr[3], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"reduce requires lambda, got {type(fn)}")
|
||||
if len(fn.params) < 2:
|
||||
raise EvalError("reduce lambda needs (acc item) params")
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = acc
|
||||
local_env[fn.params[1]] = item
|
||||
acc = evaluate(fn.body, local_env)
|
||||
return acc
|
||||
|
||||
# for-each - (for-each fn coll) - iterate with side effects
|
||||
if name == "for-each":
|
||||
if len(expr) != 3:
|
||||
raise EvalError("for-each requires fn and collection")
|
||||
fn = evaluate(expr[1], env)
|
||||
coll = evaluate(expr[2], env)
|
||||
if not isinstance(fn, Lambda):
|
||||
raise EvalError(f"for-each requires lambda, got {type(fn)}")
|
||||
for item in coll:
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
local_env[fn.params[0]] = item
|
||||
evaluate(fn.body, local_env)
|
||||
return None
|
||||
|
||||
# Function call
|
||||
fn = evaluate(head, env)
|
||||
args = [evaluate(arg, env) for arg in expr[1:]]
|
||||
|
||||
# Call builtin
|
||||
if callable(fn):
|
||||
return fn(*args)
|
||||
|
||||
# Call lambda
|
||||
if isinstance(fn, Lambda):
|
||||
if len(args) != len(fn.params):
|
||||
raise EvalError(f"Lambda expects {len(fn.params)} args, got {len(args)}")
|
||||
# Start with closure (captured env), then overlay calling env, then params
|
||||
local_env = {}
|
||||
if fn.closure:
|
||||
local_env.update(fn.closure)
|
||||
local_env.update(env)
|
||||
for name, value in zip(fn.params, args):
|
||||
local_env[name] = value
|
||||
return evaluate(fn.body, local_env)
|
||||
|
||||
raise EvalError(f"Not callable: {fn}")
|
||||
|
||||
raise EvalError(f"Cannot evaluate: {expr!r}")
|
||||
|
||||
|
||||
def make_env(**kwargs) -> Dict[str, Any]:
|
||||
"""Create an environment with initial bindings."""
|
||||
return dict(kwargs)
|
||||
292
core/artdag/sexp/external_tools.py
Normal file
292
core/artdag/sexp/external_tools.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
External tool runners for effects that can't be done in FFmpeg.
|
||||
|
||||
Supports:
|
||||
- datamosh: via ffglitch or datamoshing Python CLI
|
||||
- pixelsort: via Rust pixelsort or Python pixelsort-cli
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
def find_tool(tool_names: List[str]) -> Optional[str]:
|
||||
"""Find first available tool from a list of candidates."""
|
||||
for name in tool_names:
|
||||
path = shutil.which(name)
|
||||
if path:
|
||||
return path
|
||||
return None
|
||||
|
||||
|
||||
def check_python_package(package: str) -> bool:
|
||||
"""Check if a Python package is installed."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python3", "-c", f"import {package}"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Tool detection
|
||||
DATAMOSH_TOOLS = ["ffgac", "ffedit"] # ffglitch tools
|
||||
PIXELSORT_TOOLS = ["pixelsort"] # Rust CLI
|
||||
|
||||
|
||||
def get_available_tools() -> Dict[str, Optional[str]]:
|
||||
"""Get dictionary of available external tools."""
|
||||
return {
|
||||
"datamosh": find_tool(DATAMOSH_TOOLS),
|
||||
"pixelsort": find_tool(PIXELSORT_TOOLS),
|
||||
"datamosh_python": "datamoshing" if check_python_package("datamoshing") else None,
|
||||
"pixelsort_python": "pixelsort" if check_python_package("pixelsort") else None,
|
||||
}
|
||||
|
||||
|
||||
def run_datamosh(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run datamosh effect using available tool.
|
||||
|
||||
Args:
|
||||
input_path: Input video file
|
||||
output_path: Output video file
|
||||
params: Effect parameters (corruption, block_size, etc.)
|
||||
|
||||
Returns:
|
||||
(success, error_message)
|
||||
"""
|
||||
tools = get_available_tools()
|
||||
|
||||
corruption = params.get("corruption", 0.3)
|
||||
|
||||
# Try ffglitch first
|
||||
if tools.get("datamosh"):
|
||||
ffgac = tools["datamosh"]
|
||||
try:
|
||||
# ffglitch approach: remove I-frames to create datamosh effect
|
||||
# This is a simplified version - full datamosh needs custom scripts
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.js', delete=False) as f:
|
||||
# Write a simple ffglitch script that corrupts motion vectors
|
||||
f.write(f"""
|
||||
// Datamosh script - corrupt motion vectors
|
||||
let corruption = {corruption};
|
||||
|
||||
export function glitch_frame(frame, stream) {{
|
||||
if (frame.pict_type === 'P' && Math.random() < corruption) {{
|
||||
// Corrupt motion vectors
|
||||
let dominated = frame.mv?.forward?.overflow;
|
||||
if (dominated) {{
|
||||
for (let i = 0; i < dominated.length; i++) {{
|
||||
if (Math.random() < corruption) {{
|
||||
dominated[i] = [
|
||||
Math.floor(Math.random() * 64 - 32),
|
||||
Math.floor(Math.random() * 64 - 32)
|
||||
];
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
return frame;
|
||||
}}
|
||||
""")
|
||||
script_path = f.name
|
||||
|
||||
cmd = [
|
||||
ffgac,
|
||||
"-i", str(input_path),
|
||||
"-s", script_path,
|
||||
"-o", str(output_path),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
Path(script_path).unlink(missing_ok=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Datamosh timeout"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
# Fall back to Python datamoshing package
|
||||
if tools.get("datamosh_python"):
|
||||
try:
|
||||
cmd = [
|
||||
"python3", "-m", "datamoshing",
|
||||
str(input_path),
|
||||
str(output_path),
|
||||
"--mode", "iframe_removal",
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
return False, "No datamosh tool available. Install ffglitch or: pip install datamoshing"
|
||||
|
||||
|
||||
def run_pixelsort(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run pixelsort effect using available tool.
|
||||
|
||||
Args:
|
||||
input_path: Input image/frame file
|
||||
output_path: Output image file
|
||||
params: Effect parameters (sort_by, threshold_low, threshold_high, angle)
|
||||
|
||||
Returns:
|
||||
(success, error_message)
|
||||
"""
|
||||
tools = get_available_tools()
|
||||
|
||||
sort_by = params.get("sort_by", "lightness")
|
||||
threshold_low = params.get("threshold_low", 50)
|
||||
threshold_high = params.get("threshold_high", 200)
|
||||
angle = params.get("angle", 0)
|
||||
|
||||
# Try Rust pixelsort first (faster)
|
||||
if tools.get("pixelsort"):
|
||||
try:
|
||||
cmd = [
|
||||
tools["pixelsort"],
|
||||
str(input_path),
|
||||
"-o", str(output_path),
|
||||
"--sort", sort_by,
|
||||
"-r", str(angle),
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
# Fall back to Python pixelsort-cli
|
||||
if tools.get("pixelsort_python"):
|
||||
try:
|
||||
cmd = [
|
||||
"python3", "-m", "pixelsort",
|
||||
"--image_path", str(input_path),
|
||||
"--output", str(output_path),
|
||||
"--angle", str(angle),
|
||||
"--threshold", str(threshold_low / 255), # Normalize to 0-1
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
return False, result.stderr[:500]
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
return False, "No pixelsort tool available. Install: cargo install pixelsort or pip install pixelsort-cli"
|
||||
|
||||
|
||||
def run_pixelsort_video(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
fps: float = 30,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run pixelsort on a video by processing each frame.
|
||||
|
||||
This extracts frames, processes them, then reassembles.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
frames_in = tmpdir / "frame_%04d.png"
|
||||
frames_out = tmpdir / "sorted_%04d.png"
|
||||
|
||||
# Extract frames
|
||||
extract_cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", str(input_path),
|
||||
"-vf", f"fps={fps}",
|
||||
str(frames_in),
|
||||
]
|
||||
result = subprocess.run(extract_cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
return False, "Failed to extract frames"
|
||||
|
||||
# Process each frame
|
||||
frame_files = sorted(tmpdir.glob("frame_*.png"))
|
||||
for i, frame in enumerate(frame_files):
|
||||
out_frame = tmpdir / f"sorted_{i:04d}.png"
|
||||
success, error = run_pixelsort(frame, out_frame, params)
|
||||
if not success:
|
||||
return False, f"Frame {i}: {error}"
|
||||
|
||||
# Reassemble
|
||||
# Get audio from original
|
||||
reassemble_cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-framerate", str(fps),
|
||||
"-i", str(tmpdir / "sorted_%04d.png"),
|
||||
"-i", str(input_path),
|
||||
"-map", "0:v", "-map", "1:a?",
|
||||
"-c:v", "libx264", "-preset", "fast",
|
||||
"-c:a", "copy",
|
||||
str(output_path),
|
||||
]
|
||||
result = subprocess.run(reassemble_cmd, capture_output=True, timeout=300)
|
||||
if result.returncode != 0:
|
||||
return False, "Failed to reassemble video"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def run_external_effect(
|
||||
effect_name: str,
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
params: Dict[str, Any],
|
||||
is_video: bool = True,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Run an external effect tool.
|
||||
|
||||
Args:
|
||||
effect_name: Name of effect (datamosh, pixelsort)
|
||||
input_path: Input file
|
||||
output_path: Output file
|
||||
params: Effect parameters
|
||||
is_video: Whether input is video (vs single image)
|
||||
|
||||
Returns:
|
||||
(success, error_message)
|
||||
"""
|
||||
if effect_name == "datamosh":
|
||||
return run_datamosh(input_path, output_path, params)
|
||||
elif effect_name == "pixelsort":
|
||||
if is_video:
|
||||
return run_pixelsort_video(input_path, output_path, params)
|
||||
else:
|
||||
return run_pixelsort(input_path, output_path, params)
|
||||
else:
|
||||
return False, f"Unknown external effect: {effect_name}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Print available tools
|
||||
print("Available external tools:")
|
||||
for name, path in get_available_tools().items():
|
||||
status = path if path else "NOT INSTALLED"
|
||||
print(f" {name}: {status}")
|
||||
616
core/artdag/sexp/ffmpeg_compiler.py
Normal file
616
core/artdag/sexp/ffmpeg_compiler.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
FFmpeg filter compiler for sexp effects.
|
||||
|
||||
Compiles sexp effect definitions to FFmpeg filter expressions,
|
||||
with support for dynamic parameters via sendcmd scripts.
|
||||
|
||||
Usage:
|
||||
compiler = FFmpegCompiler()
|
||||
|
||||
# Compile an effect with static params
|
||||
filter_str = compiler.compile_effect("brightness", {"amount": 50})
|
||||
# -> "eq=brightness=0.196"
|
||||
|
||||
# Compile with dynamic binding to analysis data
|
||||
filter_str, sendcmd = compiler.compile_effect_with_binding(
|
||||
"brightness",
|
||||
{"amount": {"_bind": "bass-data", "range_min": 0, "range_max": 100}},
|
||||
analysis_data={"bass-data": {"times": [...], "values": [...]}},
|
||||
segment_start=0.0,
|
||||
segment_duration=5.0,
|
||||
)
|
||||
# -> ("eq=brightness=0.5", "0.0 [eq] brightness 0.5;\n0.05 [eq] brightness 0.6;...")
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
# FFmpeg filter mappings for common effects
|
||||
# Maps effect name -> {filter: str, params: {param_name: {ffmpeg_param, scale, offset}}}
|
||||
EFFECT_MAPPINGS = {
|
||||
"invert": {
|
||||
"filter": "negate",
|
||||
"params": {},
|
||||
},
|
||||
"grayscale": {
|
||||
"filter": "colorchannelmixer",
|
||||
"static": "0.3:0.4:0.3:0:0.3:0.4:0.3:0:0.3:0.4:0.3",
|
||||
"params": {},
|
||||
},
|
||||
"sepia": {
|
||||
"filter": "colorchannelmixer",
|
||||
"static": "0.393:0.769:0.189:0:0.349:0.686:0.168:0:0.272:0.534:0.131",
|
||||
"params": {},
|
||||
},
|
||||
"brightness": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0},
|
||||
},
|
||||
},
|
||||
"contrast": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"saturation": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "saturation", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"hue_shift": {
|
||||
"filter": "hue",
|
||||
"params": {
|
||||
"degrees": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"blur": {
|
||||
"filter": "gblur",
|
||||
"params": {
|
||||
"radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"sharpen": {
|
||||
"filter": "unsharp",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "la", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"pixelate": {
|
||||
# Scale down then up to create pixelation effect
|
||||
"filter": "scale",
|
||||
"static": "iw/8:ih/8:flags=neighbor,scale=iw*8:ih*8:flags=neighbor",
|
||||
"params": {},
|
||||
},
|
||||
"vignette": {
|
||||
"filter": "vignette",
|
||||
"params": {
|
||||
"strength": {"ffmpeg_param": "a", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"noise": {
|
||||
"filter": "noise",
|
||||
"params": {
|
||||
"amount": {"ffmpeg_param": "alls", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"flip": {
|
||||
"filter": "hflip", # Default horizontal
|
||||
"params": {},
|
||||
},
|
||||
"mirror": {
|
||||
"filter": "hflip",
|
||||
"params": {},
|
||||
},
|
||||
"rotate": {
|
||||
"filter": "rotate",
|
||||
"params": {
|
||||
"angle": {"ffmpeg_param": "a", "scale": math.pi/180, "offset": 0}, # degrees to radians
|
||||
},
|
||||
},
|
||||
"zoom": {
|
||||
"filter": "zoompan",
|
||||
"params": {
|
||||
"factor": {"ffmpeg_param": "z", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"posterize": {
|
||||
# Use lutyuv to quantize levels (approximate posterization)
|
||||
"filter": "lutyuv",
|
||||
"static": "y=floor(val/32)*32:u=floor(val/32)*32:v=floor(val/32)*32",
|
||||
"params": {},
|
||||
},
|
||||
"threshold": {
|
||||
# Use geq for thresholding
|
||||
"filter": "geq",
|
||||
"static": "lum='if(gt(lum(X,Y),128),255,0)':cb=128:cr=128",
|
||||
"params": {},
|
||||
},
|
||||
"edge_detect": {
|
||||
"filter": "edgedetect",
|
||||
"params": {
|
||||
"low": {"ffmpeg_param": "low", "scale": 1/255, "offset": 0},
|
||||
"high": {"ffmpeg_param": "high", "scale": 1/255, "offset": 0},
|
||||
},
|
||||
},
|
||||
"swirl": {
|
||||
"filter": "lenscorrection", # Approximate with lens distortion
|
||||
"params": {
|
||||
"strength": {"ffmpeg_param": "k1", "scale": 0.1, "offset": 0},
|
||||
},
|
||||
},
|
||||
"fisheye": {
|
||||
"filter": "lenscorrection",
|
||||
"params": {
|
||||
"strength": {"ffmpeg_param": "k1", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"wave": {
|
||||
# Wave displacement using geq - need r/g/b for RGB mode
|
||||
"filter": "geq",
|
||||
"static": "r='r(X+10*sin(Y/20),Y)':g='g(X+10*sin(Y/20),Y)':b='b(X+10*sin(Y/20),Y)'",
|
||||
"params": {},
|
||||
},
|
||||
"rgb_split": {
|
||||
# Chromatic aberration using geq
|
||||
"filter": "geq",
|
||||
"static": "r='p(X+5,Y)':g='p(X,Y)':b='p(X-5,Y)'",
|
||||
"params": {},
|
||||
},
|
||||
"scanlines": {
|
||||
"filter": "drawgrid",
|
||||
"params": {
|
||||
"spacing": {"ffmpeg_param": "h", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"film_grain": {
|
||||
"filter": "noise",
|
||||
"params": {
|
||||
"intensity": {"ffmpeg_param": "alls", "scale": 100, "offset": 0},
|
||||
},
|
||||
},
|
||||
"crt": {
|
||||
"filter": "vignette", # Simplified - just vignette for CRT look
|
||||
"params": {},
|
||||
},
|
||||
"bloom": {
|
||||
"filter": "gblur", # Simplified bloom = blur overlay
|
||||
"params": {
|
||||
"radius": {"ffmpeg_param": "sigma", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
"color_cycle": {
|
||||
"filter": "hue",
|
||||
"params": {
|
||||
"speed": {"ffmpeg_param": "h", "scale": 360.0, "offset": 0, "time_expr": True},
|
||||
},
|
||||
"time_based": True, # Uses time expression
|
||||
},
|
||||
"strobe": {
|
||||
# Strobe using select to drop frames
|
||||
"filter": "select",
|
||||
"static": "'mod(n,4)'",
|
||||
"params": {},
|
||||
},
|
||||
"echo": {
|
||||
# Echo using tmix
|
||||
"filter": "tmix",
|
||||
"static": "frames=4:weights='1 0.5 0.25 0.125'",
|
||||
"params": {},
|
||||
},
|
||||
"trails": {
|
||||
# Trails using tblend
|
||||
"filter": "tblend",
|
||||
"static": "all_mode=average",
|
||||
"params": {},
|
||||
},
|
||||
"kaleidoscope": {
|
||||
# 4-way mirror kaleidoscope using FFmpeg filter chain
|
||||
# Crops top-left quadrant, mirrors horizontally, then vertically
|
||||
"filter": "crop",
|
||||
"complex": True,
|
||||
"static": "iw/2:ih/2:0:0[q];[q]split[q1][q2];[q1]hflip[qr];[q2][qr]hstack[top];[top]split[t1][t2];[t2]vflip[bot];[t1][bot]vstack",
|
||||
"params": {},
|
||||
},
|
||||
"emboss": {
|
||||
"filter": "convolution",
|
||||
"static": "-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2:-2 -1 0 -1 1 1 0 1 2",
|
||||
"params": {},
|
||||
},
|
||||
"neon_glow": {
|
||||
# Edge detect + negate for neon-like effect
|
||||
"filter": "edgedetect",
|
||||
"static": "mode=colormix:high=0.1",
|
||||
"params": {},
|
||||
},
|
||||
"ascii_art": {
|
||||
# Requires Python frame processing - no FFmpeg equivalent
|
||||
"filter": None,
|
||||
"python_primitive": "ascii_art_frame",
|
||||
"params": {
|
||||
"char_size": {"default": 8},
|
||||
"alphabet": {"default": "standard"},
|
||||
"color_mode": {"default": "color"},
|
||||
},
|
||||
},
|
||||
"ascii_zones": {
|
||||
# Requires Python frame processing - zone-based ASCII
|
||||
"filter": None,
|
||||
"python_primitive": "ascii_zones_frame",
|
||||
"params": {
|
||||
"char_size": {"default": 8},
|
||||
"zone_threshold": {"default": 128},
|
||||
},
|
||||
},
|
||||
"datamosh": {
|
||||
# External tool: ffglitch or datamoshing CLI, falls back to Python
|
||||
"filter": None,
|
||||
"external_tool": "datamosh",
|
||||
"python_primitive": "datamosh_frame",
|
||||
"params": {
|
||||
"block_size": {"default": 32},
|
||||
"corruption": {"default": 0.3},
|
||||
},
|
||||
},
|
||||
"pixelsort": {
|
||||
# External tool: pixelsort CLI (Rust or Python), falls back to Python
|
||||
"filter": None,
|
||||
"external_tool": "pixelsort",
|
||||
"python_primitive": "pixelsort_frame",
|
||||
"params": {
|
||||
"sort_by": {"default": "lightness"},
|
||||
"threshold_low": {"default": 50},
|
||||
"threshold_high": {"default": 200},
|
||||
"angle": {"default": 0},
|
||||
},
|
||||
},
|
||||
"ripple": {
|
||||
# Use geq for ripple displacement
|
||||
"filter": "geq",
|
||||
"static": "lum='lum(X+5*sin(hypot(X-W/2,Y-H/2)/10),Y+5*cos(hypot(X-W/2,Y-H/2)/10))'",
|
||||
"params": {},
|
||||
},
|
||||
"tile_grid": {
|
||||
# Use tile filter for grid
|
||||
"filter": "tile",
|
||||
"static": "2x2",
|
||||
"params": {},
|
||||
},
|
||||
"outline": {
|
||||
"filter": "edgedetect",
|
||||
"params": {},
|
||||
},
|
||||
"color-adjust": {
|
||||
"filter": "eq",
|
||||
"params": {
|
||||
"brightness": {"ffmpeg_param": "brightness", "scale": 1/255, "offset": 0},
|
||||
"contrast": {"ffmpeg_param": "contrast", "scale": 1.0, "offset": 0},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class FFmpegCompiler:
|
||||
"""Compiles sexp effects to FFmpeg filters with sendcmd support."""
|
||||
|
||||
def __init__(self, effect_mappings: Dict = None):
|
||||
self.mappings = effect_mappings or EFFECT_MAPPINGS
|
||||
|
||||
def get_full_mapping(self, effect_name: str) -> Optional[Dict]:
|
||||
"""Get full mapping for an effect (including external tools and python primitives)."""
|
||||
mapping = self.mappings.get(effect_name)
|
||||
if not mapping:
|
||||
# Try with underscores/hyphens converted
|
||||
normalized = effect_name.replace("-", "_").replace(" ", "_").lower()
|
||||
mapping = self.mappings.get(normalized)
|
||||
return mapping
|
||||
|
||||
def get_mapping(self, effect_name: str) -> Optional[Dict]:
|
||||
"""Get FFmpeg filter mapping for an effect (returns None for non-FFmpeg effects)."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
# Return None if no mapping or no FFmpeg filter
|
||||
if mapping and mapping.get("filter") is None:
|
||||
return None
|
||||
return mapping
|
||||
|
||||
def has_external_tool(self, effect_name: str) -> Optional[str]:
|
||||
"""Check if effect uses an external tool. Returns tool name or None."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
if mapping:
|
||||
return mapping.get("external_tool")
|
||||
return None
|
||||
|
||||
def has_python_primitive(self, effect_name: str) -> Optional[str]:
|
||||
"""Check if effect uses a Python primitive. Returns primitive name or None."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
if mapping:
|
||||
return mapping.get("python_primitive")
|
||||
return None
|
||||
|
||||
def is_complex_filter(self, effect_name: str) -> bool:
|
||||
"""Check if effect uses a complex filter chain."""
|
||||
mapping = self.get_full_mapping(effect_name)
|
||||
return bool(mapping and mapping.get("complex"))
|
||||
|
||||
def compile_effect(
|
||||
self,
|
||||
effect_name: str,
|
||||
params: Dict[str, Any],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compile an effect to an FFmpeg filter string with static params.
|
||||
|
||||
Returns None if effect has no FFmpeg mapping.
|
||||
"""
|
||||
mapping = self.get_mapping(effect_name)
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
filter_name = mapping["filter"]
|
||||
|
||||
# Handle static filters (no params)
|
||||
if "static" in mapping:
|
||||
return f"{filter_name}={mapping['static']}"
|
||||
|
||||
if not mapping.get("params"):
|
||||
return filter_name
|
||||
|
||||
# Build param string
|
||||
filter_params = []
|
||||
for param_name, param_config in mapping["params"].items():
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
# Skip if it's a binding (handled separately)
|
||||
if isinstance(value, dict) and ("_bind" in value or "_binding" in value):
|
||||
continue
|
||||
ffmpeg_param = param_config["ffmpeg_param"]
|
||||
scale = param_config.get("scale", 1.0)
|
||||
offset = param_config.get("offset", 0)
|
||||
# Handle various value types
|
||||
if isinstance(value, (int, float)):
|
||||
ffmpeg_value = value * scale + offset
|
||||
filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
elif isinstance(value, str):
|
||||
filter_params.append(f"{ffmpeg_param}={value}")
|
||||
elif isinstance(value, list) and value and isinstance(value[0], (int, float)):
|
||||
ffmpeg_value = value[0] * scale + offset
|
||||
filter_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
|
||||
if filter_params:
|
||||
return f"{filter_name}={':'.join(filter_params)}"
|
||||
return filter_name
|
||||
|
||||
def compile_effect_with_bindings(
|
||||
self,
|
||||
effect_name: str,
|
||||
params: Dict[str, Any],
|
||||
analysis_data: Dict[str, Dict],
|
||||
segment_start: float,
|
||||
segment_duration: float,
|
||||
sample_interval: float = 0.04, # ~25 fps
|
||||
) -> Tuple[Optional[str], Optional[str], List[str]]:
|
||||
"""
|
||||
Compile an effect with dynamic bindings to a filter + sendcmd script.
|
||||
|
||||
Returns:
|
||||
(filter_string, sendcmd_script, bound_param_names)
|
||||
- filter_string: Initial FFmpeg filter (may have placeholder values)
|
||||
- sendcmd_script: Script content for sendcmd filter
|
||||
- bound_param_names: List of params that have bindings
|
||||
"""
|
||||
mapping = self.get_mapping(effect_name)
|
||||
if not mapping:
|
||||
return None, None, []
|
||||
|
||||
filter_name = mapping["filter"]
|
||||
static_params = []
|
||||
bound_params = []
|
||||
sendcmd_lines = []
|
||||
|
||||
# Handle time-based effects (use FFmpeg expressions with 't')
|
||||
if mapping.get("time_based"):
|
||||
for param_name, param_config in mapping.get("params", {}).items():
|
||||
if param_name in params:
|
||||
value = params[param_name]
|
||||
ffmpeg_param = param_config["ffmpeg_param"]
|
||||
scale = param_config.get("scale", 1.0)
|
||||
if isinstance(value, (int, float)):
|
||||
# Create time expression: h='t*speed*scale'
|
||||
static_params.append(f"{ffmpeg_param}='t*{value}*{scale}'")
|
||||
else:
|
||||
static_params.append(f"{ffmpeg_param}='t*{scale}'")
|
||||
if static_params:
|
||||
filter_str = f"{filter_name}={':'.join(static_params)}"
|
||||
else:
|
||||
filter_str = f"{filter_name}=h='t*360'" # Default rotation
|
||||
return filter_str, None, []
|
||||
|
||||
# Process each param
|
||||
for param_name, param_config in mapping.get("params", {}).items():
|
||||
if param_name not in params:
|
||||
continue
|
||||
|
||||
value = params[param_name]
|
||||
ffmpeg_param = param_config["ffmpeg_param"]
|
||||
scale = param_config.get("scale", 1.0)
|
||||
offset = param_config.get("offset", 0)
|
||||
|
||||
# Check if it's a binding
|
||||
if isinstance(value, dict) and ("_bind" in value or "_binding" in value):
|
||||
bind_ref = value.get("_bind") or value.get("_binding")
|
||||
range_min = value.get("range_min", 0.0)
|
||||
range_max = value.get("range_max", 1.0)
|
||||
transform = value.get("transform")
|
||||
|
||||
# Get analysis data
|
||||
analysis = analysis_data.get(bind_ref)
|
||||
if not analysis:
|
||||
# Try without -data suffix
|
||||
analysis = analysis_data.get(bind_ref.replace("-data", ""))
|
||||
|
||||
if analysis and "times" in analysis and "values" in analysis:
|
||||
times = analysis["times"]
|
||||
values = analysis["values"]
|
||||
|
||||
# Generate sendcmd entries for this segment
|
||||
segment_end = segment_start + segment_duration
|
||||
t = 0.0 # Time relative to segment start
|
||||
|
||||
while t < segment_duration:
|
||||
abs_time = segment_start + t
|
||||
|
||||
# Find analysis value at this time
|
||||
raw_value = self._interpolate_value(times, values, abs_time)
|
||||
|
||||
# Apply transform
|
||||
if transform == "sqrt":
|
||||
raw_value = math.sqrt(max(0, raw_value))
|
||||
elif transform == "pow2":
|
||||
raw_value = raw_value ** 2
|
||||
elif transform == "log":
|
||||
raw_value = math.log(max(0.001, raw_value))
|
||||
|
||||
# Map to range
|
||||
mapped_value = range_min + raw_value * (range_max - range_min)
|
||||
|
||||
# Apply FFmpeg scaling
|
||||
ffmpeg_value = mapped_value * scale + offset
|
||||
|
||||
# Add sendcmd line (time relative to segment)
|
||||
sendcmd_lines.append(f"{t:.3f} [{filter_name}] {ffmpeg_param} {ffmpeg_value:.4f};")
|
||||
|
||||
t += sample_interval
|
||||
|
||||
bound_params.append(param_name)
|
||||
# Use initial value for the filter string
|
||||
initial_value = self._interpolate_value(times, values, segment_start)
|
||||
initial_mapped = range_min + initial_value * (range_max - range_min)
|
||||
initial_ffmpeg = initial_mapped * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={initial_ffmpeg:.4f}")
|
||||
else:
|
||||
# No analysis data, use range midpoint
|
||||
mid_value = (range_min + range_max) / 2
|
||||
ffmpeg_value = mid_value * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
else:
|
||||
# Static value - handle various types
|
||||
if isinstance(value, (int, float)):
|
||||
ffmpeg_value = value * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
elif isinstance(value, str):
|
||||
# String value - use as-is (e.g., for direction parameters)
|
||||
static_params.append(f"{ffmpeg_param}={value}")
|
||||
elif isinstance(value, list) and value:
|
||||
# List - try to use first numeric element
|
||||
first = value[0]
|
||||
if isinstance(first, (int, float)):
|
||||
ffmpeg_value = first * scale + offset
|
||||
static_params.append(f"{ffmpeg_param}={ffmpeg_value:.4f}")
|
||||
# Skip other types
|
||||
|
||||
# Handle static filters
|
||||
if "static" in mapping:
|
||||
filter_str = f"{filter_name}={mapping['static']}"
|
||||
elif static_params:
|
||||
filter_str = f"{filter_name}={':'.join(static_params)}"
|
||||
else:
|
||||
filter_str = filter_name
|
||||
|
||||
# Combine sendcmd lines
|
||||
sendcmd_script = "\n".join(sendcmd_lines) if sendcmd_lines else None
|
||||
|
||||
return filter_str, sendcmd_script, bound_params
|
||||
|
||||
def _interpolate_value(
|
||||
self,
|
||||
times: List[float],
|
||||
values: List[float],
|
||||
target_time: float,
|
||||
) -> float:
|
||||
"""Interpolate a value from analysis data at a given time."""
|
||||
if not times or not values:
|
||||
return 0.5
|
||||
|
||||
# Find surrounding indices
|
||||
if target_time <= times[0]:
|
||||
return values[0]
|
||||
if target_time >= times[-1]:
|
||||
return values[-1]
|
||||
|
||||
# Binary search for efficiency
|
||||
lo, hi = 0, len(times) - 1
|
||||
while lo < hi - 1:
|
||||
mid = (lo + hi) // 2
|
||||
if times[mid] <= target_time:
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid
|
||||
|
||||
# Linear interpolation
|
||||
t0, t1 = times[lo], times[hi]
|
||||
v0, v1 = values[lo], values[hi]
|
||||
|
||||
if t1 == t0:
|
||||
return v0
|
||||
|
||||
alpha = (target_time - t0) / (t1 - t0)
|
||||
return v0 + alpha * (v1 - v0)
|
||||
|
||||
|
||||
def generate_sendcmd_filter(
|
||||
effects: List[Dict],
|
||||
analysis_data: Dict[str, Dict],
|
||||
segment_start: float,
|
||||
segment_duration: float,
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
"""
|
||||
Generate FFmpeg filter chain with sendcmd for dynamic effects.
|
||||
|
||||
Args:
|
||||
effects: List of effect configs with name and params
|
||||
analysis_data: Analysis data keyed by name
|
||||
segment_start: Segment start time in source
|
||||
segment_duration: Segment duration
|
||||
|
||||
Returns:
|
||||
(filter_chain_string, sendcmd_file_path or None)
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
compiler = FFmpegCompiler()
|
||||
filters = []
|
||||
all_sendcmd_lines = []
|
||||
|
||||
for effect in effects:
|
||||
effect_name = effect.get("effect")
|
||||
params = {k: v for k, v in effect.items() if k != "effect"}
|
||||
|
||||
filter_str, sendcmd, _ = compiler.compile_effect_with_bindings(
|
||||
effect_name,
|
||||
params,
|
||||
analysis_data,
|
||||
segment_start,
|
||||
segment_duration,
|
||||
)
|
||||
|
||||
if filter_str:
|
||||
filters.append(filter_str)
|
||||
if sendcmd:
|
||||
all_sendcmd_lines.append(sendcmd)
|
||||
|
||||
if not filters:
|
||||
return "", None
|
||||
|
||||
filter_chain = ",".join(filters)
|
||||
|
||||
# NOTE: sendcmd disabled - FFmpeg's sendcmd filter has compatibility issues.
|
||||
# Bindings use their initial value (sampled at segment start time).
|
||||
# This is acceptable since each segment is only ~8 seconds.
|
||||
# The binding value is still music-reactive (varies per segment).
|
||||
sendcmd_path = None
|
||||
|
||||
return filter_chain, sendcmd_path
|
||||
425
core/artdag/sexp/parser.py
Normal file
425
core/artdag/sexp/parser.py
Normal file
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
S-expression parser for ArtDAG recipes and plans.
|
||||
|
||||
Supports:
|
||||
- Lists: (a b c)
|
||||
- Symbols: foo, bar-baz, ->
|
||||
- Keywords: :key
|
||||
- Strings: "hello world"
|
||||
- Numbers: 42, 3.14, -1.5
|
||||
- Comments: ; to end of line
|
||||
- Vectors: [a b c] (syntactic sugar for lists)
|
||||
- Maps: {:key1 val1 :key2 val2} (parsed as Python dicts)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Union
|
||||
import re
|
||||
|
||||
|
||||
@dataclass
|
||||
class Symbol:
|
||||
"""An unquoted symbol/identifier."""
|
||||
name: str
|
||||
|
||||
def __repr__(self):
|
||||
return f"Symbol({self.name!r})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Symbol):
|
||||
return self.name == other.name
|
||||
if isinstance(other, str):
|
||||
return self.name == other
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Keyword:
|
||||
"""A keyword starting with colon."""
|
||||
name: str
|
||||
|
||||
def __repr__(self):
|
||||
return f"Keyword({self.name!r})"
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Keyword):
|
||||
return self.name == other.name
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return hash((':' , self.name))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Lambda:
|
||||
"""A lambda/anonymous function with closure."""
|
||||
params: List[str] # Parameter names
|
||||
body: Any # Expression body
|
||||
closure: Dict = None # Captured environment (optional for backwards compat)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Lambda({self.params}, {self.body!r})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Binding:
|
||||
"""A binding to analysis data for dynamic effect parameters."""
|
||||
analysis_ref: str # Name of analysis variable
|
||||
track: str = None # Optional track name (e.g., "bass", "energy")
|
||||
range_min: float = 0.0 # Output range minimum
|
||||
range_max: float = 1.0 # Output range maximum
|
||||
transform: str = None # Optional transform: "sqrt", "pow2", "log", etc.
|
||||
|
||||
def __repr__(self):
|
||||
t = f", transform={self.transform!r}" if self.transform else ""
|
||||
return f"Binding({self.analysis_ref!r}, track={self.track!r}, range=[{self.range_min}, {self.range_max}]{t})"
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Error during S-expression parsing."""
|
||||
def __init__(self, message: str, position: int = 0, line: int = 1, col: int = 1):
|
||||
self.position = position
|
||||
self.line = line
|
||||
self.col = col
|
||||
super().__init__(f"{message} at line {line}, column {col}")
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Tokenize S-expression text into tokens."""
|
||||
|
||||
# Token patterns
|
||||
WHITESPACE = re.compile(r'\s+')
|
||||
COMMENT = re.compile(r';[^\n]*')
|
||||
STRING = re.compile(r'"(?:[^"\\]|\\.)*"')
|
||||
NUMBER = re.compile(r'-?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?')
|
||||
KEYWORD = re.compile(r':[a-zA-Z_][a-zA-Z0-9_-]*')
|
||||
SYMBOL = re.compile(r'[a-zA-Z_*+\-><=/!?][a-zA-Z0-9_*+\-><=/!?.:]*')
|
||||
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
self.pos = 0
|
||||
self.line = 1
|
||||
self.col = 1
|
||||
|
||||
def _advance(self, count: int = 1):
|
||||
"""Advance position, tracking line/column."""
|
||||
for _ in range(count):
|
||||
if self.pos < len(self.text):
|
||||
if self.text[self.pos] == '\n':
|
||||
self.line += 1
|
||||
self.col = 1
|
||||
else:
|
||||
self.col += 1
|
||||
self.pos += 1
|
||||
|
||||
def _skip_whitespace_and_comments(self):
|
||||
"""Skip whitespace and comments."""
|
||||
while self.pos < len(self.text):
|
||||
# Whitespace
|
||||
match = self.WHITESPACE.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
continue
|
||||
|
||||
# Comments
|
||||
match = self.COMMENT.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
def peek(self) -> str | None:
|
||||
"""Peek at current character."""
|
||||
self._skip_whitespace_and_comments()
|
||||
if self.pos >= len(self.text):
|
||||
return None
|
||||
return self.text[self.pos]
|
||||
|
||||
def next_token(self) -> Any:
|
||||
"""Get the next token."""
|
||||
self._skip_whitespace_and_comments()
|
||||
|
||||
if self.pos >= len(self.text):
|
||||
return None
|
||||
|
||||
char = self.text[self.pos]
|
||||
start_line, start_col = self.line, self.col
|
||||
|
||||
# Single-character tokens (parens, brackets, braces)
|
||||
if char in '()[]{}':
|
||||
self._advance()
|
||||
return char
|
||||
|
||||
# String
|
||||
if char == '"':
|
||||
match = self.STRING.match(self.text, self.pos)
|
||||
if not match:
|
||||
raise ParseError("Unterminated string", self.pos, self.line, self.col)
|
||||
self._advance(match.end() - self.pos)
|
||||
# Parse escape sequences
|
||||
content = match.group()[1:-1]
|
||||
content = content.replace('\\n', '\n')
|
||||
content = content.replace('\\t', '\t')
|
||||
content = content.replace('\\"', '"')
|
||||
content = content.replace('\\\\', '\\')
|
||||
return content
|
||||
|
||||
# Keyword
|
||||
if char == ':':
|
||||
match = self.KEYWORD.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
return Keyword(match.group()[1:]) # Strip leading colon
|
||||
raise ParseError(f"Invalid keyword", self.pos, self.line, self.col)
|
||||
|
||||
# Number (must check before symbol due to - prefix)
|
||||
if char.isdigit() or (char == '-' and self.pos + 1 < len(self.text) and
|
||||
(self.text[self.pos + 1].isdigit() or self.text[self.pos + 1] == '.')):
|
||||
match = self.NUMBER.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
num_str = match.group()
|
||||
if '.' in num_str or 'e' in num_str or 'E' in num_str:
|
||||
return float(num_str)
|
||||
return int(num_str)
|
||||
|
||||
# Symbol
|
||||
match = self.SYMBOL.match(self.text, self.pos)
|
||||
if match:
|
||||
self._advance(match.end() - self.pos)
|
||||
return Symbol(match.group())
|
||||
|
||||
raise ParseError(f"Unexpected character: {char!r}", self.pos, self.line, self.col)
|
||||
|
||||
|
||||
def parse(text: str) -> Any:
|
||||
"""
|
||||
Parse an S-expression string into Python data structures.
|
||||
|
||||
Returns:
|
||||
Parsed S-expression as nested Python structures:
|
||||
- Lists become Python lists
|
||||
- Symbols become Symbol objects
|
||||
- Keywords become Keyword objects
|
||||
- Strings become Python strings
|
||||
- Numbers become int/float
|
||||
|
||||
Example:
|
||||
>>> parse('(recipe "test" :version "1.0")')
|
||||
[Symbol('recipe'), 'test', Keyword('version'), '1.0']
|
||||
"""
|
||||
tokenizer = Tokenizer(text)
|
||||
result = _parse_expr(tokenizer)
|
||||
|
||||
# Check for trailing content
|
||||
if tokenizer.peek() is not None:
|
||||
raise ParseError("Unexpected content after expression",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_all(text: str) -> List[Any]:
|
||||
"""
|
||||
Parse multiple S-expressions from a string.
|
||||
|
||||
Returns list of parsed expressions.
|
||||
"""
|
||||
tokenizer = Tokenizer(text)
|
||||
results = []
|
||||
|
||||
while tokenizer.peek() is not None:
|
||||
results.append(_parse_expr(tokenizer))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_expr(tokenizer: Tokenizer) -> Any:
|
||||
"""Parse a single expression."""
|
||||
token = tokenizer.next_token()
|
||||
|
||||
if token is None:
|
||||
raise ParseError("Unexpected end of input", tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
# List
|
||||
if token == '(':
|
||||
return _parse_list(tokenizer, ')')
|
||||
|
||||
# Vector (sugar for list)
|
||||
if token == '[':
|
||||
return _parse_list(tokenizer, ']')
|
||||
|
||||
# Map/dict: {:key1 val1 :key2 val2}
|
||||
if token == '{':
|
||||
return _parse_map(tokenizer)
|
||||
|
||||
# Unexpected closers
|
||||
if token in (')', ']', '}'):
|
||||
raise ParseError(f"Unexpected {token!r}", tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
# Atom
|
||||
return token
|
||||
|
||||
|
||||
def _parse_list(tokenizer: Tokenizer, closer: str) -> List[Any]:
|
||||
"""Parse a list until the closing delimiter."""
|
||||
items = []
|
||||
|
||||
while True:
|
||||
char = tokenizer.peek()
|
||||
|
||||
if char is None:
|
||||
raise ParseError(f"Unterminated list, expected {closer!r}",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
if char == closer:
|
||||
tokenizer.next_token() # Consume closer
|
||||
return items
|
||||
|
||||
items.append(_parse_expr(tokenizer))
|
||||
|
||||
|
||||
def _parse_map(tokenizer: Tokenizer) -> Dict[str, Any]:
|
||||
"""Parse a map/dict: {:key1 val1 :key2 val2} -> {"key1": val1, "key2": val2}."""
|
||||
result = {}
|
||||
|
||||
while True:
|
||||
char = tokenizer.peek()
|
||||
|
||||
if char is None:
|
||||
raise ParseError("Unterminated map, expected '}'",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
if char == '}':
|
||||
tokenizer.next_token() # Consume closer
|
||||
return result
|
||||
|
||||
# Parse key (should be a keyword like :key)
|
||||
key_token = _parse_expr(tokenizer)
|
||||
if isinstance(key_token, Keyword):
|
||||
key = key_token.name
|
||||
elif isinstance(key_token, str):
|
||||
key = key_token
|
||||
else:
|
||||
raise ParseError(f"Map key must be keyword or string, got {type(key_token).__name__}",
|
||||
tokenizer.pos, tokenizer.line, tokenizer.col)
|
||||
|
||||
# Parse value
|
||||
value = _parse_expr(tokenizer)
|
||||
result[key] = value
|
||||
|
||||
|
||||
def serialize(expr: Any, indent: int = 0, pretty: bool = False) -> str:
|
||||
"""
|
||||
Serialize a Python data structure back to S-expression format.
|
||||
|
||||
Args:
|
||||
expr: The expression to serialize
|
||||
indent: Current indentation level (for pretty printing)
|
||||
pretty: Whether to use pretty printing with newlines
|
||||
|
||||
Returns:
|
||||
S-expression string
|
||||
"""
|
||||
if isinstance(expr, list):
|
||||
if not expr:
|
||||
return "()"
|
||||
|
||||
if pretty:
|
||||
return _serialize_pretty(expr, indent)
|
||||
else:
|
||||
items = [serialize(item, indent, False) for item in expr]
|
||||
return "(" + " ".join(items) + ")"
|
||||
|
||||
if isinstance(expr, Symbol):
|
||||
return expr.name
|
||||
|
||||
if isinstance(expr, Keyword):
|
||||
return f":{expr.name}"
|
||||
|
||||
if isinstance(expr, Lambda):
|
||||
params = " ".join(expr.params)
|
||||
body = serialize(expr.body, indent, pretty)
|
||||
return f"(fn [{params}] {body})"
|
||||
|
||||
if isinstance(expr, Binding):
|
||||
# analysis_ref can be a string, node ID, or dict - serialize it properly
|
||||
if isinstance(expr.analysis_ref, str):
|
||||
ref_str = f'"{expr.analysis_ref}"'
|
||||
else:
|
||||
ref_str = serialize(expr.analysis_ref, indent, pretty)
|
||||
s = f"(bind {ref_str} :range [{expr.range_min} {expr.range_max}]"
|
||||
if expr.transform:
|
||||
s += f" :transform {expr.transform}"
|
||||
return s + ")"
|
||||
|
||||
if isinstance(expr, str):
|
||||
# Escape special characters
|
||||
escaped = expr.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n').replace('\t', '\\t')
|
||||
return f'"{escaped}"'
|
||||
|
||||
if isinstance(expr, bool):
|
||||
return "true" if expr else "false"
|
||||
|
||||
if isinstance(expr, (int, float)):
|
||||
return str(expr)
|
||||
|
||||
if expr is None:
|
||||
return "nil"
|
||||
|
||||
if isinstance(expr, dict):
|
||||
# Serialize dict as property list: {:key1 val1 :key2 val2}
|
||||
items = []
|
||||
for k, v in expr.items():
|
||||
items.append(f":{k}")
|
||||
items.append(serialize(v, indent, pretty))
|
||||
return "{" + " ".join(items) + "}"
|
||||
|
||||
raise ValueError(f"Cannot serialize {type(expr).__name__}: {expr!r}")
|
||||
|
||||
|
||||
def _serialize_pretty(expr: List, indent: int) -> str:
|
||||
"""Pretty-print a list expression with smart formatting."""
|
||||
if not expr:
|
||||
return "()"
|
||||
|
||||
prefix = " " * indent
|
||||
inner_prefix = " " * (indent + 1)
|
||||
|
||||
# Check if this is a simple list that fits on one line
|
||||
simple = serialize(expr, indent, False)
|
||||
if len(simple) < 60 and '\n' not in simple:
|
||||
return simple
|
||||
|
||||
# Start building multiline output
|
||||
head = serialize(expr[0], indent + 1, False)
|
||||
parts = [f"({head}"]
|
||||
|
||||
i = 1
|
||||
while i < len(expr):
|
||||
item = expr[i]
|
||||
|
||||
# Group keyword-value pairs on same line
|
||||
if isinstance(item, Keyword) and i + 1 < len(expr):
|
||||
key = serialize(item, 0, False)
|
||||
val = serialize(expr[i + 1], indent + 1, False)
|
||||
|
||||
# If value is short, put on same line
|
||||
if len(val) < 50 and '\n' not in val:
|
||||
parts.append(f"{inner_prefix}{key} {val}")
|
||||
else:
|
||||
# Value is complex, serialize it pretty
|
||||
val_pretty = serialize(expr[i + 1], indent + 1, True)
|
||||
parts.append(f"{inner_prefix}{key} {val_pretty}")
|
||||
i += 2
|
||||
else:
|
||||
# Regular item
|
||||
item_str = serialize(item, indent + 1, True)
|
||||
parts.append(f"{inner_prefix}{item_str}")
|
||||
i += 1
|
||||
|
||||
return "\n".join(parts) + ")"
|
||||
2187
core/artdag/sexp/planner.py
Normal file
2187
core/artdag/sexp/planner.py
Normal file
File diff suppressed because it is too large
Load Diff
620
core/artdag/sexp/primitives.py
Normal file
620
core/artdag/sexp/primitives.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""
|
||||
Frame processing primitives for sexp effects.
|
||||
|
||||
These primitives are called by sexp effect definitions and operate on
|
||||
numpy arrays (frames). They're used when falling back to Python rendering
|
||||
instead of FFmpeg.
|
||||
|
||||
Required: numpy, PIL
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
HAS_NUMPY = True
|
||||
except ImportError:
|
||||
HAS_NUMPY = False
|
||||
np = None
|
||||
|
||||
try:
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
HAS_PIL = True
|
||||
except ImportError:
|
||||
HAS_PIL = False
|
||||
|
||||
|
||||
# ASCII character sets for different styles
|
||||
ASCII_ALPHABETS = {
|
||||
"standard": " .:-=+*#%@",
|
||||
"blocks": " ░▒▓█",
|
||||
"simple": " .-:+=xX#",
|
||||
"detailed": " .'`^\",:;Il!i><~+_-?][}{1)(|\\/tfjrxnuvczXYUJCLQ0OZmwqpdbkhao*#MW&8%B@$",
|
||||
"binary": " █",
|
||||
}
|
||||
|
||||
|
||||
def check_deps():
|
||||
"""Check that required dependencies are available."""
|
||||
if not HAS_NUMPY:
|
||||
raise ImportError("numpy required for frame primitives: pip install numpy")
|
||||
if not HAS_PIL:
|
||||
raise ImportError("PIL required for frame primitives: pip install Pillow")
|
||||
|
||||
|
||||
def frame_to_image(frame: np.ndarray) -> Image.Image:
|
||||
"""Convert numpy frame to PIL Image."""
|
||||
check_deps()
|
||||
if frame.dtype != np.uint8:
|
||||
frame = np.clip(frame, 0, 255).astype(np.uint8)
|
||||
return Image.fromarray(frame)
|
||||
|
||||
|
||||
def image_to_frame(img: Image.Image) -> np.ndarray:
|
||||
"""Convert PIL Image to numpy frame."""
|
||||
check_deps()
|
||||
return np.array(img)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ASCII Art Primitives
|
||||
# ============================================================================
|
||||
|
||||
def cell_sample(frame: np.ndarray, cell_size: int = 8) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample frame into cells, returning average colors and luminances.
|
||||
|
||||
Args:
|
||||
frame: Input frame (H, W, C)
|
||||
cell_size: Size of each cell
|
||||
|
||||
Returns:
|
||||
(colors, luminances) - colors is (rows, cols, 3), luminances is (rows, cols)
|
||||
"""
|
||||
check_deps()
|
||||
h, w = frame.shape[:2]
|
||||
rows = h // cell_size
|
||||
cols = w // cell_size
|
||||
|
||||
colors = np.zeros((rows, cols, 3), dtype=np.float32)
|
||||
luminances = np.zeros((rows, cols), dtype=np.float32)
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
y0, y1 = r * cell_size, (r + 1) * cell_size
|
||||
x0, x1 = c * cell_size, (c + 1) * cell_size
|
||||
cell = frame[y0:y1, x0:x1]
|
||||
|
||||
# Average color
|
||||
avg_color = cell.mean(axis=(0, 1))
|
||||
colors[r, c] = avg_color[:3] # RGB only
|
||||
|
||||
# Luminance (ITU-R BT.601)
|
||||
lum = 0.299 * avg_color[0] + 0.587 * avg_color[1] + 0.114 * avg_color[2]
|
||||
luminances[r, c] = lum
|
||||
|
||||
return colors, luminances
|
||||
|
||||
|
||||
def luminance_to_chars(
|
||||
luminances: np.ndarray,
|
||||
alphabet: str = "standard",
|
||||
contrast: float = 1.0,
|
||||
) -> List[List[str]]:
|
||||
"""
|
||||
Convert luminance values to ASCII characters.
|
||||
|
||||
Args:
|
||||
luminances: 2D array of luminance values (0-255)
|
||||
alphabet: Name of character set or custom string
|
||||
contrast: Contrast multiplier
|
||||
|
||||
Returns:
|
||||
2D list of characters
|
||||
"""
|
||||
check_deps()
|
||||
chars = ASCII_ALPHABETS.get(alphabet, alphabet)
|
||||
n_chars = len(chars)
|
||||
|
||||
rows, cols = luminances.shape
|
||||
result = []
|
||||
|
||||
for r in range(rows):
|
||||
row_chars = []
|
||||
for c in range(cols):
|
||||
lum = luminances[r, c]
|
||||
# Apply contrast around midpoint
|
||||
lum = 128 + (lum - 128) * contrast
|
||||
lum = np.clip(lum, 0, 255)
|
||||
# Map to character index
|
||||
idx = int(lum / 256 * n_chars)
|
||||
idx = min(idx, n_chars - 1)
|
||||
row_chars.append(chars[idx])
|
||||
result.append(row_chars)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def render_char_grid(
|
||||
frame: np.ndarray,
|
||||
chars: List[List[str]],
|
||||
colors: np.ndarray,
|
||||
char_size: int = 8,
|
||||
color_mode: str = "color",
|
||||
background: Tuple[int, int, int] = (0, 0, 0),
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Render character grid to an image.
|
||||
|
||||
Args:
|
||||
frame: Original frame (for dimensions)
|
||||
chars: 2D list of characters
|
||||
colors: Color for each cell (rows, cols, 3)
|
||||
char_size: Size of each character cell
|
||||
color_mode: "color", "white", or "green"
|
||||
background: Background RGB color
|
||||
|
||||
Returns:
|
||||
Rendered frame
|
||||
"""
|
||||
check_deps()
|
||||
h, w = frame.shape[:2]
|
||||
rows = len(chars)
|
||||
cols = len(chars[0]) if chars else 0
|
||||
|
||||
# Create output image
|
||||
img = Image.new("RGB", (w, h), background)
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Try to get a monospace font
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf", char_size)
|
||||
except (IOError, OSError):
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Regular.ttf", char_size)
|
||||
except (IOError, OSError):
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
char = chars[r][c]
|
||||
if char == ' ':
|
||||
continue
|
||||
|
||||
x = c * char_size
|
||||
y = r * char_size
|
||||
|
||||
if color_mode == "color":
|
||||
color = tuple(int(v) for v in colors[r, c])
|
||||
elif color_mode == "green":
|
||||
color = (0, 255, 0)
|
||||
else: # white
|
||||
color = (255, 255, 255)
|
||||
|
||||
draw.text((x, y), char, fill=color, font=font)
|
||||
|
||||
return np.array(img)
|
||||
|
||||
|
||||
def ascii_art_frame(
|
||||
frame: np.ndarray,
|
||||
char_size: int = 8,
|
||||
alphabet: str = "standard",
|
||||
color_mode: str = "color",
|
||||
contrast: float = 1.5,
|
||||
background: Tuple[int, int, int] = (0, 0, 0),
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply ASCII art effect to a frame.
|
||||
|
||||
This is the main entry point for the ascii_art effect.
|
||||
"""
|
||||
check_deps()
|
||||
colors, luminances = cell_sample(frame, char_size)
|
||||
chars = luminance_to_chars(luminances, alphabet, contrast)
|
||||
return render_char_grid(frame, chars, colors, char_size, color_mode, background)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ASCII Zones Primitives
|
||||
# ============================================================================
|
||||
|
||||
def ascii_zones_frame(
|
||||
frame: np.ndarray,
|
||||
char_size: int = 8,
|
||||
zone_threshold: int = 128,
|
||||
dark_chars: str = " .-:",
|
||||
light_chars: str = "=+*#",
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply zone-based ASCII art effect.
|
||||
|
||||
Different character sets for dark vs light regions.
|
||||
"""
|
||||
check_deps()
|
||||
colors, luminances = cell_sample(frame, char_size)
|
||||
|
||||
rows, cols = luminances.shape
|
||||
chars = []
|
||||
|
||||
for r in range(rows):
|
||||
row_chars = []
|
||||
for c in range(cols):
|
||||
lum = luminances[r, c]
|
||||
if lum < zone_threshold:
|
||||
# Dark zone
|
||||
charset = dark_chars
|
||||
local_lum = lum / zone_threshold # 0-1 within zone
|
||||
else:
|
||||
# Light zone
|
||||
charset = light_chars
|
||||
local_lum = (lum - zone_threshold) / (255 - zone_threshold)
|
||||
|
||||
idx = int(local_lum * len(charset))
|
||||
idx = min(idx, len(charset) - 1)
|
||||
row_chars.append(charset[idx])
|
||||
chars.append(row_chars)
|
||||
|
||||
return render_char_grid(frame, chars, colors, char_size, "color", (0, 0, 0))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Kaleidoscope Primitives (Python fallback)
|
||||
# ============================================================================
|
||||
|
||||
def kaleidoscope_displace(
|
||||
w: int,
|
||||
h: int,
|
||||
segments: int = 6,
|
||||
rotation: float = 0,
|
||||
cx: float = None,
|
||||
cy: float = None,
|
||||
zoom: float = 1.0,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute kaleidoscope displacement coordinates.
|
||||
|
||||
Returns (x_coords, y_coords) arrays for remapping.
|
||||
"""
|
||||
check_deps()
|
||||
if cx is None:
|
||||
cx = w / 2
|
||||
if cy is None:
|
||||
cy = h / 2
|
||||
|
||||
# Create coordinate grids
|
||||
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
||||
|
||||
# Center coordinates
|
||||
x_centered = x_grid - cx
|
||||
y_centered = y_grid - cy
|
||||
|
||||
# Convert to polar
|
||||
r = np.sqrt(x_centered**2 + y_centered**2) / zoom
|
||||
theta = np.arctan2(y_centered, x_centered)
|
||||
|
||||
# Apply rotation
|
||||
theta = theta - np.radians(rotation)
|
||||
|
||||
# Kaleidoscope: fold angle into segment
|
||||
segment_angle = 2 * np.pi / segments
|
||||
theta = np.abs(np.mod(theta, segment_angle) - segment_angle / 2)
|
||||
|
||||
# Convert back to cartesian
|
||||
x_new = r * np.cos(theta) + cx
|
||||
y_new = r * np.sin(theta) + cy
|
||||
|
||||
return x_new, y_new
|
||||
|
||||
|
||||
def remap(
|
||||
frame: np.ndarray,
|
||||
x_coords: np.ndarray,
|
||||
y_coords: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Remap frame using coordinate arrays.
|
||||
|
||||
Uses bilinear interpolation.
|
||||
"""
|
||||
check_deps()
|
||||
from scipy import ndimage
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
|
||||
# Clip coordinates
|
||||
x_coords = np.clip(x_coords, 0, w - 1)
|
||||
y_coords = np.clip(y_coords, 0, h - 1)
|
||||
|
||||
# Remap each channel
|
||||
if len(frame.shape) == 3:
|
||||
result = np.zeros_like(frame)
|
||||
for c in range(frame.shape[2]):
|
||||
result[:, :, c] = ndimage.map_coordinates(
|
||||
frame[:, :, c],
|
||||
[y_coords, x_coords],
|
||||
order=1,
|
||||
mode='reflect',
|
||||
)
|
||||
return result
|
||||
else:
|
||||
return ndimage.map_coordinates(frame, [y_coords, x_coords], order=1, mode='reflect')
|
||||
|
||||
|
||||
def kaleidoscope_frame(
|
||||
frame: np.ndarray,
|
||||
segments: int = 6,
|
||||
rotation: float = 0,
|
||||
center_x: float = 0.5,
|
||||
center_y: float = 0.5,
|
||||
zoom: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply kaleidoscope effect to a frame.
|
||||
|
||||
This is a Python fallback - FFmpeg version is faster.
|
||||
"""
|
||||
check_deps()
|
||||
h, w = frame.shape[:2]
|
||||
cx = w * center_x
|
||||
cy = h * center_y
|
||||
|
||||
x_coords, y_coords = kaleidoscope_displace(w, h, segments, rotation, cx, cy, zoom)
|
||||
return remap(frame, x_coords, y_coords)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Datamosh Primitives (simplified Python version)
|
||||
# ============================================================================
|
||||
|
||||
def datamosh_frame(
|
||||
frame: np.ndarray,
|
||||
prev_frame: Optional[np.ndarray],
|
||||
block_size: int = 32,
|
||||
corruption: float = 0.3,
|
||||
max_offset: int = 50,
|
||||
color_corrupt: bool = True,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Simplified datamosh effect using block displacement.
|
||||
|
||||
This is a basic approximation - real datamosh works on compressed video.
|
||||
"""
|
||||
check_deps()
|
||||
if prev_frame is None:
|
||||
return frame.copy()
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
result = frame.copy()
|
||||
|
||||
# Process in blocks
|
||||
for y in range(0, h - block_size, block_size):
|
||||
for x in range(0, w - block_size, block_size):
|
||||
if np.random.random() < corruption:
|
||||
# Random offset
|
||||
ox = np.random.randint(-max_offset, max_offset + 1)
|
||||
oy = np.random.randint(-max_offset, max_offset + 1)
|
||||
|
||||
# Source from previous frame with offset
|
||||
src_y = np.clip(y + oy, 0, h - block_size)
|
||||
src_x = np.clip(x + ox, 0, w - block_size)
|
||||
|
||||
block = prev_frame[src_y:src_y+block_size, src_x:src_x+block_size]
|
||||
|
||||
# Color corruption
|
||||
if color_corrupt and np.random.random() < 0.3:
|
||||
# Swap or shift channels
|
||||
block = np.roll(block, np.random.randint(1, 3), axis=2)
|
||||
|
||||
result[y:y+block_size, x:x+block_size] = block
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pixelsort Primitives (Python version)
|
||||
# ============================================================================
|
||||
|
||||
def pixelsort_frame(
|
||||
frame: np.ndarray,
|
||||
sort_by: str = "lightness",
|
||||
threshold_low: float = 50,
|
||||
threshold_high: float = 200,
|
||||
angle: float = 0,
|
||||
reverse: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Apply pixel sorting effect to a frame.
|
||||
"""
|
||||
check_deps()
|
||||
from scipy import ndimage
|
||||
|
||||
# Rotate if needed
|
||||
if angle != 0:
|
||||
frame = ndimage.rotate(frame, -angle, reshape=False, mode='reflect')
|
||||
|
||||
h, w = frame.shape[:2]
|
||||
result = frame.copy()
|
||||
|
||||
# Compute sort key
|
||||
if sort_by == "lightness":
|
||||
key = 0.299 * frame[:,:,0] + 0.587 * frame[:,:,1] + 0.114 * frame[:,:,2]
|
||||
elif sort_by == "hue":
|
||||
# Simple hue approximation
|
||||
key = np.arctan2(
|
||||
np.sqrt(3) * (frame[:,:,1].astype(float) - frame[:,:,2]),
|
||||
2 * frame[:,:,0].astype(float) - frame[:,:,1] - frame[:,:,2]
|
||||
)
|
||||
elif sort_by == "saturation":
|
||||
mx = frame.max(axis=2).astype(float)
|
||||
mn = frame.min(axis=2).astype(float)
|
||||
key = np.where(mx > 0, (mx - mn) / mx, 0)
|
||||
else:
|
||||
key = frame[:,:,0] # Red channel
|
||||
|
||||
# Sort each row
|
||||
for y in range(h):
|
||||
row = result[y]
|
||||
row_key = key[y]
|
||||
|
||||
# Find sortable intervals (pixels within threshold)
|
||||
mask = (row_key >= threshold_low) & (row_key <= threshold_high)
|
||||
|
||||
# Find runs of True in mask
|
||||
runs = []
|
||||
start = None
|
||||
for x in range(w):
|
||||
if mask[x] and start is None:
|
||||
start = x
|
||||
elif not mask[x] and start is not None:
|
||||
if x - start > 1:
|
||||
runs.append((start, x))
|
||||
start = None
|
||||
if start is not None and w - start > 1:
|
||||
runs.append((start, w))
|
||||
|
||||
# Sort each run
|
||||
for start, end in runs:
|
||||
indices = np.argsort(row_key[start:end])
|
||||
if reverse:
|
||||
indices = indices[::-1]
|
||||
result[y, start:end] = row[start:end][indices]
|
||||
|
||||
# Rotate back
|
||||
if angle != 0:
|
||||
result = ndimage.rotate(result, angle, reshape=False, mode='reflect')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Primitive Registry
|
||||
# ============================================================================
|
||||
|
||||
def map_char_grid(
|
||||
chars,
|
||||
luminances,
|
||||
fn,
|
||||
):
|
||||
"""
|
||||
Map a function over each cell of a character grid.
|
||||
|
||||
Args:
|
||||
chars: 2D array/list of characters (rows, cols)
|
||||
luminances: 2D array of luminance values
|
||||
fn: Function or Lambda (row, col, char, luminance) -> new_char
|
||||
|
||||
Returns:
|
||||
New character grid with mapped values (list of lists)
|
||||
"""
|
||||
from .parser import Lambda
|
||||
from .evaluator import evaluate
|
||||
|
||||
# Handle both list and numpy array inputs
|
||||
if isinstance(chars, np.ndarray):
|
||||
rows, cols = chars.shape[:2]
|
||||
else:
|
||||
rows = len(chars)
|
||||
cols = len(chars[0]) if rows > 0 and isinstance(chars[0], (list, tuple, str)) else 1
|
||||
|
||||
# Get luminances as 2D
|
||||
if isinstance(luminances, np.ndarray):
|
||||
lum_arr = luminances
|
||||
else:
|
||||
lum_arr = np.array(luminances)
|
||||
|
||||
# Check if fn is a Lambda (from sexp) or a Python callable
|
||||
is_lambda = isinstance(fn, Lambda)
|
||||
|
||||
result = []
|
||||
for r in range(rows):
|
||||
row_result = []
|
||||
for c in range(cols):
|
||||
# Get character
|
||||
if isinstance(chars, np.ndarray):
|
||||
ch = chars[r, c] if len(chars.shape) > 1 else chars[r]
|
||||
elif isinstance(chars[r], str):
|
||||
ch = chars[r][c] if c < len(chars[r]) else ' '
|
||||
else:
|
||||
ch = chars[r][c] if c < len(chars[r]) else ' '
|
||||
|
||||
# Get luminance
|
||||
if len(lum_arr.shape) > 1:
|
||||
lum = lum_arr[r, c]
|
||||
else:
|
||||
lum = lum_arr[r]
|
||||
|
||||
# Call the function
|
||||
if is_lambda:
|
||||
# Evaluate the Lambda with arguments bound
|
||||
call_env = dict(fn.closure) if fn.closure else {}
|
||||
for param, val in zip(fn.params, [r, c, ch, float(lum)]):
|
||||
call_env[param] = val
|
||||
new_ch = evaluate(fn.body, call_env)
|
||||
else:
|
||||
new_ch = fn(r, c, ch, float(lum))
|
||||
|
||||
row_result.append(new_ch)
|
||||
result.append(row_result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def alphabet_char(alphabet: str, index: int) -> str:
|
||||
"""
|
||||
Get a character from an alphabet at a given index.
|
||||
|
||||
Args:
|
||||
alphabet: Alphabet name (from ASCII_ALPHABETS) or literal string
|
||||
index: Index into the alphabet (clamped to valid range)
|
||||
|
||||
Returns:
|
||||
Character at the index
|
||||
"""
|
||||
# Get alphabet string
|
||||
if alphabet in ASCII_ALPHABETS:
|
||||
chars = ASCII_ALPHABETS[alphabet]
|
||||
else:
|
||||
chars = alphabet
|
||||
|
||||
# Clamp index to valid range
|
||||
index = int(index)
|
||||
index = max(0, min(index, len(chars) - 1))
|
||||
|
||||
return chars[index]
|
||||
|
||||
|
||||
PRIMITIVES = {
|
||||
# ASCII
|
||||
"cell-sample": cell_sample,
|
||||
"luminance-to-chars": luminance_to_chars,
|
||||
"render-char-grid": render_char_grid,
|
||||
"map-char-grid": map_char_grid,
|
||||
"alphabet-char": alphabet_char,
|
||||
"ascii_art_frame": ascii_art_frame,
|
||||
"ascii_zones_frame": ascii_zones_frame,
|
||||
|
||||
# Kaleidoscope
|
||||
"kaleidoscope-displace": kaleidoscope_displace,
|
||||
"remap": remap,
|
||||
"kaleidoscope_frame": kaleidoscope_frame,
|
||||
|
||||
# Datamosh
|
||||
"datamosh": datamosh_frame,
|
||||
"datamosh_frame": datamosh_frame,
|
||||
|
||||
# Pixelsort
|
||||
"pixelsort": pixelsort_frame,
|
||||
"pixelsort_frame": pixelsort_frame,
|
||||
}
|
||||
|
||||
|
||||
def get_primitive(name: str):
|
||||
"""Get a primitive function by name."""
|
||||
return PRIMITIVES.get(name)
|
||||
|
||||
|
||||
def list_primitives() -> List[str]:
|
||||
"""List all available primitives."""
|
||||
return list(PRIMITIVES.keys())
|
||||
779
core/artdag/sexp/scheduler.py
Normal file
779
core/artdag/sexp/scheduler.py
Normal file
@@ -0,0 +1,779 @@
|
||||
"""
|
||||
Celery scheduler for S-expression execution plans.
|
||||
|
||||
Distributes plan steps to workers as S-expressions.
|
||||
The S-expression is the canonical format - workers receive
|
||||
serialized S-expressions and can verify cache_ids by hashing them.
|
||||
|
||||
Usage:
|
||||
from artdag.sexp import compile_string, create_plan
|
||||
from artdag.sexp.scheduler import schedule_plan
|
||||
|
||||
recipe = compile_string(sexp_content)
|
||||
plan = create_plan(recipe, inputs={'video': 'abc123...'})
|
||||
result = schedule_plan(plan)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
|
||||
from .parser import Symbol, Keyword, serialize, parse
|
||||
from .planner import ExecutionPlanSexp, PlanStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepResult:
|
||||
"""Result from executing a step."""
|
||||
step_id: str
|
||||
cache_id: str
|
||||
status: str # 'completed', 'cached', 'failed', 'pending'
|
||||
output_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
ipfs_cid: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlanResult:
|
||||
"""Result from executing a complete plan."""
|
||||
plan_id: str
|
||||
status: str # 'completed', 'failed', 'partial'
|
||||
steps_completed: int = 0
|
||||
steps_cached: int = 0
|
||||
steps_failed: int = 0
|
||||
output_cache_id: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
output_ipfs_cid: Optional[str] = None
|
||||
step_results: Dict[str, StepResult] = field(default_factory=dict)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def step_to_sexp(step: PlanStep) -> List:
|
||||
"""
|
||||
Convert a PlanStep to minimal S-expression for worker.
|
||||
|
||||
This is the canonical form that workers receive.
|
||||
Workers can verify cache_id by hashing this S-expression.
|
||||
"""
|
||||
sexp = [Symbol(step.node_type.lower())]
|
||||
|
||||
# Add config as keywords
|
||||
for key, value in step.config.items():
|
||||
sexp.extend([Keyword(key.replace('_', '-')), value])
|
||||
|
||||
# Add inputs as cache IDs (not step IDs)
|
||||
if step.inputs:
|
||||
sexp.extend([Keyword("inputs"), step.inputs])
|
||||
|
||||
return sexp
|
||||
|
||||
|
||||
def step_sexp_to_string(step: PlanStep) -> str:
|
||||
"""Serialize step to S-expression string for Celery task."""
|
||||
return serialize(step_to_sexp(step))
|
||||
|
||||
|
||||
def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool:
|
||||
"""
|
||||
Verify that a step's cache_id matches its S-expression.
|
||||
|
||||
Workers should call this to verify they're executing the correct task.
|
||||
"""
|
||||
data = {"sexp": step_sexp}
|
||||
if cluster_key:
|
||||
data = {"_cluster_key": cluster_key, "_data": data}
|
||||
|
||||
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
||||
computed = hashlib.sha3_256(json_str.encode()).hexdigest()
|
||||
return computed == expected_cache_id
|
||||
|
||||
|
||||
class PlanScheduler:
|
||||
"""
|
||||
Schedules execution of S-expression plans on Celery workers.
|
||||
|
||||
The scheduler:
|
||||
1. Groups steps by dependency level
|
||||
2. Checks cache for already-computed results
|
||||
3. Dispatches uncached steps to workers
|
||||
4. Waits for completion before proceeding to next level
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_manager=None,
|
||||
celery_app=None,
|
||||
execute_task_name: str = 'tasks.execute_step_sexp',
|
||||
):
|
||||
"""
|
||||
Initialize the scheduler.
|
||||
|
||||
Args:
|
||||
cache_manager: L1 cache manager for checking cached results
|
||||
celery_app: Celery application instance
|
||||
execute_task_name: Name of the Celery task for step execution
|
||||
"""
|
||||
self.cache_manager = cache_manager
|
||||
self.celery_app = celery_app
|
||||
self.execute_task_name = execute_task_name
|
||||
|
||||
def schedule(
|
||||
self,
|
||||
plan: ExecutionPlanSexp,
|
||||
timeout: int = 3600,
|
||||
) -> PlanResult:
|
||||
"""
|
||||
Schedule and execute a plan.
|
||||
|
||||
Args:
|
||||
plan: The execution plan (S-expression format)
|
||||
timeout: Timeout in seconds for the entire plan
|
||||
|
||||
Returns:
|
||||
PlanResult with execution results
|
||||
"""
|
||||
from celery import group
|
||||
|
||||
logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
|
||||
|
||||
# Build step lookup and group by level
|
||||
steps_by_id = {s.step_id: s for s in plan.steps}
|
||||
steps_by_level = self._group_by_level(plan.steps)
|
||||
max_level = max(steps_by_level.keys()) if steps_by_level else 0
|
||||
|
||||
# Track results
|
||||
result = PlanResult(
|
||||
plan_id=plan.plan_id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# Map step_id -> cache_id for resolving inputs
|
||||
cache_ids = dict(plan.inputs) # Start with input hashes
|
||||
for step in plan.steps:
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Execute level by level
|
||||
for level in range(max_level + 1):
|
||||
level_steps = steps_by_level.get(level, [])
|
||||
if not level_steps:
|
||||
continue
|
||||
|
||||
logger.info(f"Level {level}: {len(level_steps)} steps")
|
||||
|
||||
# Check cache for each step
|
||||
steps_to_run = []
|
||||
for step in level_steps:
|
||||
if self._is_cached(step.cache_id):
|
||||
result.steps_cached += 1
|
||||
result.step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="cached",
|
||||
output_path=self._get_cached_path(step.cache_id),
|
||||
)
|
||||
else:
|
||||
steps_to_run.append(step)
|
||||
|
||||
if not steps_to_run:
|
||||
logger.info(f"Level {level}: all {len(level_steps)} steps cached")
|
||||
continue
|
||||
|
||||
# Dispatch uncached steps to workers
|
||||
logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps")
|
||||
|
||||
tasks = []
|
||||
for step in steps_to_run:
|
||||
# Build task arguments
|
||||
step_sexp = step_sexp_to_string(step)
|
||||
input_cache_ids = {
|
||||
inp: cache_ids.get(inp, inp)
|
||||
for inp in step.inputs
|
||||
}
|
||||
|
||||
task = self._get_execute_task().s(
|
||||
step_sexp=step_sexp,
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
plan_id=plan.plan_id,
|
||||
input_cache_ids=input_cache_ids,
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute in parallel
|
||||
job = group(tasks)
|
||||
async_result = job.apply_async()
|
||||
|
||||
try:
|
||||
step_results = async_result.get(timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.error(f"Level {level} failed: {e}")
|
||||
result.status = "failed"
|
||||
result.error = f"Level {level} failed: {e}"
|
||||
return result
|
||||
|
||||
# Process results
|
||||
for step_result in step_results:
|
||||
step_id = step_result.get("step_id")
|
||||
status = step_result.get("status")
|
||||
|
||||
result.step_results[step_id] = StepResult(
|
||||
step_id=step_id,
|
||||
cache_id=step_result.get("cache_id"),
|
||||
status=status,
|
||||
output_path=step_result.get("output_path"),
|
||||
error=step_result.get("error"),
|
||||
ipfs_cid=step_result.get("ipfs_cid"),
|
||||
)
|
||||
|
||||
if status in ("completed", "cached", "completed_by_other"):
|
||||
result.steps_completed += 1
|
||||
elif status == "failed":
|
||||
result.steps_failed += 1
|
||||
result.status = "failed"
|
||||
result.error = step_result.get("error")
|
||||
return result
|
||||
|
||||
# Get final output
|
||||
output_step = steps_by_id.get(plan.output_step_id)
|
||||
if output_step:
|
||||
output_result = result.step_results.get(output_step.step_id)
|
||||
if output_result:
|
||||
result.output_cache_id = output_step.cache_id
|
||||
result.output_path = output_result.output_path
|
||||
result.output_ipfs_cid = output_result.ipfs_cid
|
||||
|
||||
result.status = "completed"
|
||||
logger.info(
|
||||
f"Plan {plan.plan_id[:16]}... completed: "
|
||||
f"{result.steps_completed} executed, {result.steps_cached} cached"
|
||||
)
|
||||
return result
|
||||
|
||||
def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]:
|
||||
"""Group steps by dependency level."""
|
||||
by_level = {}
|
||||
for step in steps:
|
||||
by_level.setdefault(step.level, []).append(step)
|
||||
return by_level
|
||||
|
||||
def _is_cached(self, cache_id: str) -> bool:
|
||||
"""Check if a cache_id exists in cache."""
|
||||
if self.cache_manager is None:
|
||||
return False
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return path is not None
|
||||
|
||||
def _get_cached_path(self, cache_id: str) -> Optional[str]:
|
||||
"""Get the path for a cached item."""
|
||||
if self.cache_manager is None:
|
||||
return None
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return str(path) if path else None
|
||||
|
||||
def _get_execute_task(self):
|
||||
"""Get the Celery task for step execution."""
|
||||
if self.celery_app is None:
|
||||
raise RuntimeError("Celery app not configured")
|
||||
return self.celery_app.tasks[self.execute_task_name]
|
||||
|
||||
|
||||
def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler:
|
||||
"""
|
||||
Create a scheduler with the given dependencies.
|
||||
|
||||
If not provided, attempts to import from art-celery.
|
||||
"""
|
||||
if celery_app is None:
|
||||
try:
|
||||
from celery_app import app as celery_app
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if cache_manager is None:
|
||||
try:
|
||||
from cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return PlanScheduler(
|
||||
cache_manager=cache_manager,
|
||||
celery_app=celery_app,
|
||||
)
|
||||
|
||||
|
||||
def schedule_plan(
|
||||
plan: ExecutionPlanSexp,
|
||||
cache_manager=None,
|
||||
celery_app=None,
|
||||
timeout: int = 3600,
|
||||
) -> PlanResult:
|
||||
"""
|
||||
Convenience function to schedule a plan.
|
||||
|
||||
Args:
|
||||
plan: The execution plan
|
||||
cache_manager: Optional cache manager
|
||||
celery_app: Optional Celery app
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
PlanResult
|
||||
"""
|
||||
scheduler = create_scheduler(cache_manager, celery_app)
|
||||
return scheduler.schedule(plan, timeout=timeout)
|
||||
|
||||
|
||||
# Stage-aware scheduling
|
||||
|
||||
@dataclass
|
||||
class StageResult:
|
||||
"""Result from executing a stage."""
|
||||
stage_name: str
|
||||
cache_id: str
|
||||
status: str # 'completed', 'cached', 'failed', 'pending'
|
||||
step_results: Dict[str, StepResult] = field(default_factory=dict)
|
||||
outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StagePlanResult:
|
||||
"""Result from executing a plan with stages."""
|
||||
plan_id: str
|
||||
status: str # 'completed', 'failed', 'partial'
|
||||
stages_completed: int = 0
|
||||
stages_cached: int = 0
|
||||
stages_failed: int = 0
|
||||
steps_completed: int = 0
|
||||
steps_cached: int = 0
|
||||
steps_failed: int = 0
|
||||
stage_results: Dict[str, StageResult] = field(default_factory=dict)
|
||||
output_cache_id: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class StagePlanScheduler:
|
||||
"""
|
||||
Stage-aware scheduler for S-expression plans.
|
||||
|
||||
The scheduler:
|
||||
1. Groups stages by level (parallel groups)
|
||||
2. For each stage level:
|
||||
- Check stage cache, skip entire stage if hit
|
||||
- Execute stage steps (grouped by level within stage)
|
||||
- Cache stage outputs
|
||||
3. Stages at same level can run in parallel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_manager=None,
|
||||
stage_cache=None,
|
||||
celery_app=None,
|
||||
execute_task_name: str = 'tasks.execute_step_sexp',
|
||||
):
|
||||
"""
|
||||
Initialize the stage-aware scheduler.
|
||||
|
||||
Args:
|
||||
cache_manager: L1 cache manager for step-level caching
|
||||
stage_cache: StageCache instance for stage-level caching
|
||||
celery_app: Celery application instance
|
||||
execute_task_name: Name of the Celery task for step execution
|
||||
"""
|
||||
self.cache_manager = cache_manager
|
||||
self.stage_cache = stage_cache
|
||||
self.celery_app = celery_app
|
||||
self.execute_task_name = execute_task_name
|
||||
|
||||
def schedule(
|
||||
self,
|
||||
plan: ExecutionPlanSexp,
|
||||
timeout: int = 3600,
|
||||
) -> StagePlanResult:
|
||||
"""
|
||||
Schedule and execute a plan with stage awareness.
|
||||
|
||||
If the plan has stages, uses stage-level scheduling.
|
||||
Otherwise, falls back to step-level scheduling.
|
||||
|
||||
Args:
|
||||
plan: The execution plan (S-expression format)
|
||||
timeout: Timeout in seconds for the entire plan
|
||||
|
||||
Returns:
|
||||
StagePlanResult with execution results
|
||||
"""
|
||||
# If no stages, use regular scheduling
|
||||
if not plan.stage_plans:
|
||||
logger.info("Plan has no stages, using step-level scheduling")
|
||||
regular_scheduler = PlanScheduler(
|
||||
cache_manager=self.cache_manager,
|
||||
celery_app=self.celery_app,
|
||||
execute_task_name=self.execute_task_name,
|
||||
)
|
||||
step_result = regular_scheduler.schedule(plan, timeout)
|
||||
return StagePlanResult(
|
||||
plan_id=step_result.plan_id,
|
||||
status=step_result.status,
|
||||
steps_completed=step_result.steps_completed,
|
||||
steps_cached=step_result.steps_cached,
|
||||
steps_failed=step_result.steps_failed,
|
||||
output_cache_id=step_result.output_cache_id,
|
||||
output_path=step_result.output_path,
|
||||
error=step_result.error,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduling plan {plan.plan_id[:16]}... "
|
||||
f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)"
|
||||
)
|
||||
|
||||
result = StagePlanResult(
|
||||
plan_id=plan.plan_id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# Group stages by level
|
||||
stages_by_level = self._group_stages_by_level(plan.stage_plans)
|
||||
max_level = max(stages_by_level.keys()) if stages_by_level else 0
|
||||
|
||||
# Track stage outputs for data flow
|
||||
stage_outputs = {} # stage_name -> {binding_name -> cache_id}
|
||||
|
||||
# Execute stage by stage level
|
||||
for level in range(max_level + 1):
|
||||
level_stages = stages_by_level.get(level, [])
|
||||
if not level_stages:
|
||||
continue
|
||||
|
||||
logger.info(f"Stage level {level}: {len(level_stages)} stages")
|
||||
|
||||
# Check stage cache for each stage
|
||||
stages_to_run = []
|
||||
for stage_plan in level_stages:
|
||||
if self._is_stage_cached(stage_plan.cache_id):
|
||||
result.stages_cached += 1
|
||||
cached_entry = self._load_cached_stage(stage_plan.cache_id)
|
||||
if cached_entry:
|
||||
stage_outputs[stage_plan.stage_name] = {
|
||||
name: out.cache_id
|
||||
for name, out in cached_entry.outputs.items()
|
||||
}
|
||||
result.stage_results[stage_plan.stage_name] = StageResult(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
status="cached",
|
||||
outputs=stage_outputs[stage_plan.stage_name],
|
||||
)
|
||||
logger.info(f"Stage {stage_plan.stage_name}: cached")
|
||||
else:
|
||||
stages_to_run.append(stage_plan)
|
||||
|
||||
if not stages_to_run:
|
||||
logger.info(f"Stage level {level}: all {len(level_stages)} stages cached")
|
||||
continue
|
||||
|
||||
# Execute uncached stages
|
||||
# For now, execute sequentially; L1 Celery will add parallel execution
|
||||
for stage_plan in stages_to_run:
|
||||
logger.info(f"Executing stage: {stage_plan.stage_name}")
|
||||
|
||||
stage_result = self._execute_stage(
|
||||
stage_plan,
|
||||
plan,
|
||||
stage_outputs,
|
||||
timeout,
|
||||
)
|
||||
|
||||
result.stage_results[stage_plan.stage_name] = stage_result
|
||||
|
||||
if stage_result.status == "completed":
|
||||
result.stages_completed += 1
|
||||
stage_outputs[stage_plan.stage_name] = stage_result.outputs
|
||||
|
||||
# Cache the stage result
|
||||
self._cache_stage(stage_plan, stage_result)
|
||||
elif stage_result.status == "failed":
|
||||
result.stages_failed += 1
|
||||
result.status = "failed"
|
||||
result.error = stage_result.error
|
||||
return result
|
||||
|
||||
# Accumulate step counts
|
||||
for sr in stage_result.step_results.values():
|
||||
if sr.status == "completed":
|
||||
result.steps_completed += 1
|
||||
elif sr.status == "cached":
|
||||
result.steps_cached += 1
|
||||
elif sr.status == "failed":
|
||||
result.steps_failed += 1
|
||||
|
||||
# Get final output
|
||||
if plan.stage_plans:
|
||||
last_stage = plan.stage_plans[-1]
|
||||
if last_stage.stage_name in result.stage_results:
|
||||
stage_res = result.stage_results[last_stage.stage_name]
|
||||
result.output_cache_id = last_stage.cache_id
|
||||
# Find the output step's path from step results
|
||||
for step_res in stage_res.step_results.values():
|
||||
if step_res.output_path:
|
||||
result.output_path = step_res.output_path
|
||||
|
||||
result.status = "completed"
|
||||
logger.info(
|
||||
f"Plan {plan.plan_id[:16]}... completed: "
|
||||
f"{result.stages_completed} stages executed, "
|
||||
f"{result.stages_cached} stages cached"
|
||||
)
|
||||
return result
|
||||
|
||||
def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]:
|
||||
"""Group stage plans by their level."""
|
||||
by_level = {}
|
||||
for stage_plan in stage_plans:
|
||||
by_level.setdefault(stage_plan.level, []).append(stage_plan)
|
||||
return by_level
|
||||
|
||||
def _is_stage_cached(self, cache_id: str) -> bool:
|
||||
"""Check if a stage is cached."""
|
||||
if self.stage_cache is None:
|
||||
return False
|
||||
return self.stage_cache.has_stage(cache_id)
|
||||
|
||||
def _load_cached_stage(self, cache_id: str):
|
||||
"""Load a cached stage entry."""
|
||||
if self.stage_cache is None:
|
||||
return None
|
||||
return self.stage_cache.load_stage(cache_id)
|
||||
|
||||
def _cache_stage(self, stage_plan, stage_result: StageResult) -> None:
|
||||
"""Cache a stage result."""
|
||||
if self.stage_cache is None:
|
||||
return
|
||||
|
||||
from .stage_cache import StageCacheEntry, StageOutput
|
||||
|
||||
outputs = {}
|
||||
for name, cache_id in stage_result.outputs.items():
|
||||
outputs[name] = StageOutput(
|
||||
cache_id=cache_id,
|
||||
output_type="artifact",
|
||||
)
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
outputs=outputs,
|
||||
)
|
||||
self.stage_cache.save_stage(entry)
|
||||
|
||||
def _execute_stage(
|
||||
self,
|
||||
stage_plan,
|
||||
plan: ExecutionPlanSexp,
|
||||
stage_outputs: Dict,
|
||||
timeout: int,
|
||||
) -> StageResult:
|
||||
"""
|
||||
Execute a single stage.
|
||||
|
||||
Uses the PlanScheduler to execute the stage's steps.
|
||||
"""
|
||||
# Create a mini-plan with just this stage's steps
|
||||
stage_steps = stage_plan.steps
|
||||
|
||||
# Build step lookup
|
||||
steps_by_id = {s.step_id: s for s in stage_steps}
|
||||
steps_by_level = {}
|
||||
for step in stage_steps:
|
||||
steps_by_level.setdefault(step.level, []).append(step)
|
||||
|
||||
max_level = max(steps_by_level.keys()) if steps_by_level else 0
|
||||
|
||||
# Track step results
|
||||
step_results = {}
|
||||
cache_ids = dict(plan.inputs) # Start with input hashes
|
||||
for step in plan.steps:
|
||||
cache_ids[step.step_id] = step.cache_id
|
||||
|
||||
# Include outputs from previous stages
|
||||
for stage_name, outputs in stage_outputs.items():
|
||||
for binding_name, binding_cache_id in outputs.items():
|
||||
cache_ids[binding_name] = binding_cache_id
|
||||
|
||||
# Execute steps level by level
|
||||
for level in range(max_level + 1):
|
||||
level_steps = steps_by_level.get(level, [])
|
||||
if not level_steps:
|
||||
continue
|
||||
|
||||
# Check cache for each step
|
||||
steps_to_run = []
|
||||
for step in level_steps:
|
||||
if self._is_step_cached(step.cache_id):
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="cached",
|
||||
output_path=self._get_cached_path(step.cache_id),
|
||||
)
|
||||
else:
|
||||
steps_to_run.append(step)
|
||||
|
||||
if not steps_to_run:
|
||||
continue
|
||||
|
||||
# Execute steps (for now, sequentially - L1 will add Celery dispatch)
|
||||
for step in steps_to_run:
|
||||
# In a full implementation, this would dispatch to Celery
|
||||
# For now, mark as pending
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="pending",
|
||||
)
|
||||
|
||||
# If Celery is configured, dispatch the task
|
||||
if self.celery_app:
|
||||
try:
|
||||
task_result = self._dispatch_step(step, cache_ids, timeout)
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status=task_result.get("status", "completed"),
|
||||
output_path=task_result.get("output_path"),
|
||||
error=task_result.get("error"),
|
||||
ipfs_cid=task_result.get("ipfs_cid"),
|
||||
)
|
||||
except Exception as e:
|
||||
step_results[step.step_id] = StepResult(
|
||||
step_id=step.step_id,
|
||||
cache_id=step.cache_id,
|
||||
status="failed",
|
||||
error=str(e),
|
||||
)
|
||||
return StageResult(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
status="failed",
|
||||
step_results=step_results,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Build output bindings
|
||||
outputs = {}
|
||||
for out_name, node_id in stage_plan.output_bindings.items():
|
||||
outputs[out_name] = cache_ids.get(node_id, node_id)
|
||||
|
||||
return StageResult(
|
||||
stage_name=stage_plan.stage_name,
|
||||
cache_id=stage_plan.cache_id,
|
||||
status="completed",
|
||||
step_results=step_results,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
def _is_step_cached(self, cache_id: str) -> bool:
|
||||
"""Check if a step is cached."""
|
||||
if self.cache_manager is None:
|
||||
return False
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return path is not None
|
||||
|
||||
def _get_cached_path(self, cache_id: str) -> Optional[str]:
|
||||
"""Get the path for a cached step."""
|
||||
if self.cache_manager is None:
|
||||
return None
|
||||
path = self.cache_manager.get_by_cid(cache_id)
|
||||
return str(path) if path else None
|
||||
|
||||
def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict:
|
||||
"""Dispatch a step to Celery for execution."""
|
||||
if self.celery_app is None:
|
||||
raise RuntimeError("Celery app not configured")
|
||||
|
||||
task = self.celery_app.tasks[self.execute_task_name]
|
||||
|
||||
step_sexp = step_sexp_to_string(step)
|
||||
input_cache_ids = {
|
||||
inp: cache_ids.get(inp, inp)
|
||||
for inp in step.inputs
|
||||
}
|
||||
|
||||
async_result = task.apply_async(
|
||||
kwargs={
|
||||
"step_sexp": step_sexp,
|
||||
"step_id": step.step_id,
|
||||
"cache_id": step.cache_id,
|
||||
"input_cache_ids": input_cache_ids,
|
||||
}
|
||||
)
|
||||
|
||||
return async_result.get(timeout=timeout)
|
||||
|
||||
|
||||
def create_stage_scheduler(
|
||||
cache_manager=None,
|
||||
stage_cache=None,
|
||||
celery_app=None,
|
||||
) -> StagePlanScheduler:
|
||||
"""
|
||||
Create a stage-aware scheduler with the given dependencies.
|
||||
|
||||
Args:
|
||||
cache_manager: L1 cache manager for step-level caching
|
||||
stage_cache: StageCache instance for stage-level caching
|
||||
celery_app: Celery application instance
|
||||
|
||||
Returns:
|
||||
StagePlanScheduler
|
||||
"""
|
||||
if celery_app is None:
|
||||
try:
|
||||
from celery_app import app as celery_app
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if cache_manager is None:
|
||||
try:
|
||||
from cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return StagePlanScheduler(
|
||||
cache_manager=cache_manager,
|
||||
stage_cache=stage_cache,
|
||||
celery_app=celery_app,
|
||||
)
|
||||
|
||||
|
||||
def schedule_staged_plan(
|
||||
plan: ExecutionPlanSexp,
|
||||
cache_manager=None,
|
||||
stage_cache=None,
|
||||
celery_app=None,
|
||||
timeout: int = 3600,
|
||||
) -> StagePlanResult:
|
||||
"""
|
||||
Convenience function to schedule a plan with stage awareness.
|
||||
|
||||
Args:
|
||||
plan: The execution plan
|
||||
cache_manager: Optional step-level cache manager
|
||||
stage_cache: Optional stage-level cache
|
||||
celery_app: Optional Celery app
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
StagePlanResult
|
||||
"""
|
||||
scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app)
|
||||
return scheduler.schedule(plan, timeout=timeout)
|
||||
412
core/artdag/sexp/stage_cache.py
Normal file
412
core/artdag/sexp/stage_cache.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""
|
||||
Stage-level cache layer using S-expression storage.
|
||||
|
||||
Provides caching for stage outputs, enabling:
|
||||
- Stage-level cache hits (skip entire stage execution)
|
||||
- Analysis result persistence as sexp
|
||||
- Cross-worker stage cache sharing (for L1 Celery integration)
|
||||
|
||||
All cache files use .sexp extension - no JSON in the pipeline.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .parser import Symbol, Keyword, parse, serialize
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageOutput:
|
||||
"""A single output from a stage."""
|
||||
cache_id: Optional[str] = None # For artifacts (files, analysis data)
|
||||
value: Any = None # For scalar values
|
||||
output_type: str = "artifact" # "artifact", "analysis", "scalar"
|
||||
|
||||
def to_sexp(self) -> List:
|
||||
"""Convert to S-expression."""
|
||||
sexp = []
|
||||
if self.cache_id:
|
||||
sexp.extend([Keyword("cache-id"), self.cache_id])
|
||||
if self.value is not None:
|
||||
sexp.extend([Keyword("value"), self.value])
|
||||
sexp.extend([Keyword("type"), Keyword(self.output_type)])
|
||||
return sexp
|
||||
|
||||
@classmethod
|
||||
def from_sexp(cls, sexp: List) -> 'StageOutput':
|
||||
"""Parse from S-expression list."""
|
||||
cache_id = None
|
||||
value = None
|
||||
output_type = "artifact"
|
||||
|
||||
i = 0
|
||||
while i < len(sexp):
|
||||
if isinstance(sexp[i], Keyword):
|
||||
key = sexp[i].name
|
||||
if i + 1 < len(sexp):
|
||||
val = sexp[i + 1]
|
||||
if key == "cache-id":
|
||||
cache_id = val
|
||||
elif key == "value":
|
||||
value = val
|
||||
elif key == "type":
|
||||
if isinstance(val, Keyword):
|
||||
output_type = val.name
|
||||
else:
|
||||
output_type = str(val)
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return cls(cache_id=cache_id, value=value, output_type=output_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StageCacheEntry:
|
||||
"""Cached result of a stage execution."""
|
||||
stage_name: str
|
||||
cache_id: str
|
||||
outputs: Dict[str, StageOutput] # binding_name -> output info
|
||||
completed_at: float = field(default_factory=time.time)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_sexp(self) -> List:
|
||||
"""
|
||||
Convert to S-expression for storage.
|
||||
|
||||
Format:
|
||||
(stage-result
|
||||
:name "analyze-a"
|
||||
:cache-id "abc123..."
|
||||
:completed-at 1705678900.123
|
||||
:outputs
|
||||
((beats-a :cache-id "def456..." :type :analysis)
|
||||
(tempo :value 120.5 :type :scalar)))
|
||||
"""
|
||||
sexp = [Symbol("stage-result")]
|
||||
sexp.extend([Keyword("name"), self.stage_name])
|
||||
sexp.extend([Keyword("cache-id"), self.cache_id])
|
||||
sexp.extend([Keyword("completed-at"), self.completed_at])
|
||||
|
||||
if self.outputs:
|
||||
outputs_sexp = []
|
||||
for name, output in self.outputs.items():
|
||||
output_sexp = [Symbol(name)] + output.to_sexp()
|
||||
outputs_sexp.append(output_sexp)
|
||||
sexp.extend([Keyword("outputs"), outputs_sexp])
|
||||
|
||||
if self.metadata:
|
||||
sexp.extend([Keyword("metadata"), self.metadata])
|
||||
|
||||
return sexp
|
||||
|
||||
def to_string(self, pretty: bool = True) -> str:
|
||||
"""Serialize to S-expression string."""
|
||||
return serialize(self.to_sexp(), pretty=pretty)
|
||||
|
||||
@classmethod
|
||||
def from_sexp(cls, sexp: List) -> 'StageCacheEntry':
|
||||
"""Parse from S-expression."""
|
||||
if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "stage-result":
|
||||
raise ValueError("Invalid stage-result sexp")
|
||||
|
||||
stage_name = None
|
||||
cache_id = None
|
||||
completed_at = time.time()
|
||||
outputs = {}
|
||||
metadata = {}
|
||||
|
||||
i = 1
|
||||
while i < len(sexp):
|
||||
if isinstance(sexp[i], Keyword):
|
||||
key = sexp[i].name
|
||||
if i + 1 < len(sexp):
|
||||
val = sexp[i + 1]
|
||||
if key == "name":
|
||||
stage_name = val
|
||||
elif key == "cache-id":
|
||||
cache_id = val
|
||||
elif key == "completed-at":
|
||||
completed_at = float(val)
|
||||
elif key == "outputs":
|
||||
if isinstance(val, list):
|
||||
for output_sexp in val:
|
||||
if isinstance(output_sexp, list) and output_sexp:
|
||||
out_name = output_sexp[0]
|
||||
if isinstance(out_name, Symbol):
|
||||
out_name = out_name.name
|
||||
outputs[out_name] = StageOutput.from_sexp(output_sexp[1:])
|
||||
elif key == "metadata":
|
||||
metadata = val if isinstance(val, dict) else {}
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if not stage_name or not cache_id:
|
||||
raise ValueError("stage-result missing required fields (name, cache-id)")
|
||||
|
||||
return cls(
|
||||
stage_name=stage_name,
|
||||
cache_id=cache_id,
|
||||
outputs=outputs,
|
||||
completed_at=completed_at,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, text: str) -> 'StageCacheEntry':
|
||||
"""Parse from S-expression string."""
|
||||
sexp = parse(text)
|
||||
return cls.from_sexp(sexp)
|
||||
|
||||
|
||||
class StageCache:
|
||||
"""
|
||||
Stage-level cache manager using S-expression files.
|
||||
|
||||
Cache structure:
|
||||
cache_dir/
|
||||
_stages/
|
||||
{cache_id}.sexp <- Stage result files
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Union[str, Path]):
|
||||
"""
|
||||
Initialize stage cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Base cache directory
|
||||
"""
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.stages_dir = self.cache_dir / "_stages"
|
||||
self.stages_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_cache_path(self, cache_id: str) -> Path:
|
||||
"""Get the path for a stage cache file."""
|
||||
return self.stages_dir / f"{cache_id}.sexp"
|
||||
|
||||
def has_stage(self, cache_id: str) -> bool:
|
||||
"""Check if a stage result is cached."""
|
||||
return self.get_cache_path(cache_id).exists()
|
||||
|
||||
def load_stage(self, cache_id: str) -> Optional[StageCacheEntry]:
|
||||
"""
|
||||
Load a cached stage result.
|
||||
|
||||
Args:
|
||||
cache_id: Stage cache ID
|
||||
|
||||
Returns:
|
||||
StageCacheEntry if found, None otherwise
|
||||
"""
|
||||
path = self.get_cache_path(cache_id)
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text()
|
||||
return StageCacheEntry.from_string(content)
|
||||
except Exception as e:
|
||||
# Corrupted cache file - log and return None
|
||||
import sys
|
||||
print(f"Warning: corrupted stage cache {cache_id}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def save_stage(self, entry: StageCacheEntry) -> Path:
|
||||
"""
|
||||
Save a stage result to cache.
|
||||
|
||||
Args:
|
||||
entry: Stage cache entry to save
|
||||
|
||||
Returns:
|
||||
Path to the saved cache file
|
||||
"""
|
||||
path = self.get_cache_path(entry.cache_id)
|
||||
content = entry.to_string(pretty=True)
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
def delete_stage(self, cache_id: str) -> bool:
|
||||
"""
|
||||
Delete a cached stage result.
|
||||
|
||||
Args:
|
||||
cache_id: Stage cache ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
path = self.get_cache_path(cache_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_stages(self) -> List[str]:
|
||||
"""List all cached stage IDs."""
|
||||
return [
|
||||
p.stem for p in self.stages_dir.glob("*.sexp")
|
||||
]
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cached stages.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = 0
|
||||
for path in self.stages_dir.glob("*.sexp"):
|
||||
path.unlink()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
"""
|
||||
Analysis result stored as S-expression.
|
||||
|
||||
Format:
|
||||
(analysis-result
|
||||
:analyzer "beats"
|
||||
:input-hash "abc123..."
|
||||
:duration 120.5
|
||||
:tempo 128.0
|
||||
:times (0.0 0.468 0.937 1.406 ...)
|
||||
:values (0.8 0.9 0.7 0.85 ...))
|
||||
"""
|
||||
analyzer: str
|
||||
input_hash: str
|
||||
data: Dict[str, Any] # Analysis data (times, values, duration, etc.)
|
||||
computed_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_sexp(self) -> List:
|
||||
"""Convert to S-expression."""
|
||||
sexp = [Symbol("analysis-result")]
|
||||
sexp.extend([Keyword("analyzer"), self.analyzer])
|
||||
sexp.extend([Keyword("input-hash"), self.input_hash])
|
||||
sexp.extend([Keyword("computed-at"), self.computed_at])
|
||||
|
||||
# Add all data fields
|
||||
for key, value in self.data.items():
|
||||
# Convert key to keyword
|
||||
sexp.extend([Keyword(key.replace("_", "-")), value])
|
||||
|
||||
return sexp
|
||||
|
||||
def to_string(self, pretty: bool = True) -> str:
|
||||
"""Serialize to S-expression string."""
|
||||
return serialize(self.to_sexp(), pretty=pretty)
|
||||
|
||||
@classmethod
|
||||
def from_sexp(cls, sexp: List) -> 'AnalysisResult':
|
||||
"""Parse from S-expression."""
|
||||
if not sexp or not isinstance(sexp[0], Symbol) or sexp[0].name != "analysis-result":
|
||||
raise ValueError("Invalid analysis-result sexp")
|
||||
|
||||
analyzer = None
|
||||
input_hash = None
|
||||
computed_at = time.time()
|
||||
data = {}
|
||||
|
||||
i = 1
|
||||
while i < len(sexp):
|
||||
if isinstance(sexp[i], Keyword):
|
||||
key = sexp[i].name
|
||||
if i + 1 < len(sexp):
|
||||
val = sexp[i + 1]
|
||||
if key == "analyzer":
|
||||
analyzer = val
|
||||
elif key == "input-hash":
|
||||
input_hash = val
|
||||
elif key == "computed-at":
|
||||
computed_at = float(val)
|
||||
else:
|
||||
# Convert kebab-case back to snake_case
|
||||
data_key = key.replace("-", "_")
|
||||
data[data_key] = val
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
if not analyzer:
|
||||
raise ValueError("analysis-result missing analyzer field")
|
||||
|
||||
return cls(
|
||||
analyzer=analyzer,
|
||||
input_hash=input_hash or "",
|
||||
data=data,
|
||||
computed_at=computed_at,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, text: str) -> 'AnalysisResult':
|
||||
"""Parse from S-expression string."""
|
||||
sexp = parse(text)
|
||||
return cls.from_sexp(sexp)
|
||||
|
||||
|
||||
def save_analysis_result(
|
||||
cache_dir: Union[str, Path],
|
||||
node_id: str,
|
||||
result: AnalysisResult,
|
||||
) -> Path:
|
||||
"""
|
||||
Save an analysis result as S-expression.
|
||||
|
||||
Args:
|
||||
cache_dir: Base cache directory
|
||||
node_id: Node ID for the analysis
|
||||
result: Analysis result to save
|
||||
|
||||
Returns:
|
||||
Path to the saved file
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
node_dir = cache_dir / node_id
|
||||
node_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = node_dir / "analysis.sexp"
|
||||
content = result.to_string(pretty=True)
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
|
||||
def load_analysis_result(
|
||||
cache_dir: Union[str, Path],
|
||||
node_id: str,
|
||||
) -> Optional[AnalysisResult]:
|
||||
"""
|
||||
Load an analysis result from cache.
|
||||
|
||||
Args:
|
||||
cache_dir: Base cache directory
|
||||
node_id: Node ID for the analysis
|
||||
|
||||
Returns:
|
||||
AnalysisResult if found, None otherwise
|
||||
"""
|
||||
cache_dir = Path(cache_dir)
|
||||
path = cache_dir / node_id / "analysis.sexp"
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
content = path.read_text()
|
||||
return AnalysisResult.from_string(content)
|
||||
except Exception as e:
|
||||
import sys
|
||||
print(f"Warning: corrupted analysis cache {node_id}: {e}", file=sys.stderr)
|
||||
return None
|
||||
146
core/artdag/sexp/test_ffmpeg_compiler.py
Normal file
146
core/artdag/sexp/test_ffmpeg_compiler.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
Tests for FFmpeg filter compilation.
|
||||
|
||||
Validates that each filter mapping produces valid FFmpeg commands.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .ffmpeg_compiler import FFmpegCompiler, EFFECT_MAPPINGS
|
||||
|
||||
|
||||
def test_filter_syntax(filter_str: str, duration: float = 0.1, is_complex: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Test if an FFmpeg filter string is valid by running it on a test pattern.
|
||||
|
||||
Args:
|
||||
filter_str: The filter string to test
|
||||
duration: Duration of test video
|
||||
is_complex: If True, use -filter_complex instead of -vf
|
||||
|
||||
Returns (success, error_message)
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as f:
|
||||
output_path = f.name
|
||||
|
||||
try:
|
||||
if is_complex:
|
||||
# Complex filter graph needs -filter_complex and explicit output mapping
|
||||
cmd = [
|
||||
'ffmpeg', '-y',
|
||||
'-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10',
|
||||
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
|
||||
'-filter_complex', f'[0:v]{filter_str}[out]',
|
||||
'-map', '[out]', '-map', '1:a',
|
||||
'-c:v', 'libx264', '-preset', 'ultrafast',
|
||||
'-c:a', 'aac',
|
||||
'-t', str(duration),
|
||||
output_path
|
||||
]
|
||||
else:
|
||||
# Simple filter uses -vf
|
||||
cmd = [
|
||||
'ffmpeg', '-y',
|
||||
'-f', 'lavfi', '-i', f'testsrc=duration={duration}:size=64x64:rate=10',
|
||||
'-f', 'lavfi', '-i', f'sine=frequency=440:duration={duration}',
|
||||
'-vf', filter_str,
|
||||
'-c:v', 'libx264', '-preset', 'ultrafast',
|
||||
'-c:a', 'aac',
|
||||
'-t', str(duration),
|
||||
output_path
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
|
||||
if result.returncode == 0:
|
||||
return True, ""
|
||||
else:
|
||||
# Extract relevant error
|
||||
stderr = result.stderr
|
||||
for line in stderr.split('\n'):
|
||||
if 'Error' in line or 'error' in line or 'Invalid' in line:
|
||||
return False, line.strip()
|
||||
return False, stderr[-500:] if len(stderr) > 500 else stderr
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "Timeout"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
finally:
|
||||
Path(output_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Test all effect mappings."""
|
||||
compiler = FFmpegCompiler()
|
||||
results = []
|
||||
|
||||
for effect_name, mapping in EFFECT_MAPPINGS.items():
|
||||
filter_name = mapping.get("filter")
|
||||
|
||||
# Skip effects with no FFmpeg equivalent (external tools or python primitives)
|
||||
if filter_name is None:
|
||||
reason = "No FFmpeg equivalent"
|
||||
if mapping.get("external_tool"):
|
||||
reason = f"External tool: {mapping['external_tool']}"
|
||||
elif mapping.get("python_primitive"):
|
||||
reason = f"Python primitive: {mapping['python_primitive']}"
|
||||
results.append((effect_name, "SKIP", reason))
|
||||
continue
|
||||
|
||||
# Check if complex filter
|
||||
is_complex = mapping.get("complex", False)
|
||||
|
||||
# Build filter string
|
||||
if "static" in mapping:
|
||||
filter_str = f"{filter_name}={mapping['static']}"
|
||||
else:
|
||||
filter_str = filter_name
|
||||
|
||||
# Test it
|
||||
success, error = test_filter_syntax(filter_str, is_complex=is_complex)
|
||||
|
||||
if success:
|
||||
results.append((effect_name, "PASS", filter_str))
|
||||
else:
|
||||
results.append((effect_name, "FAIL", f"{filter_str} -> {error}"))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_results(results):
|
||||
"""Print test results."""
|
||||
passed = sum(1 for _, status, _ in results if status == "PASS")
|
||||
failed = sum(1 for _, status, _ in results if status == "FAIL")
|
||||
skipped = sum(1 for _, status, _ in results if status == "SKIP")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"FFmpeg Filter Tests: {passed} passed, {failed} failed, {skipped} skipped")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Print failures first
|
||||
if failed > 0:
|
||||
print("FAILURES:")
|
||||
for name, status, msg in results:
|
||||
if status == "FAIL":
|
||||
print(f" {name}: {msg}")
|
||||
print()
|
||||
|
||||
# Print passes
|
||||
print("PASSED:")
|
||||
for name, status, msg in results:
|
||||
if status == "PASS":
|
||||
print(f" {name}: {msg}")
|
||||
|
||||
# Print skips
|
||||
if skipped > 0:
|
||||
print("\nSKIPPED (Python fallback):")
|
||||
for name, status, msg in results:
|
||||
if status == "SKIP":
|
||||
print(f" {name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
results = run_all_tests()
|
||||
print_results(results)
|
||||
201
core/artdag/sexp/test_primitives.py
Normal file
201
core/artdag/sexp/test_primitives.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Tests for Python primitive effects.
|
||||
|
||||
Tests that ascii_art, ascii_zones, and other Python primitives
|
||||
can be executed via the EffectExecutor.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
HAS_DEPS = True
|
||||
except ImportError:
|
||||
HAS_DEPS = False
|
||||
|
||||
from .primitives import (
|
||||
ascii_art_frame,
|
||||
ascii_zones_frame,
|
||||
get_primitive,
|
||||
list_primitives,
|
||||
)
|
||||
from .ffmpeg_compiler import FFmpegCompiler
|
||||
|
||||
|
||||
def create_test_video(path: Path, duration: float = 0.5, size: str = "64x64") -> bool:
|
||||
"""Create a short test video using ffmpeg."""
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", f"testsrc=duration={duration}:size={size}:rate=10",
|
||||
"-c:v", "libx264", "-preset", "ultrafast",
|
||||
str(path)
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True)
|
||||
return result.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available")
|
||||
class TestPrimitives:
|
||||
"""Test primitive functions directly."""
|
||||
|
||||
def test_ascii_art_frame_basic(self):
|
||||
"""Test ascii_art_frame produces output of same shape."""
|
||||
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
result = ascii_art_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
assert result.dtype == np.uint8
|
||||
|
||||
def test_ascii_zones_frame_basic(self):
|
||||
"""Test ascii_zones_frame produces output of same shape."""
|
||||
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
result = ascii_zones_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
assert result.dtype == np.uint8
|
||||
|
||||
def test_get_primitive(self):
|
||||
"""Test primitive lookup."""
|
||||
assert get_primitive("ascii_art_frame") is ascii_art_frame
|
||||
assert get_primitive("ascii_zones_frame") is ascii_zones_frame
|
||||
assert get_primitive("nonexistent") is None
|
||||
|
||||
def test_list_primitives(self):
|
||||
"""Test listing primitives."""
|
||||
primitives = list_primitives()
|
||||
assert "ascii_art_frame" in primitives
|
||||
assert "ascii_zones_frame" in primitives
|
||||
assert len(primitives) > 5
|
||||
|
||||
|
||||
class TestFFmpegCompilerPrimitives:
|
||||
"""Test FFmpegCompiler python_primitive mappings."""
|
||||
|
||||
def test_has_python_primitive_ascii_art(self):
|
||||
"""Test ascii_art has python_primitive."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame"
|
||||
|
||||
def test_has_python_primitive_ascii_zones(self):
|
||||
"""Test ascii_zones has python_primitive."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame"
|
||||
|
||||
def test_has_python_primitive_ffmpeg_effect(self):
|
||||
"""Test FFmpeg effects don't have python_primitive."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.has_python_primitive("brightness") is None
|
||||
assert compiler.has_python_primitive("blur") is None
|
||||
|
||||
def test_compile_effect_returns_none_for_primitives(self):
|
||||
"""Test compile_effect returns None for primitive effects."""
|
||||
compiler = FFmpegCompiler()
|
||||
assert compiler.compile_effect("ascii_art", {}) is None
|
||||
assert compiler.compile_effect("ascii_zones", {}) is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_DEPS, reason="numpy/PIL not available")
|
||||
class TestEffectExecutorPrimitives:
|
||||
"""Test EffectExecutor with Python primitives."""
|
||||
|
||||
def test_executor_loads_primitive(self):
|
||||
"""Test that executor finds primitive effects."""
|
||||
from ..nodes.effect import _get_python_primitive_effect
|
||||
|
||||
effect_fn = _get_python_primitive_effect("ascii_art")
|
||||
assert effect_fn is not None
|
||||
|
||||
effect_fn = _get_python_primitive_effect("ascii_zones")
|
||||
assert effect_fn is not None
|
||||
|
||||
def test_executor_rejects_unknown_effect(self):
|
||||
"""Test that executor returns None for unknown effects."""
|
||||
from ..nodes.effect import _get_python_primitive_effect
|
||||
|
||||
effect_fn = _get_python_primitive_effect("nonexistent_effect")
|
||||
assert effect_fn is None
|
||||
|
||||
def test_execute_ascii_art_effect(self, tmp_path):
|
||||
"""Test executing ascii_art effect on a video."""
|
||||
from ..nodes.effect import EffectExecutor
|
||||
|
||||
# Create test video
|
||||
input_path = tmp_path / "input.mp4"
|
||||
if not create_test_video(input_path):
|
||||
pytest.skip("Could not create test video")
|
||||
|
||||
output_path = tmp_path / "output.mkv"
|
||||
executor = EffectExecutor()
|
||||
|
||||
result = executor.execute(
|
||||
config={"effect": "ascii_art", "char_size": 8},
|
||||
inputs=[input_path],
|
||||
output_path=output_path,
|
||||
)
|
||||
|
||||
assert result.exists()
|
||||
assert result.stat().st_size > 0
|
||||
|
||||
|
||||
def run_all_tests():
|
||||
"""Run tests manually."""
|
||||
import sys
|
||||
|
||||
# Check dependencies
|
||||
if not HAS_DEPS:
|
||||
print("SKIP: numpy/PIL not available")
|
||||
return
|
||||
|
||||
print("Testing primitives...")
|
||||
|
||||
# Test primitive functions
|
||||
frame = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
|
||||
print(" ascii_art_frame...", end=" ")
|
||||
result = ascii_art_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
print("PASS")
|
||||
|
||||
print(" ascii_zones_frame...", end=" ")
|
||||
result = ascii_zones_frame(frame, char_size=8)
|
||||
assert result.shape == frame.shape
|
||||
print("PASS")
|
||||
|
||||
# Test FFmpegCompiler mappings
|
||||
print("\nTesting FFmpegCompiler mappings...")
|
||||
compiler = FFmpegCompiler()
|
||||
|
||||
print(" ascii_art python_primitive...", end=" ")
|
||||
assert compiler.has_python_primitive("ascii_art") == "ascii_art_frame"
|
||||
print("PASS")
|
||||
|
||||
print(" ascii_zones python_primitive...", end=" ")
|
||||
assert compiler.has_python_primitive("ascii_zones") == "ascii_zones_frame"
|
||||
print("PASS")
|
||||
|
||||
# Test executor lookup
|
||||
print("\nTesting EffectExecutor...")
|
||||
try:
|
||||
from ..nodes.effect import _get_python_primitive_effect
|
||||
|
||||
print(" _get_python_primitive_effect(ascii_art)...", end=" ")
|
||||
effect_fn = _get_python_primitive_effect("ascii_art")
|
||||
assert effect_fn is not None
|
||||
print("PASS")
|
||||
|
||||
print(" _get_python_primitive_effect(ascii_zones)...", end=" ")
|
||||
effect_fn = _get_python_primitive_effect("ascii_zones")
|
||||
assert effect_fn is not None
|
||||
print("PASS")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"SKIP: {e}")
|
||||
|
||||
print("\n=== All tests passed ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_all_tests()
|
||||
324
core/artdag/sexp/test_stage_cache.py
Normal file
324
core/artdag/sexp/test_stage_cache.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
Tests for stage cache layer.
|
||||
|
||||
Tests S-expression storage for stage results and analysis data.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .stage_cache import (
|
||||
StageCache,
|
||||
StageCacheEntry,
|
||||
StageOutput,
|
||||
AnalysisResult,
|
||||
save_analysis_result,
|
||||
load_analysis_result,
|
||||
)
|
||||
from .parser import parse, serialize
|
||||
|
||||
|
||||
class TestStageOutput:
|
||||
"""Test StageOutput dataclass and serialization."""
|
||||
|
||||
def test_stage_output_artifact(self):
|
||||
"""StageOutput can represent an artifact."""
|
||||
output = StageOutput(
|
||||
cache_id="abc123",
|
||||
output_type="artifact",
|
||||
)
|
||||
assert output.cache_id == "abc123"
|
||||
assert output.output_type == "artifact"
|
||||
|
||||
def test_stage_output_scalar(self):
|
||||
"""StageOutput can represent a scalar value."""
|
||||
output = StageOutput(
|
||||
value=120.5,
|
||||
output_type="scalar",
|
||||
)
|
||||
assert output.value == 120.5
|
||||
assert output.output_type == "scalar"
|
||||
|
||||
def test_stage_output_to_sexp(self):
|
||||
"""StageOutput serializes to sexp."""
|
||||
output = StageOutput(
|
||||
cache_id="abc123",
|
||||
output_type="artifact",
|
||||
)
|
||||
sexp = output.to_sexp()
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "cache-id" in sexp_str
|
||||
assert "abc123" in sexp_str
|
||||
assert "type" in sexp_str
|
||||
assert "artifact" in sexp_str
|
||||
|
||||
def test_stage_output_from_sexp(self):
|
||||
"""StageOutput parses from sexp."""
|
||||
sexp = parse('(:cache-id "def456" :type :analysis)')
|
||||
output = StageOutput.from_sexp(sexp)
|
||||
|
||||
assert output.cache_id == "def456"
|
||||
assert output.output_type == "analysis"
|
||||
|
||||
|
||||
class TestStageCacheEntry:
|
||||
"""Test StageCacheEntry serialization."""
|
||||
|
||||
def test_stage_cache_entry_to_sexp(self):
|
||||
"""StageCacheEntry serializes to sexp."""
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze-a",
|
||||
cache_id="stage_abc123",
|
||||
outputs={
|
||||
"beats": StageOutput(cache_id="beats_def456", output_type="analysis"),
|
||||
"tempo": StageOutput(value=120.5, output_type="scalar"),
|
||||
},
|
||||
completed_at=1705678900.123,
|
||||
)
|
||||
|
||||
sexp = entry.to_sexp()
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "stage-result" in sexp_str
|
||||
assert "analyze-a" in sexp_str
|
||||
assert "stage_abc123" in sexp_str
|
||||
assert "outputs" in sexp_str
|
||||
assert "beats" in sexp_str
|
||||
|
||||
def test_stage_cache_entry_roundtrip(self):
|
||||
"""save -> load produces identical data."""
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze-b",
|
||||
cache_id="stage_xyz789",
|
||||
outputs={
|
||||
"segments": StageOutput(cache_id="seg_123", output_type="artifact"),
|
||||
},
|
||||
completed_at=1705678900.0,
|
||||
)
|
||||
|
||||
sexp_str = entry.to_string()
|
||||
loaded = StageCacheEntry.from_string(sexp_str)
|
||||
|
||||
assert loaded.stage_name == entry.stage_name
|
||||
assert loaded.cache_id == entry.cache_id
|
||||
assert "segments" in loaded.outputs
|
||||
assert loaded.outputs["segments"].cache_id == "seg_123"
|
||||
|
||||
def test_stage_cache_entry_from_sexp(self):
|
||||
"""StageCacheEntry parses from sexp."""
|
||||
sexp_str = '''
|
||||
(stage-result
|
||||
:name "test-stage"
|
||||
:cache-id "cache123"
|
||||
:completed-at 1705678900.0
|
||||
:outputs ((beats :cache-id "beats123" :type :analysis)))
|
||||
'''
|
||||
entry = StageCacheEntry.from_string(sexp_str)
|
||||
|
||||
assert entry.stage_name == "test-stage"
|
||||
assert entry.cache_id == "cache123"
|
||||
assert "beats" in entry.outputs
|
||||
assert entry.outputs["beats"].cache_id == "beats123"
|
||||
|
||||
|
||||
class TestStageCache:
|
||||
"""Test StageCache file operations."""
|
||||
|
||||
def test_save_and_load_stage(self):
|
||||
"""Save and load a stage result."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="test_cache_id",
|
||||
outputs={
|
||||
"beats": StageOutput(cache_id="beats_out", output_type="analysis"),
|
||||
},
|
||||
)
|
||||
|
||||
path = cache.save_stage(entry)
|
||||
assert path.exists()
|
||||
assert path.suffix == ".sexp"
|
||||
|
||||
loaded = cache.load_stage("test_cache_id")
|
||||
assert loaded is not None
|
||||
assert loaded.stage_name == "analyze"
|
||||
assert "beats" in loaded.outputs
|
||||
|
||||
def test_has_stage(self):
|
||||
"""Check if stage is cached."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
assert not cache.has_stage("nonexistent")
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name="test",
|
||||
cache_id="exists_cache_id",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
assert cache.has_stage("exists_cache_id")
|
||||
|
||||
def test_delete_stage(self):
|
||||
"""Delete a cached stage."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
entry = StageCacheEntry(
|
||||
stage_name="test",
|
||||
cache_id="to_delete",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
assert cache.has_stage("to_delete")
|
||||
result = cache.delete_stage("to_delete")
|
||||
assert result is True
|
||||
assert not cache.has_stage("to_delete")
|
||||
|
||||
def test_list_stages(self):
|
||||
"""List all cached stages."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
for i in range(3):
|
||||
entry = StageCacheEntry(
|
||||
stage_name=f"stage{i}",
|
||||
cache_id=f"cache_{i}",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
stages = cache.list_stages()
|
||||
assert len(stages) == 3
|
||||
assert "cache_0" in stages
|
||||
assert "cache_1" in stages
|
||||
assert "cache_2" in stages
|
||||
|
||||
def test_clear(self):
|
||||
"""Clear all cached stages."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
for i in range(3):
|
||||
entry = StageCacheEntry(
|
||||
stage_name=f"stage{i}",
|
||||
cache_id=f"cache_{i}",
|
||||
outputs={},
|
||||
)
|
||||
cache.save_stage(entry)
|
||||
|
||||
count = cache.clear()
|
||||
assert count == 3
|
||||
assert len(cache.list_stages()) == 0
|
||||
|
||||
def test_cache_file_extension(self):
|
||||
"""Cache files use .sexp extension."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
path = cache.get_cache_path("test_id")
|
||||
assert path.suffix == ".sexp"
|
||||
|
||||
def test_invalid_sexp_error_handling(self):
|
||||
"""Graceful error on corrupted cache file."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
cache = StageCache(tmpdir)
|
||||
|
||||
# Write corrupted content
|
||||
corrupt_path = cache.get_cache_path("corrupted")
|
||||
corrupt_path.write_text("this is not valid sexp )()(")
|
||||
|
||||
# Should return None, not raise
|
||||
result = cache.load_stage("corrupted")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAnalysisResult:
|
||||
"""Test AnalysisResult serialization."""
|
||||
|
||||
def test_analysis_result_to_sexp(self):
|
||||
"""AnalysisResult serializes to sexp."""
|
||||
result = AnalysisResult(
|
||||
analyzer="beats",
|
||||
input_hash="input_abc123",
|
||||
data={
|
||||
"duration": 120.5,
|
||||
"tempo": 128.0,
|
||||
"times": [0.0, 0.468, 0.937, 1.406],
|
||||
"values": [0.8, 0.9, 0.7, 0.85],
|
||||
},
|
||||
)
|
||||
|
||||
sexp = result.to_sexp()
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "analysis-result" in sexp_str
|
||||
assert "beats" in sexp_str
|
||||
assert "duration" in sexp_str
|
||||
assert "tempo" in sexp_str
|
||||
assert "times" in sexp_str
|
||||
|
||||
def test_analysis_result_roundtrip(self):
|
||||
"""Analysis result round-trips through sexp."""
|
||||
original = AnalysisResult(
|
||||
analyzer="scenes",
|
||||
input_hash="video_xyz",
|
||||
data={
|
||||
"scene_count": 5,
|
||||
"scene_times": [0.0, 10.5, 25.0, 45.2, 60.0],
|
||||
},
|
||||
)
|
||||
|
||||
sexp_str = original.to_string()
|
||||
loaded = AnalysisResult.from_string(sexp_str)
|
||||
|
||||
assert loaded.analyzer == original.analyzer
|
||||
assert loaded.input_hash == original.input_hash
|
||||
assert loaded.data["scene_count"] == 5
|
||||
|
||||
def test_save_and_load_analysis_result(self):
|
||||
"""Save and load analysis result from cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
result = AnalysisResult(
|
||||
analyzer="beats",
|
||||
input_hash="audio_123",
|
||||
data={
|
||||
"tempo": 120.0,
|
||||
"times": [0.0, 0.5, 1.0],
|
||||
},
|
||||
)
|
||||
|
||||
path = save_analysis_result(tmpdir, "node_abc", result)
|
||||
assert path.exists()
|
||||
assert path.name == "analysis.sexp"
|
||||
|
||||
loaded = load_analysis_result(tmpdir, "node_abc")
|
||||
assert loaded is not None
|
||||
assert loaded.analyzer == "beats"
|
||||
assert loaded.data["tempo"] == 120.0
|
||||
|
||||
def test_analysis_result_kebab_case(self):
|
||||
"""Keys convert between snake_case and kebab-case."""
|
||||
result = AnalysisResult(
|
||||
analyzer="test",
|
||||
input_hash="hash",
|
||||
data={
|
||||
"scene_count": 5,
|
||||
"beat_times": [1, 2, 3],
|
||||
},
|
||||
)
|
||||
|
||||
sexp_str = result.to_string()
|
||||
# Kebab case in sexp
|
||||
assert "scene-count" in sexp_str
|
||||
assert "beat-times" in sexp_str
|
||||
|
||||
# Back to snake_case after parsing
|
||||
loaded = AnalysisResult.from_string(sexp_str)
|
||||
assert "scene_count" in loaded.data
|
||||
assert "beat_times" in loaded.data
|
||||
286
core/artdag/sexp/test_stage_compiler.py
Normal file
286
core/artdag/sexp/test_stage_compiler.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Tests for stage compilation and scoping.
|
||||
|
||||
Tests the CompiledStage dataclass, stage form parsing,
|
||||
variable scoping, and dependency validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from .parser import parse, Symbol, Keyword
|
||||
from .compiler import (
|
||||
compile_recipe,
|
||||
CompileError,
|
||||
CompiledStage,
|
||||
CompilerContext,
|
||||
_topological_sort_stages,
|
||||
)
|
||||
|
||||
|
||||
class TestStageCompilation:
|
||||
"""Test stage form compilation."""
|
||||
|
||||
def test_parse_stage_form_basic(self):
|
||||
"""Stage parses correctly with name and outputs."""
|
||||
recipe = '''
|
||||
(recipe "test-stage"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 1
|
||||
assert compiled.stages[0].name == "analyze"
|
||||
assert "beats" in compiled.stages[0].outputs
|
||||
assert len(compiled.stages[0].node_ids) > 0
|
||||
|
||||
def test_parse_stage_with_requires(self):
|
||||
"""Stage parses correctly with requires and inputs."""
|
||||
recipe = '''
|
||||
(recipe "test-requires"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
:outputs [segments]
|
||||
(def segments (-> audio (segment :times beats)))
|
||||
(-> segments (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 2
|
||||
process_stage = next(s for s in compiled.stages if s.name == "process")
|
||||
assert process_stage.requires == ["analyze"]
|
||||
assert "beats" in process_stage.inputs
|
||||
assert "segments" in process_stage.outputs
|
||||
|
||||
def test_stage_outputs_recorded(self):
|
||||
"""Stage outputs are tracked in CompiledStage."""
|
||||
recipe = '''
|
||||
(recipe "test-outputs"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats tempo]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(def tempo (-> audio (analyze tempo)))
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
stage = compiled.stages[0]
|
||||
assert "beats" in stage.outputs
|
||||
assert "tempo" in stage.outputs
|
||||
assert "beats" in stage.output_bindings
|
||||
assert "tempo" in stage.output_bindings
|
||||
|
||||
def test_stage_order_topological(self):
|
||||
"""Stages are topologically sorted."""
|
||||
recipe = '''
|
||||
(recipe "test-order"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
# analyze should come before output
|
||||
assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output")
|
||||
|
||||
|
||||
class TestStageValidation:
|
||||
"""Test stage dependency and input validation."""
|
||||
|
||||
def test_stage_requires_validation(self):
|
||||
"""Error if requiring non-existent stage."""
|
||||
recipe = '''
|
||||
(recipe "test-bad-require"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :process
|
||||
:requires [:nonexistent]
|
||||
:inputs [beats]
|
||||
(def result audio)))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="requires undefined stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_stage_inputs_validation(self):
|
||||
"""Error if input not produced by required stage."""
|
||||
recipe = '''
|
||||
(recipe "test-bad-input"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [nonexistent]
|
||||
(def result audio)))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not an output of any required stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_undeclared_output_error(self):
|
||||
"""Error if stage declares output not defined in body."""
|
||||
recipe = '''
|
||||
(recipe "test-missing-output"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats nonexistent]
|
||||
(def beats (-> audio (analyze beats)))))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not defined in the stage body"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_forward_reference_detection(self):
|
||||
"""Error when requiring a stage not yet defined."""
|
||||
# Forward references are not allowed - stages must be defined
|
||||
# before they can be required
|
||||
recipe = '''
|
||||
(recipe "test-forward"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :a
|
||||
:requires [:b]
|
||||
:outputs [out-a]
|
||||
(def out-a audio))
|
||||
|
||||
(stage :b
|
||||
:outputs [out-b]
|
||||
(def out-b audio)
|
||||
audio))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="requires undefined stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
|
||||
class TestStageScoping:
|
||||
"""Test variable scoping between stages."""
|
||||
|
||||
def test_pre_stage_bindings_accessible(self):
|
||||
"""Sources defined before stages accessible to all stages."""
|
||||
recipe = '''
|
||||
(recipe "test-pre-stage"
|
||||
(def audio (source :path "test.mp3"))
|
||||
(def video (source :path "test.mp4"))
|
||||
|
||||
(stage :analyze-audio
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :analyze-video
|
||||
:outputs [scenes]
|
||||
(def scenes (-> video (analyze scenes)))
|
||||
(-> video (segment :times scenes) (sequence))))
|
||||
'''
|
||||
# Should compile without error - audio and video accessible to both stages
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
assert len(compiled.stages) == 2
|
||||
|
||||
def test_stage_bindings_flow_through_requires(self):
|
||||
"""Stage bindings accessible to dependent stages via :inputs."""
|
||||
recipe = '''
|
||||
(recipe "test-binding-flow"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
:outputs [result]
|
||||
(def result (-> audio (segment :times beats)))
|
||||
(-> result (sequence))))
|
||||
'''
|
||||
# Should compile without error - beats flows from analyze to process
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
assert len(compiled.stages) == 2
|
||||
|
||||
|
||||
class TestTopologicalSort:
|
||||
"""Test stage topological sorting."""
|
||||
|
||||
def test_empty_stages(self):
|
||||
"""Empty stages returns empty list."""
|
||||
assert _topological_sort_stages({}) == []
|
||||
|
||||
def test_single_stage(self):
|
||||
"""Single stage returns single element."""
|
||||
stages = {
|
||||
"a": CompiledStage(
|
||||
name="a",
|
||||
requires=[],
|
||||
inputs=[],
|
||||
outputs=["out"],
|
||||
node_ids=["n1"],
|
||||
output_bindings={"out": "n1"},
|
||||
)
|
||||
}
|
||||
assert _topological_sort_stages(stages) == ["a"]
|
||||
|
||||
def test_linear_chain(self):
|
||||
"""Linear chain sorted correctly."""
|
||||
stages = {
|
||||
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
"b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
"c": CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
}
|
||||
result = _topological_sort_stages(stages)
|
||||
assert result.index("a") < result.index("b") < result.index("c")
|
||||
|
||||
def test_parallel_stages_same_level(self):
|
||||
"""Parallel stages are both valid orderings."""
|
||||
stages = {
|
||||
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
"b": CompiledStage(name="b", requires=[], inputs=[], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
}
|
||||
result = _topological_sort_stages(stages)
|
||||
# Both a and b should be in result (order doesn't matter)
|
||||
assert set(result) == {"a", "b"}
|
||||
|
||||
def test_diamond_dependency(self):
|
||||
"""Diamond pattern: A -> B, A -> C, B+C -> D."""
|
||||
stages = {
|
||||
"a": CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
"b": CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
"c": CompiledStage(name="c", requires=["a"], inputs=["x"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
"d": CompiledStage(name="d", requires=["b", "c"], inputs=["y", "z"], outputs=["out"],
|
||||
node_ids=["n4"], output_bindings={"out": "n4"}),
|
||||
}
|
||||
result = _topological_sort_stages(stages)
|
||||
# a must be first, d must be last
|
||||
assert result[0] == "a"
|
||||
assert result[-1] == "d"
|
||||
# b and c must be before d
|
||||
assert result.index("b") < result.index("d")
|
||||
assert result.index("c") < result.index("d")
|
||||
739
core/artdag/sexp/test_stage_integration.py
Normal file
739
core/artdag/sexp/test_stage_integration.py
Normal file
@@ -0,0 +1,739 @@
|
||||
"""
|
||||
End-to-end integration tests for staged recipes.
|
||||
|
||||
Tests the complete flow: compile -> plan -> execute
|
||||
for recipes with stages.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from .parser import parse, serialize
|
||||
from .compiler import compile_recipe, CompileError
|
||||
from .planner import ExecutionPlanSexp, StagePlan
|
||||
from .stage_cache import StageCache, StageCacheEntry, StageOutput
|
||||
from .scheduler import StagePlanScheduler, StagePlanResult
|
||||
|
||||
|
||||
class TestSimpleTwoStageRecipe:
|
||||
"""Test basic two-stage recipe flow."""
|
||||
|
||||
def test_compile_two_stage_recipe(self):
|
||||
"""Compile a simple two-stage recipe."""
|
||||
recipe = '''
|
||||
(recipe "test-two-stages"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 2
|
||||
assert compiled.stage_order == ["analyze", "output"]
|
||||
|
||||
analyze_stage = compiled.stages[0]
|
||||
assert analyze_stage.name == "analyze"
|
||||
assert "beats" in analyze_stage.outputs
|
||||
|
||||
output_stage = compiled.stages[1]
|
||||
assert output_stage.name == "output"
|
||||
assert output_stage.requires == ["analyze"]
|
||||
assert "beats" in output_stage.inputs
|
||||
|
||||
|
||||
class TestParallelAnalysisStages:
|
||||
"""Test parallel analysis stages."""
|
||||
|
||||
def test_compile_parallel_stages(self):
|
||||
"""Two analysis stages can run in parallel."""
|
||||
recipe = '''
|
||||
(recipe "test-parallel"
|
||||
(def audio-a (source :path "a.mp3"))
|
||||
(def audio-b (source :path "b.mp3"))
|
||||
|
||||
(stage :analyze-a
|
||||
:outputs [beats-a]
|
||||
(def beats-a (-> audio-a (analyze beats))))
|
||||
|
||||
(stage :analyze-b
|
||||
:outputs [beats-b]
|
||||
(def beats-b (-> audio-b (analyze beats))))
|
||||
|
||||
(stage :combine
|
||||
:requires [:analyze-a :analyze-b]
|
||||
:inputs [beats-a beats-b]
|
||||
(-> audio-a (segment :times beats-a) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 3
|
||||
|
||||
# analyze-a and analyze-b should both be at level 0 (parallel)
|
||||
analyze_a = next(s for s in compiled.stages if s.name == "analyze-a")
|
||||
analyze_b = next(s for s in compiled.stages if s.name == "analyze-b")
|
||||
combine = next(s for s in compiled.stages if s.name == "combine")
|
||||
|
||||
assert analyze_a.requires == []
|
||||
assert analyze_b.requires == []
|
||||
assert set(combine.requires) == {"analyze-a", "analyze-b"}
|
||||
|
||||
|
||||
class TestDiamondDependency:
|
||||
"""Test diamond dependency pattern: A -> B, A -> C, B+C -> D."""
|
||||
|
||||
def test_compile_diamond_pattern(self):
|
||||
"""Diamond pattern compiles correctly."""
|
||||
recipe = '''
|
||||
(recipe "test-diamond"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :source-stage
|
||||
:outputs [audio-ref]
|
||||
(def audio-ref audio))
|
||||
|
||||
(stage :branch-b
|
||||
:requires [:source-stage]
|
||||
:inputs [audio-ref]
|
||||
:outputs [result-b]
|
||||
(def result-b (-> audio-ref (effect gain :amount 0.5))))
|
||||
|
||||
(stage :branch-c
|
||||
:requires [:source-stage]
|
||||
:inputs [audio-ref]
|
||||
:outputs [result-c]
|
||||
(def result-c (-> audio-ref (effect gain :amount 0.8))))
|
||||
|
||||
(stage :merge
|
||||
:requires [:branch-b :branch-c]
|
||||
:inputs [result-b result-c]
|
||||
(-> result-b (blend result-c :mode "mix"))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 4
|
||||
|
||||
# Check dependencies
|
||||
source = next(s for s in compiled.stages if s.name == "source-stage")
|
||||
branch_b = next(s for s in compiled.stages if s.name == "branch-b")
|
||||
branch_c = next(s for s in compiled.stages if s.name == "branch-c")
|
||||
merge = next(s for s in compiled.stages if s.name == "merge")
|
||||
|
||||
assert source.requires == []
|
||||
assert branch_b.requires == ["source-stage"]
|
||||
assert branch_c.requires == ["source-stage"]
|
||||
assert set(merge.requires) == {"branch-b", "branch-c"}
|
||||
|
||||
# source-stage should come first in order
|
||||
assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-b")
|
||||
assert compiled.stage_order.index("source-stage") < compiled.stage_order.index("branch-c")
|
||||
# merge should come last
|
||||
assert compiled.stage_order.index("branch-b") < compiled.stage_order.index("merge")
|
||||
assert compiled.stage_order.index("branch-c") < compiled.stage_order.index("merge")
|
||||
|
||||
|
||||
class TestStageReuseOnRerun:
|
||||
"""Test that re-running recipe uses cached stages."""
|
||||
|
||||
def test_stage_reuse(self):
|
||||
"""Re-running recipe uses cached stages."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
# Simulate first run by caching a stage
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="fixed_cache_id",
|
||||
outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")},
|
||||
)
|
||||
stage_cache.save_stage(entry)
|
||||
|
||||
# Verify cache exists
|
||||
assert stage_cache.has_stage("fixed_cache_id")
|
||||
|
||||
# Second run should find cache
|
||||
loaded = stage_cache.load_stage("fixed_cache_id")
|
||||
assert loaded is not None
|
||||
assert loaded.stage_name == "analyze"
|
||||
|
||||
|
||||
class TestExplicitDataFlowEndToEnd:
|
||||
"""Test that analysis results flow through :inputs/:outputs."""
|
||||
|
||||
def test_data_flow_declaration(self):
|
||||
"""Explicit data flow is declared correctly."""
|
||||
recipe = '''
|
||||
(recipe "test-data-flow"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats tempo]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(def tempo (-> audio (analyze tempo))))
|
||||
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats tempo]
|
||||
:outputs [result]
|
||||
(def result (-> audio (segment :times beats) (effect speed :factor tempo)))
|
||||
(-> result (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
analyze = next(s for s in compiled.stages if s.name == "analyze")
|
||||
process = next(s for s in compiled.stages if s.name == "process")
|
||||
|
||||
# Analyze outputs
|
||||
assert set(analyze.outputs) == {"beats", "tempo"}
|
||||
assert "beats" in analyze.output_bindings
|
||||
assert "tempo" in analyze.output_bindings
|
||||
|
||||
# Process inputs
|
||||
assert set(process.inputs) == {"beats", "tempo"}
|
||||
assert process.requires == ["analyze"]
|
||||
|
||||
|
||||
class TestRecipeFixtures:
|
||||
"""Test using recipe fixtures."""
|
||||
|
||||
@pytest.fixture
|
||||
def test_recipe_two_stages(self):
|
||||
return '''
|
||||
(recipe "test-two-stages"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
|
||||
@pytest.fixture
|
||||
def test_recipe_parallel_stages(self):
|
||||
return '''
|
||||
(recipe "test-parallel"
|
||||
(def audio-a (source :path "a.mp3"))
|
||||
(def audio-b (source :path "b.mp3"))
|
||||
|
||||
(stage :analyze-a
|
||||
:outputs [beats-a]
|
||||
(def beats-a (-> audio-a (analyze beats))))
|
||||
|
||||
(stage :analyze-b
|
||||
:outputs [beats-b]
|
||||
(def beats-b (-> audio-b (analyze beats))))
|
||||
|
||||
(stage :combine
|
||||
:requires [:analyze-a :analyze-b]
|
||||
:inputs [beats-a beats-b]
|
||||
(-> audio-a (blend audio-b :mode "mix"))))
|
||||
'''
|
||||
|
||||
def test_two_stages_fixture(self, test_recipe_two_stages):
|
||||
"""Two-stage recipe fixture compiles."""
|
||||
compiled = compile_recipe(parse(test_recipe_two_stages))
|
||||
assert len(compiled.stages) == 2
|
||||
|
||||
def test_parallel_stages_fixture(self, test_recipe_parallel_stages):
|
||||
"""Parallel stages recipe fixture compiles."""
|
||||
compiled = compile_recipe(parse(test_recipe_parallel_stages))
|
||||
assert len(compiled.stages) == 3
|
||||
|
||||
|
||||
class TestStageValidationErrors:
|
||||
"""Test error handling for invalid stage recipes."""
|
||||
|
||||
def test_missing_output_declaration(self):
|
||||
"""Error when stage output not declared."""
|
||||
recipe = '''
|
||||
(recipe "test-missing-output"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats nonexistent]
|
||||
(def beats (-> audio (analyze beats)))))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not defined in the stage body"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_input_without_requires(self):
|
||||
"""Error when using input not from required stage."""
|
||||
recipe = '''
|
||||
(recipe "test-bad-input"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :process
|
||||
:requires []
|
||||
:inputs [beats]
|
||||
(def result audio)))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="not an output of any required stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
def test_forward_reference(self):
|
||||
"""Error when requiring stage not yet defined (forward reference)."""
|
||||
recipe = '''
|
||||
(recipe "test-forward-ref"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :a
|
||||
:requires [:b]
|
||||
:outputs [out-a]
|
||||
(def out-a audio)
|
||||
audio)
|
||||
|
||||
(stage :b
|
||||
:outputs [out-b]
|
||||
(def out-b audio)
|
||||
audio))
|
||||
'''
|
||||
with pytest.raises(CompileError, match="requires undefined stage"):
|
||||
compile_recipe(parse(recipe))
|
||||
|
||||
|
||||
class TestBeatSyncDemoRecipe:
|
||||
"""Test the beat-sync demo recipe from examples."""
|
||||
|
||||
BEAT_SYNC_RECIPE = '''
|
||||
;; Simple staged recipe demo
|
||||
(recipe "beat-sync-demo"
|
||||
:version "1.0"
|
||||
:description "Demo of staged beat-sync workflow"
|
||||
|
||||
;; Pre-stage definitions (available to all stages)
|
||||
(def audio (source :path "input.mp3"))
|
||||
|
||||
;; Stage 1: Analysis (expensive, cached)
|
||||
(stage :analyze
|
||||
:outputs [beats tempo]
|
||||
(def beats (-> audio (analyze beats)))
|
||||
(def tempo (-> audio (analyze tempo))))
|
||||
|
||||
;; Stage 2: Processing (uses analysis results)
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
:outputs [segments]
|
||||
(def segments (-> audio (segment :times beats)))
|
||||
(-> segments (sequence))))
|
||||
'''
|
||||
|
||||
def test_compile_beat_sync_recipe(self):
|
||||
"""Beat-sync demo recipe compiles correctly."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
assert compiled.name == "beat-sync-demo"
|
||||
assert compiled.version == "1.0"
|
||||
assert compiled.description == "Demo of staged beat-sync workflow"
|
||||
|
||||
def test_beat_sync_stage_count(self):
|
||||
"""Beat-sync has 2 stages in correct order."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
assert len(compiled.stages) == 2
|
||||
assert compiled.stage_order == ["analyze", "process"]
|
||||
|
||||
def test_beat_sync_analyze_stage(self):
|
||||
"""Analyze stage has correct outputs."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
analyze = next(s for s in compiled.stages if s.name == "analyze")
|
||||
assert analyze.requires == []
|
||||
assert analyze.inputs == []
|
||||
assert set(analyze.outputs) == {"beats", "tempo"}
|
||||
assert "beats" in analyze.output_bindings
|
||||
assert "tempo" in analyze.output_bindings
|
||||
|
||||
def test_beat_sync_process_stage(self):
|
||||
"""Process stage has correct dependencies and inputs."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
process = next(s for s in compiled.stages if s.name == "process")
|
||||
assert process.requires == ["analyze"]
|
||||
assert "beats" in process.inputs
|
||||
assert "segments" in process.outputs
|
||||
|
||||
def test_beat_sync_node_count(self):
|
||||
"""Beat-sync generates expected number of nodes."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
# 1 SOURCE + 2 ANALYZE + 1 SEGMENT + 1 SEQUENCE = 5 nodes
|
||||
assert len(compiled.nodes) == 5
|
||||
|
||||
def test_beat_sync_node_types(self):
|
||||
"""Beat-sync generates correct node types."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
node_types = [n["type"] for n in compiled.nodes]
|
||||
assert node_types.count("SOURCE") == 1
|
||||
assert node_types.count("ANALYZE") == 2
|
||||
assert node_types.count("SEGMENT") == 1
|
||||
assert node_types.count("SEQUENCE") == 1
|
||||
|
||||
def test_beat_sync_output_is_sequence(self):
|
||||
"""Beat-sync output node is the sequence node."""
|
||||
compiled = compile_recipe(parse(self.BEAT_SYNC_RECIPE))
|
||||
|
||||
output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id)
|
||||
assert output_node["type"] == "SEQUENCE"
|
||||
|
||||
|
||||
class TestAsciiArtStagedRecipe:
|
||||
"""Test the ASCII art staged recipe."""
|
||||
|
||||
ASCII_ART_STAGED_RECIPE = '''
|
||||
;; ASCII art effect with staged execution
|
||||
(recipe "ascii_art_staged"
|
||||
:version "1.0"
|
||||
:description "ASCII art effect with staged execution"
|
||||
:encoding (:codec "libx264" :crf 20 :preset "medium" :audio-codec "aac" :fps 30)
|
||||
|
||||
;; Registry
|
||||
(effect ascii_art :path "sexp_effects/effects/ascii_art.sexp")
|
||||
(analyzer energy :path "../artdag-analyzers/energy/analyzer.py")
|
||||
|
||||
;; Pre-stage definitions
|
||||
(def color_mode "color")
|
||||
(def video (source :path "monday.webm"))
|
||||
(def audio (source :path "dizzy.mp3"))
|
||||
|
||||
;; Stage 1: Analysis
|
||||
(stage :analyze
|
||||
:outputs [energy-data]
|
||||
(def audio-clip (-> audio (segment :start 60 :duration 10)))
|
||||
(def energy-data (-> audio-clip (analyze energy))))
|
||||
|
||||
;; Stage 2: Process
|
||||
(stage :process
|
||||
:requires [:analyze]
|
||||
:inputs [energy-data]
|
||||
:outputs [result audio-clip]
|
||||
(def clip (-> video (segment :start 0 :duration 10)))
|
||||
(def audio-clip (-> audio (segment :start 60 :duration 10)))
|
||||
(def result (-> clip
|
||||
(effect ascii_art
|
||||
:char_size (bind energy-data values :range [2 32])
|
||||
:color_mode color_mode))))
|
||||
|
||||
;; Stage 3: Output
|
||||
(stage :output
|
||||
:requires [:process]
|
||||
:inputs [result audio-clip]
|
||||
(mux result audio-clip)))
|
||||
'''
|
||||
|
||||
def test_compile_ascii_art_staged(self):
|
||||
"""ASCII art staged recipe compiles correctly."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
assert compiled.name == "ascii_art_staged"
|
||||
assert compiled.version == "1.0"
|
||||
|
||||
def test_ascii_art_stage_count(self):
|
||||
"""ASCII art has 3 stages in correct order."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
assert len(compiled.stages) == 3
|
||||
assert compiled.stage_order == ["analyze", "process", "output"]
|
||||
|
||||
def test_ascii_art_analyze_stage(self):
|
||||
"""Analyze stage outputs energy-data."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
analyze = next(s for s in compiled.stages if s.name == "analyze")
|
||||
assert analyze.requires == []
|
||||
assert analyze.inputs == []
|
||||
assert "energy-data" in analyze.outputs
|
||||
|
||||
def test_ascii_art_process_stage(self):
|
||||
"""Process stage requires analyze and outputs result."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
process = next(s for s in compiled.stages if s.name == "process")
|
||||
assert process.requires == ["analyze"]
|
||||
assert "energy-data" in process.inputs
|
||||
assert "result" in process.outputs
|
||||
assert "audio-clip" in process.outputs
|
||||
|
||||
def test_ascii_art_output_stage(self):
|
||||
"""Output stage requires process and has mux."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
output = next(s for s in compiled.stages if s.name == "output")
|
||||
assert output.requires == ["process"]
|
||||
assert "result" in output.inputs
|
||||
assert "audio-clip" in output.inputs
|
||||
|
||||
def test_ascii_art_node_count(self):
|
||||
"""ASCII art generates expected nodes."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
# 2 SOURCE + 2 SEGMENT + 1 ANALYZE + 1 EFFECT + 1 MUX = 7+ nodes
|
||||
assert len(compiled.nodes) >= 7
|
||||
|
||||
def test_ascii_art_has_mux_output(self):
|
||||
"""ASCII art output is MUX node."""
|
||||
compiled = compile_recipe(parse(self.ASCII_ART_STAGED_RECIPE))
|
||||
|
||||
output_node = next(n for n in compiled.nodes if n["id"] == compiled.output_node_id)
|
||||
assert output_node["type"] == "MUX"
|
||||
|
||||
|
||||
class TestMixedStagedAndNonStagedRecipes:
|
||||
"""Test that non-staged recipes still work."""
|
||||
|
||||
def test_recipe_without_stages(self):
|
||||
"""Non-staged recipe compiles normally."""
|
||||
recipe = '''
|
||||
(recipe "no-stages"
|
||||
(-> (source :path "test.mp3")
|
||||
(effect gain :amount 0.5)))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert compiled.stages == []
|
||||
assert compiled.stage_order == []
|
||||
# Should still have nodes
|
||||
assert len(compiled.nodes) > 0
|
||||
|
||||
def test_mixed_pre_stage_and_stages(self):
|
||||
"""Pre-stage definitions work with stages."""
|
||||
recipe = '''
|
||||
(recipe "mixed"
|
||||
;; Pre-stage definitions
|
||||
(def audio (source :path "test.mp3"))
|
||||
(def volume 0.8)
|
||||
|
||||
;; Stage using pre-stage definitions, ending with output expression
|
||||
(stage :process
|
||||
:outputs [result]
|
||||
(def result (-> audio (effect gain :amount volume)))
|
||||
result))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
|
||||
assert len(compiled.stages) == 1
|
||||
# audio and volume should be accessible in stage
|
||||
process = compiled.stages[0]
|
||||
assert process.name == "process"
|
||||
assert "result" in process.outputs
|
||||
|
||||
|
||||
class TestEffectParamsBlock:
|
||||
"""Test :params block parsing in effect definitions."""
|
||||
|
||||
def test_parse_effect_with_params_block(self):
|
||||
"""Parse effect with new :params syntax."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
|
||||
effect_code = '''
|
||||
(define-effect test_effect
|
||||
:params (
|
||||
(size :type int :default 10 :range [1 100] :desc "Size parameter")
|
||||
(color :type string :default "red" :desc "Color parameter")
|
||||
(enabled :type int :default 1 :range [0 1] :desc "Enable flag")
|
||||
)
|
||||
frame)
|
||||
'''
|
||||
name, process_fn, defaults, param_defs = load_sexp_effect(effect_code)
|
||||
|
||||
assert name == "test_effect"
|
||||
assert len(param_defs) == 3
|
||||
assert defaults["size"] == 10
|
||||
assert defaults["color"] == "red"
|
||||
assert defaults["enabled"] == 1
|
||||
|
||||
# Check ParamDef objects
|
||||
size_param = param_defs[0]
|
||||
assert size_param.name == "size"
|
||||
assert size_param.param_type == "int"
|
||||
assert size_param.default == 10
|
||||
assert size_param.range_min == 1.0
|
||||
assert size_param.range_max == 100.0
|
||||
assert size_param.description == "Size parameter"
|
||||
|
||||
color_param = param_defs[1]
|
||||
assert color_param.name == "color"
|
||||
assert color_param.param_type == "string"
|
||||
assert color_param.default == "red"
|
||||
|
||||
def test_parse_effect_with_choices(self):
|
||||
"""Parse effect with choices in :params."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
|
||||
effect_code = '''
|
||||
(define-effect mode_effect
|
||||
:params (
|
||||
(mode :type string :default "fast"
|
||||
:choices [fast slow medium]
|
||||
:desc "Processing mode")
|
||||
)
|
||||
frame)
|
||||
'''
|
||||
name, _, defaults, param_defs = load_sexp_effect(effect_code)
|
||||
|
||||
assert name == "mode_effect"
|
||||
assert defaults["mode"] == "fast"
|
||||
|
||||
mode_param = param_defs[0]
|
||||
assert mode_param.choices == ["fast", "slow", "medium"]
|
||||
|
||||
def test_legacy_effect_syntax_rejected(self):
|
||||
"""Legacy effect syntax should be rejected."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
import pytest
|
||||
|
||||
effect_code = '''
|
||||
(define-effect legacy_effect
|
||||
((width 100)
|
||||
(height 200)
|
||||
(name "default"))
|
||||
frame)
|
||||
'''
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
load_sexp_effect(effect_code)
|
||||
|
||||
assert "Legacy parameter syntax" in str(exc_info.value)
|
||||
assert ":params" in str(exc_info.value)
|
||||
|
||||
def test_effect_params_introspection(self):
|
||||
"""Test that effect params are available for introspection."""
|
||||
from .effect_loader import load_sexp_effect_file
|
||||
from pathlib import Path
|
||||
|
||||
# Create a temp effect file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
|
||||
f.write('''
|
||||
(define-effect introspect_test
|
||||
:params (
|
||||
(alpha :type float :default 0.5 :range [0 1] :desc "Alpha value")
|
||||
)
|
||||
frame)
|
||||
''')
|
||||
temp_path = Path(f.name)
|
||||
|
||||
try:
|
||||
name, _, defaults, param_defs = load_sexp_effect_file(temp_path)
|
||||
assert name == "introspect_test"
|
||||
assert len(param_defs) == 1
|
||||
assert param_defs[0].name == "alpha"
|
||||
assert param_defs[0].param_type == "float"
|
||||
finally:
|
||||
temp_path.unlink()
|
||||
|
||||
|
||||
class TestConstructParamsBlock:
|
||||
"""Test :params block parsing in construct definitions."""
|
||||
|
||||
def test_parse_construct_params_helper(self):
|
||||
"""Test the _parse_construct_params helper function."""
|
||||
from .planner import _parse_construct_params
|
||||
from .parser import Symbol, Keyword
|
||||
|
||||
params_list = [
|
||||
[Symbol("duration"), Keyword("type"), Symbol("float"),
|
||||
Keyword("default"), 5.0, Keyword("desc"), "Duration in seconds"],
|
||||
[Symbol("count"), Keyword("type"), Symbol("int"),
|
||||
Keyword("default"), 10],
|
||||
]
|
||||
|
||||
param_names, param_defaults = _parse_construct_params(params_list)
|
||||
|
||||
assert param_names == ["duration", "count"]
|
||||
assert param_defaults["duration"] == 5.0
|
||||
assert param_defaults["count"] == 10
|
||||
|
||||
def test_construct_params_with_no_defaults(self):
|
||||
"""Test construct params where some have no default."""
|
||||
from .planner import _parse_construct_params
|
||||
from .parser import Symbol, Keyword
|
||||
|
||||
params_list = [
|
||||
[Symbol("required_param"), Keyword("type"), Symbol("string")],
|
||||
[Symbol("optional_param"), Keyword("type"), Symbol("int"),
|
||||
Keyword("default"), 42],
|
||||
]
|
||||
|
||||
param_names, param_defaults = _parse_construct_params(params_list)
|
||||
|
||||
assert param_names == ["required_param", "optional_param"]
|
||||
assert param_defaults["required_param"] is None
|
||||
assert param_defaults["optional_param"] == 42
|
||||
|
||||
|
||||
class TestParameterValidation:
|
||||
"""Test that unknown parameters are rejected."""
|
||||
|
||||
def test_effect_rejects_unknown_params(self):
|
||||
"""Effects should reject unknown parameters."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
effect_code = '''
|
||||
(define-effect test_effect
|
||||
:params (
|
||||
(brightness :type int :default 0 :desc "Brightness")
|
||||
)
|
||||
frame)
|
||||
'''
|
||||
name, process_frame, defaults, _ = load_sexp_effect(effect_code)
|
||||
|
||||
# Create a test frame
|
||||
frame = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
state = {}
|
||||
|
||||
# Valid param should work
|
||||
result, _ = process_frame(frame, {"brightness": 10}, state)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
# Unknown param should raise
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
process_frame(frame, {"unknown_param": 42}, state)
|
||||
|
||||
assert "Unknown parameter 'unknown_param'" in str(exc_info.value)
|
||||
assert "brightness" in str(exc_info.value)
|
||||
|
||||
def test_effect_no_params_rejects_all(self):
|
||||
"""Effects with no params should reject any parameter."""
|
||||
from .effect_loader import load_sexp_effect
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
effect_code = '''
|
||||
(define-effect no_params_effect
|
||||
:params ()
|
||||
frame)
|
||||
'''
|
||||
name, process_frame, defaults, _ = load_sexp_effect(effect_code)
|
||||
|
||||
# Create a test frame
|
||||
frame = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
state = {}
|
||||
|
||||
# Empty params should work
|
||||
result, _ = process_frame(frame, {}, state)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
# Any param should raise
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
process_frame(frame, {"any_param": 42}, state)
|
||||
|
||||
assert "Unknown parameter 'any_param'" in str(exc_info.value)
|
||||
assert "(none)" in str(exc_info.value)
|
||||
228
core/artdag/sexp/test_stage_planner.py
Normal file
228
core/artdag/sexp/test_stage_planner.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Tests for stage-aware planning.
|
||||
|
||||
Tests stage topological sorting, level computation, cache ID computation,
|
||||
and plan metadata generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from .parser import parse
|
||||
from .compiler import compile_recipe, CompiledStage
|
||||
from .planner import (
|
||||
create_plan,
|
||||
StagePlan,
|
||||
_compute_stage_levels,
|
||||
_compute_stage_cache_id,
|
||||
)
|
||||
|
||||
|
||||
class TestStagePlanning:
|
||||
"""Test stage-aware plan creation."""
|
||||
|
||||
def test_stage_topological_sort_in_plan(self):
|
||||
"""Stages sorted by dependencies in plan."""
|
||||
recipe = '''
|
||||
(recipe "test-sort"
|
||||
(def audio (source :path "test.mp3"))
|
||||
|
||||
(stage :analyze
|
||||
:outputs [beats]
|
||||
(def beats (-> audio (analyze beats))))
|
||||
|
||||
(stage :output
|
||||
:requires [:analyze]
|
||||
:inputs [beats]
|
||||
(-> audio (segment :times beats) (sequence))))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
# Note: create_plan needs recipe_dir for analysis, we'll test the ordering differently
|
||||
assert compiled.stage_order.index("analyze") < compiled.stage_order.index("output")
|
||||
|
||||
def test_stage_level_computation(self):
|
||||
"""Independent stages get same level."""
|
||||
stages = [
|
||||
CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
CompiledStage(name="b", requires=[], inputs=[], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
CompiledStage(name="c", requires=["a", "b"], inputs=["x", "y"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
]
|
||||
levels = _compute_stage_levels(stages)
|
||||
|
||||
assert levels["a"] == 0
|
||||
assert levels["b"] == 0
|
||||
assert levels["c"] == 1 # Depends on a and b
|
||||
|
||||
def test_stage_level_chain(self):
|
||||
"""Chain stages get increasing levels."""
|
||||
stages = [
|
||||
CompiledStage(name="a", requires=[], inputs=[], outputs=["x"],
|
||||
node_ids=["n1"], output_bindings={"x": "n1"}),
|
||||
CompiledStage(name="b", requires=["a"], inputs=["x"], outputs=["y"],
|
||||
node_ids=["n2"], output_bindings={"y": "n2"}),
|
||||
CompiledStage(name="c", requires=["b"], inputs=["y"], outputs=["z"],
|
||||
node_ids=["n3"], output_bindings={"z": "n3"}),
|
||||
]
|
||||
levels = _compute_stage_levels(stages)
|
||||
|
||||
assert levels["a"] == 0
|
||||
assert levels["b"] == 1
|
||||
assert levels["c"] == 2
|
||||
|
||||
def test_stage_cache_id_deterministic(self):
|
||||
"""Same stage = same cache ID."""
|
||||
stage = CompiledStage(
|
||||
name="analyze",
|
||||
requires=[],
|
||||
inputs=[],
|
||||
outputs=["beats"],
|
||||
node_ids=["abc123"],
|
||||
output_bindings={"beats": "abc123"},
|
||||
)
|
||||
|
||||
cache_id_1 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key=None,
|
||||
)
|
||||
cache_id_2 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key=None,
|
||||
)
|
||||
|
||||
assert cache_id_1 == cache_id_2
|
||||
|
||||
def test_stage_cache_id_includes_requires(self):
|
||||
"""Cache ID changes when required stage cache ID changes."""
|
||||
stage = CompiledStage(
|
||||
name="process",
|
||||
requires=["analyze"],
|
||||
inputs=["beats"],
|
||||
outputs=["result"],
|
||||
node_ids=["def456"],
|
||||
output_bindings={"result": "def456"},
|
||||
)
|
||||
|
||||
cache_id_1 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={"analyze": "req_cache_a"},
|
||||
node_cache_ids={"def456": "node_def"},
|
||||
cluster_key=None,
|
||||
)
|
||||
cache_id_2 = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={"analyze": "req_cache_b"},
|
||||
node_cache_ids={"def456": "node_def"},
|
||||
cluster_key=None,
|
||||
)
|
||||
|
||||
# Different required stage cache IDs should produce different cache IDs
|
||||
assert cache_id_1 != cache_id_2
|
||||
|
||||
def test_stage_cache_id_cluster_key(self):
|
||||
"""Cache ID changes with cluster key."""
|
||||
stage = CompiledStage(
|
||||
name="analyze",
|
||||
requires=[],
|
||||
inputs=[],
|
||||
outputs=["beats"],
|
||||
node_ids=["abc123"],
|
||||
output_bindings={"beats": "abc123"},
|
||||
)
|
||||
|
||||
cache_id_no_key = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key=None,
|
||||
)
|
||||
cache_id_with_key = _compute_stage_cache_id(
|
||||
stage,
|
||||
stage_cache_ids={},
|
||||
node_cache_ids={"abc123": "nodeabc"},
|
||||
cluster_key="cluster123",
|
||||
)
|
||||
|
||||
# Cluster key should change the cache ID
|
||||
assert cache_id_no_key != cache_id_with_key
|
||||
|
||||
|
||||
class TestStagePlanMetadata:
|
||||
"""Test stage metadata in execution plans."""
|
||||
|
||||
def test_plan_without_stages(self):
|
||||
"""Plan without stages has empty stage fields."""
|
||||
recipe = '''
|
||||
(recipe "no-stages"
|
||||
(-> (source :path "test.mp3") (effect gain :amount 0.5)))
|
||||
'''
|
||||
compiled = compile_recipe(parse(recipe))
|
||||
assert compiled.stages == []
|
||||
assert compiled.stage_order == []
|
||||
|
||||
|
||||
class TestStagePlanDataclass:
|
||||
"""Test StagePlan dataclass."""
|
||||
|
||||
def test_stage_plan_creation(self):
|
||||
"""StagePlan can be created with all fields."""
|
||||
from .planner import PlanStep
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={"analyzer": "beats"},
|
||||
inputs=["input1"],
|
||||
cache_id="cache123",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="stage_cache_123",
|
||||
)
|
||||
|
||||
stage_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="stage_cache_123",
|
||||
steps=[step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "cache123"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
assert stage_plan.stage_name == "analyze"
|
||||
assert stage_plan.cache_id == "stage_cache_123"
|
||||
assert len(stage_plan.steps) == 1
|
||||
assert stage_plan.level == 0
|
||||
|
||||
|
||||
class TestExplicitDataRouting:
|
||||
"""Test that plan includes explicit data routing."""
|
||||
|
||||
def test_plan_step_includes_stage_info(self):
|
||||
"""PlanStep includes stage and stage_cache_id."""
|
||||
from .planner import PlanStep
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="cache123",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="stage_cache_abc",
|
||||
)
|
||||
|
||||
sexp = step.to_sexp()
|
||||
# Convert to string to check for stage info
|
||||
from .parser import serialize
|
||||
sexp_str = serialize(sexp)
|
||||
|
||||
assert "stage" in sexp_str
|
||||
assert "analyze" in sexp_str
|
||||
assert "stage-cache-id" in sexp_str
|
||||
323
core/artdag/sexp/test_stage_scheduler.py
Normal file
323
core/artdag/sexp/test_stage_scheduler.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Tests for stage-aware scheduler.
|
||||
|
||||
Tests stage cache hit/miss, stage execution ordering,
|
||||
and parallel stage support.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
|
||||
from .scheduler import (
|
||||
StagePlanScheduler,
|
||||
StageResult,
|
||||
StagePlanResult,
|
||||
create_stage_scheduler,
|
||||
schedule_staged_plan,
|
||||
)
|
||||
from .planner import ExecutionPlanSexp, PlanStep, StagePlan
|
||||
from .stage_cache import StageCache, StageCacheEntry, StageOutput
|
||||
|
||||
|
||||
class TestStagePlanScheduler:
|
||||
"""Test stage-aware scheduling."""
|
||||
|
||||
def test_plan_without_stages_uses_regular_scheduling(self):
|
||||
"""Plans without stages fall back to regular scheduling."""
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[],
|
||||
output_step_id="output",
|
||||
stage_plans=[], # No stages
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler()
|
||||
# This will use PlanScheduler internally
|
||||
# Without Celery, it just returns completed status
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
assert isinstance(result, StagePlanResult)
|
||||
|
||||
def test_stage_cache_hit_skips_execution(self):
|
||||
"""Cached stage not re-executed."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
# Pre-populate cache
|
||||
entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="stage_cache_123",
|
||||
outputs={"beats": StageOutput(cache_id="beats_out", output_type="analysis")},
|
||||
)
|
||||
stage_cache.save_stage(entry)
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="step_cache",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="stage_cache_123",
|
||||
)
|
||||
|
||||
stage_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="stage_cache_123",
|
||||
steps=[step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "beats_out"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[step],
|
||||
output_step_id="step1",
|
||||
stage_plans=[stage_plan],
|
||||
stage_order=["analyze"],
|
||||
stage_levels={"analyze": 0},
|
||||
stage_cache_ids={"analyze": "stage_cache_123"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler(stage_cache=stage_cache)
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
assert result.stages_cached == 1
|
||||
assert result.stages_completed == 0
|
||||
|
||||
def test_stage_inputs_loaded_from_cache(self):
|
||||
"""Stage receives inputs from required stage cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
# Pre-populate upstream stage cache
|
||||
upstream_entry = StageCacheEntry(
|
||||
stage_name="analyze",
|
||||
cache_id="upstream_cache",
|
||||
outputs={"beats": StageOutput(cache_id="beats_data", output_type="analysis")},
|
||||
)
|
||||
stage_cache.save_stage(upstream_entry)
|
||||
|
||||
# Steps for stages
|
||||
upstream_step = PlanStep(
|
||||
step_id="analyze_step",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="analyze_cache",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="upstream_cache",
|
||||
)
|
||||
|
||||
downstream_step = PlanStep(
|
||||
step_id="process_step",
|
||||
node_type="SEGMENT",
|
||||
config={},
|
||||
inputs=["analyze_step"],
|
||||
cache_id="process_cache",
|
||||
level=1,
|
||||
stage="process",
|
||||
stage_cache_id="downstream_cache",
|
||||
)
|
||||
|
||||
upstream_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="upstream_cache",
|
||||
steps=[upstream_step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "beats_data"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
downstream_plan = StagePlan(
|
||||
stage_name="process",
|
||||
cache_id="downstream_cache",
|
||||
steps=[downstream_step],
|
||||
requires=["analyze"],
|
||||
output_bindings={"result": "process_cache"},
|
||||
level=1,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[upstream_step, downstream_step],
|
||||
output_step_id="process_step",
|
||||
stage_plans=[upstream_plan, downstream_plan],
|
||||
stage_order=["analyze", "process"],
|
||||
stage_levels={"analyze": 0, "process": 1},
|
||||
stage_cache_ids={"analyze": "upstream_cache", "process": "downstream_cache"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler(stage_cache=stage_cache)
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
# Upstream should be cached, downstream executed
|
||||
assert result.stages_cached == 1
|
||||
assert "analyze" in result.stage_results
|
||||
assert result.stage_results["analyze"].status == "cached"
|
||||
|
||||
def test_parallel_stages_same_level(self):
|
||||
"""Stages at same level can run in parallel."""
|
||||
step_a = PlanStep(
|
||||
step_id="step_a",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="cache_a",
|
||||
level=0,
|
||||
stage="analyze-a",
|
||||
stage_cache_id="stage_a",
|
||||
)
|
||||
|
||||
step_b = PlanStep(
|
||||
step_id="step_b",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="cache_b",
|
||||
level=0,
|
||||
stage="analyze-b",
|
||||
stage_cache_id="stage_b",
|
||||
)
|
||||
|
||||
stage_a = StagePlan(
|
||||
stage_name="analyze-a",
|
||||
cache_id="stage_a",
|
||||
steps=[step_a],
|
||||
requires=[],
|
||||
output_bindings={"beats-a": "cache_a"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
stage_b = StagePlan(
|
||||
stage_name="analyze-b",
|
||||
cache_id="stage_b",
|
||||
steps=[step_b],
|
||||
requires=[],
|
||||
output_bindings={"beats-b": "cache_b"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[step_a, step_b],
|
||||
output_step_id="step_b",
|
||||
stage_plans=[stage_a, stage_b],
|
||||
stage_order=["analyze-a", "analyze-b"],
|
||||
stage_levels={"analyze-a": 0, "analyze-b": 0},
|
||||
stage_cache_ids={"analyze-a": "stage_a", "analyze-b": "stage_b"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler()
|
||||
# Group stages by level
|
||||
stages_by_level = scheduler._group_stages_by_level(plan.stage_plans)
|
||||
|
||||
# Both stages should be at level 0
|
||||
assert len(stages_by_level[0]) == 2
|
||||
|
||||
def test_stage_outputs_cached_after_execution(self):
|
||||
"""Stage outputs written to cache after completion."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
|
||||
step = PlanStep(
|
||||
step_id="step1",
|
||||
node_type="ANALYZE",
|
||||
config={},
|
||||
inputs=[],
|
||||
cache_id="step_cache",
|
||||
level=0,
|
||||
stage="analyze",
|
||||
stage_cache_id="new_stage_cache",
|
||||
)
|
||||
|
||||
stage_plan = StagePlan(
|
||||
stage_name="analyze",
|
||||
cache_id="new_stage_cache",
|
||||
steps=[step],
|
||||
requires=[],
|
||||
output_bindings={"beats": "step_cache"},
|
||||
level=0,
|
||||
)
|
||||
|
||||
plan = ExecutionPlanSexp(
|
||||
plan_id="test_plan",
|
||||
recipe_id="test_recipe",
|
||||
recipe_hash="abc123",
|
||||
steps=[step],
|
||||
output_step_id="step1",
|
||||
stage_plans=[stage_plan],
|
||||
stage_order=["analyze"],
|
||||
stage_levels={"analyze": 0},
|
||||
stage_cache_ids={"analyze": "new_stage_cache"},
|
||||
)
|
||||
|
||||
scheduler = StagePlanScheduler(stage_cache=stage_cache)
|
||||
result = scheduler.schedule(plan)
|
||||
|
||||
# Stage should now be cached
|
||||
assert stage_cache.has_stage("new_stage_cache")
|
||||
|
||||
|
||||
class TestStageResult:
|
||||
"""Test StageResult dataclass."""
|
||||
|
||||
def test_stage_result_creation(self):
|
||||
"""StageResult can be created with all fields."""
|
||||
result = StageResult(
|
||||
stage_name="test",
|
||||
cache_id="cache123",
|
||||
status="completed",
|
||||
step_results={},
|
||||
outputs={"out": "out_cache"},
|
||||
)
|
||||
|
||||
assert result.stage_name == "test"
|
||||
assert result.status == "completed"
|
||||
assert result.outputs["out"] == "out_cache"
|
||||
|
||||
|
||||
class TestStagePlanResult:
|
||||
"""Test StagePlanResult dataclass."""
|
||||
|
||||
def test_stage_plan_result_creation(self):
|
||||
"""StagePlanResult can be created with all fields."""
|
||||
result = StagePlanResult(
|
||||
plan_id="plan123",
|
||||
status="completed",
|
||||
stages_completed=2,
|
||||
stages_cached=1,
|
||||
stages_failed=0,
|
||||
)
|
||||
|
||||
assert result.plan_id == "plan123"
|
||||
assert result.stages_completed == 2
|
||||
assert result.stages_cached == 1
|
||||
|
||||
|
||||
class TestSchedulerFactory:
|
||||
"""Test scheduler factory functions."""
|
||||
|
||||
def test_create_stage_scheduler(self):
|
||||
"""create_stage_scheduler returns StagePlanScheduler."""
|
||||
scheduler = create_stage_scheduler()
|
||||
assert isinstance(scheduler, StagePlanScheduler)
|
||||
|
||||
def test_create_stage_scheduler_with_cache(self):
|
||||
"""create_stage_scheduler accepts stage_cache."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
stage_cache = StageCache(tmpdir)
|
||||
scheduler = create_stage_scheduler(stage_cache=stage_cache)
|
||||
assert scheduler.stage_cache is stage_cache
|
||||
384
core/docs/EXECUTION_MODEL.md
Normal file
384
core/docs/EXECUTION_MODEL.md
Normal file
@@ -0,0 +1,384 @@
|
||||
# Art DAG 3-Phase Execution Model
|
||||
|
||||
## Overview
|
||||
|
||||
The execution model separates DAG processing into three distinct phases:
|
||||
|
||||
```
|
||||
Recipe + Inputs → ANALYZE → Analysis Results
|
||||
↓
|
||||
Analysis + Recipe → PLAN → Execution Plan (with cache IDs)
|
||||
↓
|
||||
Execution Plan → EXECUTE → Cached Results
|
||||
```
|
||||
|
||||
This separation enables:
|
||||
1. **Incremental development** - Re-run recipes without reprocessing unchanged steps
|
||||
2. **Parallel execution** - Independent steps run concurrently via Celery
|
||||
3. **Deterministic caching** - Same inputs always produce same cache IDs
|
||||
4. **Cost estimation** - Plan phase can estimate work before executing
|
||||
|
||||
## Phase 1: Analysis
|
||||
|
||||
### Purpose
|
||||
Extract features from input media that inform downstream processing decisions.
|
||||
|
||||
### Inputs
|
||||
- Recipe YAML with input references
|
||||
- Input media files (by content hash)
|
||||
|
||||
### Outputs
|
||||
Analysis results stored as JSON, keyed by input hash:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
input_hash: str
|
||||
features: Dict[str, Any]
|
||||
# Audio features
|
||||
beats: Optional[List[float]] # Beat times in seconds
|
||||
downbeats: Optional[List[float]] # Bar-start times
|
||||
tempo: Optional[float] # BPM
|
||||
energy: Optional[List[Tuple[float, float]]] # (time, value) envelope
|
||||
spectrum: Optional[Dict[str, List[Tuple[float, float]]]] # band envelopes
|
||||
# Video features
|
||||
duration: float
|
||||
frame_rate: float
|
||||
dimensions: Tuple[int, int]
|
||||
motion_tempo: Optional[float] # Estimated BPM from motion
|
||||
```
|
||||
|
||||
### Implementation
|
||||
```python
|
||||
class Analyzer:
|
||||
def analyze(self, input_hash: str, features: List[str]) -> AnalysisResult:
|
||||
"""Extract requested features from input."""
|
||||
|
||||
def analyze_audio(self, path: Path) -> AudioFeatures:
|
||||
"""Extract all audio features using librosa/essentia."""
|
||||
|
||||
def analyze_video(self, path: Path) -> VideoFeatures:
|
||||
"""Extract video metadata and motion analysis."""
|
||||
```
|
||||
|
||||
### Caching
|
||||
Analysis results are cached by:
|
||||
```
|
||||
analysis_cache_id = SHA3-256(input_hash + sorted(feature_names))
|
||||
```
|
||||
|
||||
## Phase 2: Planning
|
||||
|
||||
### Purpose
|
||||
Convert recipe + analysis into a complete execution plan with pre-computed cache IDs.
|
||||
|
||||
### Inputs
|
||||
- Recipe YAML (parsed)
|
||||
- Analysis results for all inputs
|
||||
- Recipe parameters (user-supplied values)
|
||||
|
||||
### Outputs
|
||||
An ExecutionPlan containing ordered steps, each with a pre-computed cache ID:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class ExecutionStep:
|
||||
step_id: str # Unique identifier
|
||||
node_type: str # Primitive type (SOURCE, SEQUENCE, etc.)
|
||||
config: Dict[str, Any] # Node configuration
|
||||
input_steps: List[str] # IDs of steps this depends on
|
||||
cache_id: str # Pre-computed: hash(inputs + config)
|
||||
estimated_duration: float # Optional: for progress reporting
|
||||
|
||||
@dataclass
|
||||
class ExecutionPlan:
|
||||
plan_id: str # Hash of entire plan
|
||||
recipe_id: str # Source recipe
|
||||
steps: List[ExecutionStep] # Topologically sorted
|
||||
analysis: Dict[str, AnalysisResult]
|
||||
output_step: str # Final step ID
|
||||
|
||||
def compute_cache_ids(self):
|
||||
"""Compute all cache IDs in dependency order."""
|
||||
```
|
||||
|
||||
### Cache ID Computation
|
||||
|
||||
Cache IDs are computed in topological order so each step's cache ID
|
||||
incorporates its inputs' cache IDs:
|
||||
|
||||
```python
|
||||
def compute_cache_id(step: ExecutionStep, resolved_inputs: Dict[str, str]) -> str:
|
||||
"""
|
||||
Cache ID = SHA3-256(
|
||||
node_type +
|
||||
canonical_json(config) +
|
||||
sorted([input_cache_ids])
|
||||
)
|
||||
"""
|
||||
components = [
|
||||
step.node_type,
|
||||
json.dumps(step.config, sort_keys=True),
|
||||
*sorted(resolved_inputs[s] for s in step.input_steps)
|
||||
]
|
||||
return sha3_256('|'.join(components))
|
||||
```
|
||||
|
||||
### Plan Generation
|
||||
|
||||
The planner expands recipe nodes into concrete steps:
|
||||
|
||||
1. **SOURCE nodes** → Direct step with input hash as cache ID
|
||||
2. **ANALYZE nodes** → Step that references analysis results
|
||||
3. **TRANSFORM nodes** → Step with static config
|
||||
4. **TRANSFORM_DYNAMIC nodes** → Expanded to per-frame steps (or use BIND output)
|
||||
5. **SEQUENCE nodes** → Tree reduction for parallel composition
|
||||
6. **MAP nodes** → Expanded to N parallel steps + reduction
|
||||
|
||||
### Tree Reduction for Composition
|
||||
|
||||
Instead of sequential pairwise composition:
|
||||
```
|
||||
A → B → C → D (3 sequential steps)
|
||||
```
|
||||
|
||||
Use parallel tree reduction:
|
||||
```
|
||||
A ─┬─ AB ─┬─ ABCD
|
||||
B ─┘ │
|
||||
C ─┬─ CD ─┘
|
||||
D ─┘
|
||||
|
||||
Level 0: [A, B, C, D] (4 parallel)
|
||||
Level 1: [AB, CD] (2 parallel)
|
||||
Level 2: [ABCD] (1 final)
|
||||
```
|
||||
|
||||
This reduces O(N) to O(log N) levels.
|
||||
|
||||
## Phase 3: Execution
|
||||
|
||||
### Purpose
|
||||
Execute the plan, skipping steps with cached results.
|
||||
|
||||
### Inputs
|
||||
- ExecutionPlan with pre-computed cache IDs
|
||||
- Cache state (which IDs already exist)
|
||||
|
||||
### Process
|
||||
|
||||
1. **Claim Check**: For each step, atomically check if result is cached
|
||||
2. **Task Dispatch**: Uncached steps dispatched to Celery workers
|
||||
3. **Parallel Execution**: Independent steps run concurrently
|
||||
4. **Result Storage**: Each step stores result with its cache ID
|
||||
5. **Progress Tracking**: Real-time status updates
|
||||
|
||||
### Hash-Based Task Claiming
|
||||
|
||||
Prevents duplicate work when multiple workers process the same plan:
|
||||
|
||||
```lua
|
||||
-- Redis Lua script for atomic claim
|
||||
local key = KEYS[1]
|
||||
local data = redis.call('GET', key)
|
||||
if data then
|
||||
local status = cjson.decode(data)
|
||||
if status.status == 'running' or
|
||||
status.status == 'completed' or
|
||||
status.status == 'cached' then
|
||||
return 0 -- Already claimed/done
|
||||
end
|
||||
end
|
||||
local claim_data = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
redis.call('SETEX', key, ttl, claim_data)
|
||||
return 1 -- Successfully claimed
|
||||
```
|
||||
|
||||
### Celery Task Structure
|
||||
|
||||
```python
|
||||
@app.task(bind=True)
|
||||
def execute_step(self, step_json: str, plan_id: str) -> dict:
|
||||
"""Execute a single step with caching."""
|
||||
step = ExecutionStep.from_json(step_json)
|
||||
|
||||
# Check cache first
|
||||
if cache.has(step.cache_id):
|
||||
return {'status': 'cached', 'cache_id': step.cache_id}
|
||||
|
||||
# Try to claim this work
|
||||
if not claim_task(step.cache_id, self.request.id):
|
||||
# Another worker is handling it, wait for result
|
||||
return wait_for_result(step.cache_id)
|
||||
|
||||
# Do the work
|
||||
executor = get_executor(step.node_type)
|
||||
input_paths = [cache.get(s) for s in step.input_steps]
|
||||
output_path = cache.get_output_path(step.cache_id)
|
||||
|
||||
result_path = executor.execute(step.config, input_paths, output_path)
|
||||
cache.put(step.cache_id, result_path)
|
||||
|
||||
return {'status': 'completed', 'cache_id': step.cache_id}
|
||||
```
|
||||
|
||||
### Execution Orchestration
|
||||
|
||||
```python
|
||||
class PlanExecutor:
|
||||
def execute(self, plan: ExecutionPlan) -> ExecutionResult:
|
||||
"""Execute plan with parallel Celery tasks."""
|
||||
|
||||
# Group steps by level (steps at same level can run in parallel)
|
||||
levels = self.compute_dependency_levels(plan.steps)
|
||||
|
||||
for level_steps in levels:
|
||||
# Dispatch all steps at this level
|
||||
tasks = [
|
||||
execute_step.delay(step.to_json(), plan.plan_id)
|
||||
for step in level_steps
|
||||
if not self.cache.has(step.cache_id)
|
||||
]
|
||||
|
||||
# Wait for level completion
|
||||
results = [task.get() for task in tasks]
|
||||
|
||||
return self.collect_results(plan)
|
||||
```
|
||||
|
||||
## Data Flow Example
|
||||
|
||||
### Recipe: beat-cuts
|
||||
```yaml
|
||||
nodes:
|
||||
- id: music
|
||||
type: SOURCE
|
||||
config: { input: true }
|
||||
|
||||
- id: beats
|
||||
type: ANALYZE
|
||||
config: { feature: beats }
|
||||
inputs: [music]
|
||||
|
||||
- id: videos
|
||||
type: SOURCE_LIST
|
||||
config: { input: true }
|
||||
|
||||
- id: slices
|
||||
type: MAP
|
||||
config: { operation: RANDOM_SLICE }
|
||||
inputs:
|
||||
items: videos
|
||||
timing: beats
|
||||
|
||||
- id: final
|
||||
type: SEQUENCE
|
||||
inputs: [slices]
|
||||
```
|
||||
|
||||
### Phase 1: Analysis
|
||||
```python
|
||||
# Input: music file with hash abc123
|
||||
analysis = {
|
||||
'abc123': AnalysisResult(
|
||||
beats=[0.0, 0.48, 0.96, 1.44, ...],
|
||||
tempo=125.0,
|
||||
duration=180.0
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Phase 2: Planning
|
||||
```python
|
||||
# Expands MAP into concrete steps
|
||||
plan = ExecutionPlan(
|
||||
steps=[
|
||||
# Source steps
|
||||
ExecutionStep(id='music', cache_id='abc123', ...),
|
||||
ExecutionStep(id='video_0', cache_id='def456', ...),
|
||||
ExecutionStep(id='video_1', cache_id='ghi789', ...),
|
||||
|
||||
# Slice steps (one per beat group)
|
||||
ExecutionStep(id='slice_0', cache_id='hash(video_0+timing)', ...),
|
||||
ExecutionStep(id='slice_1', cache_id='hash(video_1+timing)', ...),
|
||||
...
|
||||
|
||||
# Tree reduction for sequence
|
||||
ExecutionStep(id='seq_0_1', inputs=['slice_0', 'slice_1'], ...),
|
||||
ExecutionStep(id='seq_2_3', inputs=['slice_2', 'slice_3'], ...),
|
||||
ExecutionStep(id='seq_final', inputs=['seq_0_1', 'seq_2_3'], ...),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Phase 3: Execution
|
||||
```
|
||||
Level 0: [music, video_0, video_1] → all cached (SOURCE)
|
||||
Level 1: [slice_0, slice_1, slice_2, slice_3] → 4 parallel tasks
|
||||
Level 2: [seq_0_1, seq_2_3] → 2 parallel SEQUENCE tasks
|
||||
Level 3: [seq_final] → 1 final SEQUENCE task
|
||||
```
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
artdag/
|
||||
├── artdag/
|
||||
│ ├── analysis/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── analyzer.py # Main Analyzer class
|
||||
│ │ ├── audio.py # Audio feature extraction
|
||||
│ │ └── video.py # Video feature extraction
|
||||
│ ├── planning/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── planner.py # RecipePlanner class
|
||||
│ │ ├── schema.py # ExecutionPlan, ExecutionStep
|
||||
│ │ └── tree_reduction.py # Parallel composition optimizer
|
||||
│ └── execution/
|
||||
│ ├── __init__.py
|
||||
│ ├── executor.py # PlanExecutor class
|
||||
│ └── claiming.py # Hash-based task claiming
|
||||
|
||||
art-celery/
|
||||
├── tasks/
|
||||
│ ├── __init__.py
|
||||
│ ├── analyze.py # analyze_inputs task
|
||||
│ ├── plan.py # generate_plan task
|
||||
│ ├── execute.py # execute_step task
|
||||
│ └── orchestrate.py # run_plan (coordinates all)
|
||||
├── claiming.py # Redis Lua scripts
|
||||
└── ...
|
||||
```
|
||||
|
||||
## CLI Interface
|
||||
|
||||
```bash
|
||||
# Full pipeline
|
||||
artdag run-recipe recipes/beat-cuts/recipe.yaml \
|
||||
-i music:abc123 \
|
||||
-i videos:def456,ghi789
|
||||
|
||||
# Phase by phase
|
||||
artdag analyze recipes/beat-cuts/recipe.yaml -i music:abc123
|
||||
# → outputs analysis.json
|
||||
|
||||
artdag plan recipes/beat-cuts/recipe.yaml --analysis analysis.json
|
||||
# → outputs plan.json
|
||||
|
||||
artdag execute plan.json
|
||||
# → runs with caching, skips completed steps
|
||||
|
||||
# Dry run (show what would execute)
|
||||
artdag execute plan.json --dry-run
|
||||
# → shows which steps are cached vs need execution
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Development Speed**: Change recipe, re-run → only affected steps execute
|
||||
2. **Parallelism**: Independent steps run on multiple Celery workers
|
||||
3. **Reproducibility**: Same inputs + recipe = same cache IDs = same output
|
||||
4. **Visibility**: Plan shows exactly what will happen before execution
|
||||
5. **Cost Control**: Estimate compute before committing resources
|
||||
6. **Fault Tolerance**: Failed runs resume from last successful step
|
||||
443
core/docs/IPFS_PRIMARY_ARCHITECTURE.md
Normal file
443
core/docs/IPFS_PRIMARY_ARCHITECTURE.md
Normal file
@@ -0,0 +1,443 @@
|
||||
# IPFS-Primary Architecture (Sketch)
|
||||
|
||||
A simplified L1 architecture for large-scale distributed rendering where IPFS is the primary data store.
|
||||
|
||||
## Current vs Simplified
|
||||
|
||||
| Component | Current | Simplified |
|
||||
|-----------|---------|------------|
|
||||
| Local cache | Custom, per-worker | IPFS node handles it |
|
||||
| Redis content_index | content_hash → node_id | Eliminated |
|
||||
| Redis ipfs_index | content_hash → ipfs_cid | Eliminated |
|
||||
| Step inputs | File paths | IPFS CIDs |
|
||||
| Step outputs | File path + CID | Just CID |
|
||||
| Cache lookup | Local → Redis → IPFS | Just IPFS |
|
||||
|
||||
## Core Principle
|
||||
|
||||
**Steps receive CIDs, produce CIDs. No file paths cross machine boundaries.**
|
||||
|
||||
```
|
||||
Step input: [cid1, cid2, ...]
|
||||
Step output: cid_out
|
||||
```
|
||||
|
||||
## Worker Architecture
|
||||
|
||||
Each worker runs:
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────┐
|
||||
│ Worker Node │
|
||||
│ │
|
||||
│ ┌───────────┐ ┌──────────────┐ │
|
||||
│ │ Celery │────│ IPFS Node │ │
|
||||
│ │ Worker │ │ (local) │ │
|
||||
│ └───────────┘ └──────────────┘ │
|
||||
│ │ │ │
|
||||
│ │ ┌─────┴─────┐ │
|
||||
│ │ │ Local │ │
|
||||
│ │ │ Blockstore│ │
|
||||
│ │ └───────────┘ │
|
||||
│ │ │
|
||||
│ ┌────┴────┐ │
|
||||
│ │ /tmp │ (ephemeral workspace) │
|
||||
│ └─────────┘ │
|
||||
└─────────────────────────────────────┘
|
||||
│
|
||||
│ IPFS libp2p
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Other IPFS │
|
||||
│ Nodes │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
## Execution Flow
|
||||
|
||||
### 1. Plan Generation (unchanged)
|
||||
|
||||
```python
|
||||
plan = planner.plan(recipe, input_hashes)
|
||||
# plan.steps[].cache_id = deterministic hash
|
||||
```
|
||||
|
||||
### 2. Input Registration
|
||||
|
||||
Before execution, register inputs with IPFS:
|
||||
|
||||
```python
|
||||
input_cids = {}
|
||||
for name, path in inputs.items():
|
||||
cid = ipfs.add(path)
|
||||
input_cids[name] = cid
|
||||
|
||||
# Plan now carries CIDs
|
||||
plan.input_cids = input_cids
|
||||
```
|
||||
|
||||
### 3. Step Execution
|
||||
|
||||
```python
|
||||
@celery.task
|
||||
def execute_step(step_json: str, input_cids: dict[str, str]) -> str:
|
||||
"""Execute step, return output CID."""
|
||||
step = ExecutionStep.from_json(step_json)
|
||||
|
||||
# Check if already computed (by cache_id as IPNS key or DHT lookup)
|
||||
existing_cid = ipfs.resolve(f"/ipns/{step.cache_id}")
|
||||
if existing_cid:
|
||||
return existing_cid
|
||||
|
||||
# Fetch inputs from IPFS → local temp files
|
||||
input_paths = []
|
||||
for input_step_id in step.input_steps:
|
||||
cid = input_cids[input_step_id]
|
||||
path = ipfs.get(cid, f"/tmp/{cid}") # IPFS node caches automatically
|
||||
input_paths.append(path)
|
||||
|
||||
# Execute
|
||||
output_path = f"/tmp/{step.cache_id}.mkv"
|
||||
executor = get_executor(step.node_type)
|
||||
executor.execute(step.config, input_paths, output_path)
|
||||
|
||||
# Add output to IPFS
|
||||
output_cid = ipfs.add(output_path)
|
||||
|
||||
# Publish cache_id → CID mapping (optional, for cache hits)
|
||||
ipfs.name_publish(step.cache_id, output_cid)
|
||||
|
||||
# Cleanup temp files
|
||||
cleanup_temp(input_paths + [output_path])
|
||||
|
||||
return output_cid
|
||||
```
|
||||
|
||||
### 4. Orchestration
|
||||
|
||||
```python
|
||||
@celery.task
|
||||
def run_plan(plan_json: str) -> str:
|
||||
"""Execute plan, return final output CID."""
|
||||
plan = ExecutionPlan.from_json(plan_json)
|
||||
|
||||
# CID results accumulate as steps complete
|
||||
cid_results = dict(plan.input_cids)
|
||||
|
||||
for level in plan.get_steps_by_level():
|
||||
# Parallel execution within level
|
||||
tasks = []
|
||||
for step in level:
|
||||
step_input_cids = {
|
||||
sid: cid_results[sid]
|
||||
for sid in step.input_steps
|
||||
}
|
||||
tasks.append(execute_step.s(step.to_json(), step_input_cids))
|
||||
|
||||
# Wait for level to complete
|
||||
results = group(tasks).apply_async().get()
|
||||
|
||||
# Record output CIDs
|
||||
for step, cid in zip(level, results):
|
||||
cid_results[step.step_id] = cid
|
||||
|
||||
return cid_results[plan.output_step]
|
||||
```
|
||||
|
||||
## What's Eliminated
|
||||
|
||||
### No more Redis indexes
|
||||
|
||||
```python
|
||||
# BEFORE: Complex index management
|
||||
self._set_content_index(content_hash, node_id) # Redis + local
|
||||
self._set_ipfs_index(content_hash, ipfs_cid) # Redis + local
|
||||
node_id = self._get_content_index(content_hash) # Check Redis, fallback local
|
||||
|
||||
# AFTER: Just CIDs
|
||||
output_cid = ipfs.add(output_path)
|
||||
return output_cid
|
||||
```
|
||||
|
||||
### No more local cache management
|
||||
|
||||
```python
|
||||
# BEFORE: Custom cache with entries, metadata, cleanup
|
||||
cache.put(node_id, source_path, node_type, execution_time)
|
||||
cache.get(node_id)
|
||||
cache.has(node_id)
|
||||
cache.cleanup_lru()
|
||||
|
||||
# AFTER: IPFS handles it
|
||||
ipfs.add(path) # Store
|
||||
ipfs.get(cid) # Retrieve (cached by IPFS node)
|
||||
ipfs.pin(cid) # Keep permanently
|
||||
ipfs.gc() # Cleanup unpinned
|
||||
```
|
||||
|
||||
### No more content_hash vs node_id confusion
|
||||
|
||||
```python
|
||||
# BEFORE: Two identifiers
|
||||
content_hash = sha3_256(file_bytes) # What the file IS
|
||||
node_id = cache_id # What computation produced it
|
||||
# Need indexes to map between them
|
||||
|
||||
# AFTER: One identifier
|
||||
cid = ipfs.add(file) # Content-addressed, includes hash
|
||||
# CID IS the identifier
|
||||
```
|
||||
|
||||
## Cache Hit Detection
|
||||
|
||||
Two options:
|
||||
|
||||
### Option A: IPNS (mutable names)
|
||||
|
||||
```python
|
||||
# Publish: cache_id → CID
|
||||
ipfs.name_publish(key=cache_id, value=output_cid)
|
||||
|
||||
# Lookup before executing
|
||||
existing = ipfs.name_resolve(cache_id)
|
||||
if existing:
|
||||
return existing # Cache hit
|
||||
```
|
||||
|
||||
### Option B: DHT record
|
||||
|
||||
```python
|
||||
# Store in DHT: cache_id → CID
|
||||
ipfs.dht_put(cache_id, output_cid)
|
||||
|
||||
# Lookup
|
||||
existing = ipfs.dht_get(cache_id)
|
||||
```
|
||||
|
||||
### Option C: Redis (minimal)
|
||||
|
||||
Keep Redis just for cache_id → CID mapping:
|
||||
|
||||
```python
|
||||
# Store
|
||||
redis.hset("artdag:cache", cache_id, output_cid)
|
||||
|
||||
# Lookup
|
||||
existing = redis.hget("artdag:cache", cache_id)
|
||||
```
|
||||
|
||||
This is simpler than current approach - one hash, one mapping, no content_hash/node_id confusion.
|
||||
|
||||
## Claiming (Preventing Duplicate Work)
|
||||
|
||||
Still need Redis for atomic claiming:
|
||||
|
||||
```python
|
||||
# Claim before executing
|
||||
claimed = redis.set(f"artdag:claim:{cache_id}", worker_id, nx=True, ex=300)
|
||||
if not claimed:
|
||||
# Another worker is doing it - wait for result
|
||||
return wait_for_result(cache_id)
|
||||
```
|
||||
|
||||
Or use IPFS pubsub for coordination.
|
||||
|
||||
## Data Flow Diagram
|
||||
|
||||
```
|
||||
┌─────────────┐
|
||||
│ Recipe │
|
||||
│ + Inputs │
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Planner │
|
||||
│ (compute │
|
||||
│ cache_ids) │
|
||||
└──────┬──────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────┐
|
||||
│ ExecutionPlan │
|
||||
│ - steps with cache_ids │
|
||||
│ - input_cids (from ipfs.add) │
|
||||
└─────────────────┬───────────────┘
|
||||
│
|
||||
┌────────────┼────────────┐
|
||||
▼ ▼ ▼
|
||||
┌────────┐ ┌────────┐ ┌────────┐
|
||||
│Worker 1│ │Worker 2│ │Worker 3│
|
||||
│ │ │ │ │ │
|
||||
│ IPFS │◄──│ IPFS │◄──│ IPFS │
|
||||
│ Node │──►│ Node │──►│ Node │
|
||||
└───┬────┘ └───┬────┘ └───┬────┘
|
||||
│ │ │
|
||||
└────────────┼────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Final CID │
|
||||
│ (output) │
|
||||
└─────────────┘
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Simpler code** - No custom cache, no dual indexes
|
||||
2. **Automatic distribution** - IPFS handles replication
|
||||
3. **Content verification** - CIDs are self-verifying
|
||||
4. **Scalable** - Add workers = add IPFS nodes = more cache capacity
|
||||
5. **Resilient** - Any node can serve any content
|
||||
|
||||
## Tradeoffs
|
||||
|
||||
1. **IPFS dependency** - Every worker needs IPFS node
|
||||
2. **Initial fetch latency** - First fetch may be slower than local disk
|
||||
3. **IPNS latency** - Name resolution can be slow (Option C avoids this)
|
||||
|
||||
## Trust Domains (Cluster Key)
|
||||
|
||||
Systems can share work through IPFS, but how do you trust them?
|
||||
|
||||
**Problem:** A malicious system could return wrong CIDs for computed steps.
|
||||
|
||||
**Solution:** Cluster key creates isolated trust domains:
|
||||
|
||||
```bash
|
||||
export ARTDAG_CLUSTER_KEY="my-secret-shared-key"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- The cluster key is mixed into all cache_id computations
|
||||
- Systems with the same key produce the same cache_ids
|
||||
- Systems with different keys have separate cache namespaces
|
||||
- Only share the key with trusted partners
|
||||
|
||||
```
|
||||
cache_id = SHA3-256(cluster_key + node_type + config + inputs)
|
||||
```
|
||||
|
||||
**Trust model:**
|
||||
| Scenario | Same Key? | Can Share Work? |
|
||||
|----------|-----------|-----------------|
|
||||
| Same organization | Yes | Yes |
|
||||
| Trusted partner | Yes (shared) | Yes |
|
||||
| Unknown system | No | No (different cache_ids) |
|
||||
|
||||
**Configuration:**
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
environment:
|
||||
- ARTDAG_CLUSTER_KEY=your-secret-key-here
|
||||
```
|
||||
|
||||
**Programmatic:**
|
||||
```python
|
||||
from artdag.planning.schema import set_cluster_key
|
||||
set_cluster_key("my-secret-key")
|
||||
```
|
||||
|
||||
## Implementation
|
||||
|
||||
The simplified architecture is implemented in `art-celery/`:
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `hybrid_state.py` | Hybrid state manager (Redis + IPNS) |
|
||||
| `tasks/execute_cid.py` | Step execution with CIDs |
|
||||
| `tasks/analyze_cid.py` | Analysis with CIDs |
|
||||
| `tasks/orchestrate_cid.py` | Full pipeline orchestration |
|
||||
|
||||
### Key Functions
|
||||
|
||||
**Registration (local → IPFS):**
|
||||
- `register_input_cid(path)` → `{cid, content_hash}`
|
||||
- `register_recipe_cid(path)` → `{cid, name, version}`
|
||||
|
||||
**Analysis:**
|
||||
- `analyze_input_cid(input_cid, input_hash, features)` → `{analysis_cid}`
|
||||
|
||||
**Planning:**
|
||||
- `generate_plan_cid(recipe_cid, input_cids, input_hashes, analysis_cids)` → `{plan_cid}`
|
||||
|
||||
**Execution:**
|
||||
- `execute_step_cid(step_json, input_cids)` → `{cid}`
|
||||
- `execute_plan_from_cid(plan_cid, input_cids)` → `{output_cid}`
|
||||
|
||||
**Full Pipeline:**
|
||||
- `run_recipe_cid(recipe_cid, input_cids, input_hashes)` → `{output_cid, all_cids}`
|
||||
- `run_from_local(recipe_path, input_paths)` → registers + runs
|
||||
|
||||
### Hybrid State Manager
|
||||
|
||||
For distributed L1 coordination, use the `HybridStateManager` which provides:
|
||||
|
||||
**Fast path (local Redis):**
|
||||
- `get_cached_cid(cache_id)` / `set_cached_cid(cache_id, cid)` - microsecond lookups
|
||||
- `try_claim(cache_id, worker_id)` / `release_claim(cache_id)` - atomic claiming
|
||||
- `get_analysis_cid()` / `set_analysis_cid()` - analysis cache
|
||||
- `get_plan_cid()` / `set_plan_cid()` - plan cache
|
||||
- `get_run_cid()` / `set_run_cid()` - run cache
|
||||
|
||||
**Slow path (background IPNS sync):**
|
||||
- Periodically syncs local state with global IPNS state (default: every 30s)
|
||||
- Pulls new entries from remote nodes
|
||||
- Pushes local updates to IPNS
|
||||
|
||||
**Configuration:**
|
||||
```bash
|
||||
# Enable IPNS sync
|
||||
export ARTDAG_IPNS_SYNC=true
|
||||
export ARTDAG_IPNS_SYNC_INTERVAL=30 # seconds
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from hybrid_state import get_state_manager
|
||||
|
||||
state = get_state_manager()
|
||||
|
||||
# Fast local lookup
|
||||
cid = state.get_cached_cid(cache_id)
|
||||
|
||||
# Fast local write (synced in background)
|
||||
state.set_cached_cid(cache_id, output_cid)
|
||||
|
||||
# Atomic claim
|
||||
if state.try_claim(cache_id, worker_id):
|
||||
# We have the lock
|
||||
...
|
||||
```
|
||||
|
||||
**Trade-offs:**
|
||||
- Local Redis: Fast (microseconds), single node
|
||||
- IPNS sync: Slow (seconds), eventually consistent across nodes
|
||||
- Duplicate work: Accepted (idempotent - same inputs → same CID)
|
||||
|
||||
### Redis Usage (minimal)
|
||||
|
||||
| Key | Type | Purpose |
|
||||
|-----|------|---------|
|
||||
| `artdag:cid_cache` | Hash | cache_id → output CID |
|
||||
| `artdag:analysis_cache` | Hash | input_hash:features → analysis CID |
|
||||
| `artdag:plan_cache` | Hash | plan_id → plan CID |
|
||||
| `artdag:run_cache` | Hash | run_id → output CID |
|
||||
| `artdag:claim:{cache_id}` | String | worker_id (TTL 5 min) |
|
||||
|
||||
## Migration Path
|
||||
|
||||
1. Keep current system working ✓
|
||||
2. Add CID-based tasks ✓
|
||||
- `execute_cid.py` ✓
|
||||
- `analyze_cid.py` ✓
|
||||
- `orchestrate_cid.py` ✓
|
||||
3. Add `--ipfs-primary` flag to CLI ✓
|
||||
4. Add hybrid state manager for L1 coordination ✓
|
||||
5. Gradually deprecate local cache code
|
||||
6. Remove old tasks when CID versions are stable
|
||||
|
||||
## See Also
|
||||
|
||||
- [L1_STORAGE.md](L1_STORAGE.md) - Current L1 architecture
|
||||
- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase model
|
||||
181
core/docs/L1_STORAGE.md
Normal file
181
core/docs/L1_STORAGE.md
Normal file
@@ -0,0 +1,181 @@
|
||||
# L1 Distributed Storage Architecture
|
||||
|
||||
This document describes how data is stored when running artdag on L1 (the distributed rendering layer).
|
||||
|
||||
## Overview
|
||||
|
||||
L1 uses four storage systems working together:
|
||||
|
||||
| System | Purpose | Data Stored |
|
||||
|--------|---------|-------------|
|
||||
| **Local Cache** | Hot storage (fast access) | Media files, plans, analysis |
|
||||
| **IPFS** | Durable content-addressed storage | All media outputs |
|
||||
| **Redis** | Coordination & indexes | Claims, mappings, run status |
|
||||
| **PostgreSQL** | Metadata & ownership | User data, provenance |
|
||||
|
||||
## Storage Flow
|
||||
|
||||
When a step executes on L1:
|
||||
|
||||
```
|
||||
1. Executor produces output file
|
||||
2. Store in local cache (fast)
|
||||
3. Compute content_hash = SHA3-256(file)
|
||||
4. Upload to IPFS → get ipfs_cid
|
||||
5. Update indexes:
|
||||
- content_hash → node_id (Redis + local)
|
||||
- content_hash → ipfs_cid (Redis + local)
|
||||
```
|
||||
|
||||
Every intermediate step output (SEGMENT, SEQUENCE, etc.) gets its own IPFS CID.
|
||||
|
||||
## Local Cache
|
||||
|
||||
Hot storage on each worker node:
|
||||
|
||||
```
|
||||
cache_dir/
|
||||
index.json # Cache metadata
|
||||
content_index.json # content_hash → node_id
|
||||
ipfs_index.json # content_hash → ipfs_cid
|
||||
plans/
|
||||
{plan_id}.json # Cached execution plans
|
||||
analysis/
|
||||
{hash}.json # Analysis results
|
||||
{node_id}/
|
||||
output.mkv # Media output
|
||||
metadata.json # CacheEntry metadata
|
||||
```
|
||||
|
||||
## IPFS - Durable Media Storage
|
||||
|
||||
All media files are stored in IPFS for durability and content-addressing.
|
||||
|
||||
**Supported pinning providers:**
|
||||
- Pinata
|
||||
- web3.storage
|
||||
- NFT.Storage
|
||||
- Infura IPFS
|
||||
- Filebase (S3-compatible)
|
||||
- Storj (decentralized)
|
||||
- Local IPFS node
|
||||
|
||||
**Configuration:**
|
||||
```bash
|
||||
IPFS_API=/ip4/127.0.0.1/tcp/5001 # Local IPFS daemon
|
||||
```
|
||||
|
||||
## Redis - Coordination
|
||||
|
||||
Redis handles distributed coordination across workers.
|
||||
|
||||
### Key Patterns
|
||||
|
||||
| Key | Type | Purpose |
|
||||
|-----|------|---------|
|
||||
| `artdag:run:{run_id}` | String | Run status, timestamps, celery task ID |
|
||||
| `artdag:content_index` | Hash | content_hash → node_id mapping |
|
||||
| `artdag:ipfs_index` | Hash | content_hash → ipfs_cid mapping |
|
||||
| `artdag:claim:{cache_id}` | String | Task claiming (prevents duplicate work) |
|
||||
|
||||
### Task Claiming
|
||||
|
||||
Lua scripts ensure atomic claiming across workers:
|
||||
|
||||
```
|
||||
Status flow: PENDING → CLAIMED → RUNNING → COMPLETED/CACHED/FAILED
|
||||
TTL: 5 minutes for claims, 1 hour for results
|
||||
```
|
||||
|
||||
This prevents two workers from executing the same step.
|
||||
|
||||
## PostgreSQL - Metadata
|
||||
|
||||
Stores ownership, provenance, and sharing metadata.
|
||||
|
||||
### Tables
|
||||
|
||||
```sql
|
||||
-- Core cache (shared)
|
||||
cache_items (content_hash, ipfs_cid, created_at)
|
||||
|
||||
-- Per-user ownership
|
||||
item_types (content_hash, actor_id, type, metadata)
|
||||
|
||||
-- Run cache (deterministic identity)
|
||||
run_cache (
|
||||
run_id, -- SHA3-256(sorted_inputs + recipe)
|
||||
output_hash,
|
||||
ipfs_cid,
|
||||
provenance_cid,
|
||||
recipe, inputs, actor_id
|
||||
)
|
||||
|
||||
-- Storage backends
|
||||
storage_backends (actor_id, provider_type, config, capacity_gb)
|
||||
|
||||
-- What's stored where
|
||||
storage_pins (content_hash, storage_id, ipfs_cid, pin_type)
|
||||
```
|
||||
|
||||
## Cache Lookup Flow
|
||||
|
||||
When a worker needs a file:
|
||||
|
||||
```
|
||||
1. Check local cache by cache_id (fastest)
|
||||
2. Check Redis content_index: content_hash → node_id
|
||||
3. Check PostgreSQL cache_items
|
||||
4. Retrieve from IPFS by CID
|
||||
5. Store in local cache for next hit
|
||||
```
|
||||
|
||||
## Local vs L1 Comparison
|
||||
|
||||
| Feature | Local Testing | L1 Distributed |
|
||||
|---------|---------------|----------------|
|
||||
| Local cache | Yes | Yes |
|
||||
| IPFS | No | Yes |
|
||||
| Redis | No | Yes |
|
||||
| PostgreSQL | No | Yes |
|
||||
| Multi-worker | No | Yes |
|
||||
| Task claiming | No | Yes (Lua scripts) |
|
||||
| Durability | Filesystem only | IPFS + PostgreSQL |
|
||||
|
||||
## Content Addressing
|
||||
|
||||
All storage uses SHA3-256 (quantum-resistant):
|
||||
|
||||
- **Files:** `content_hash = SHA3-256(file_bytes)`
|
||||
- **Computation:** `cache_id = SHA3-256(type + config + input_hashes)`
|
||||
- **Run identity:** `run_id = SHA3-256(sorted_inputs + recipe)`
|
||||
- **Plans:** `plan_id = SHA3-256(recipe + inputs + analysis)`
|
||||
|
||||
This ensures:
|
||||
- Same inputs → same outputs (reproducibility)
|
||||
- Automatic deduplication across workers
|
||||
- Content verification (tamper detection)
|
||||
|
||||
## Configuration
|
||||
|
||||
Default locations:
|
||||
|
||||
```bash
|
||||
# Local cache
|
||||
~/.artdag/cache # Default
|
||||
/data/cache # Docker
|
||||
|
||||
# Redis
|
||||
redis://localhost:6379/5
|
||||
|
||||
# PostgreSQL
|
||||
postgresql://user:pass@host/artdag
|
||||
|
||||
# IPFS
|
||||
/ip4/127.0.0.1/tcp/5001
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [OFFLINE_TESTING.md](OFFLINE_TESTING.md) - Local testing without L1
|
||||
- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase execution model
|
||||
211
core/docs/OFFLINE_TESTING.md
Normal file
211
core/docs/OFFLINE_TESTING.md
Normal file
@@ -0,0 +1,211 @@
|
||||
# Offline Testing Strategy
|
||||
|
||||
This document describes how to test artdag locally without requiring Redis, IPFS, Celery, or any external distributed infrastructure.
|
||||
|
||||
## Overview
|
||||
|
||||
The artdag system uses a **3-Phase Execution Model** that enables complete offline testing:
|
||||
|
||||
1. **Analysis** - Extract features from input media
|
||||
2. **Planning** - Generate deterministic execution plan with pre-computed cache IDs
|
||||
3. **Execution** - Run plan steps, skipping cached results
|
||||
|
||||
This separation allows testing each phase independently and running full pipelines locally.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run a full offline test with a video file:
|
||||
|
||||
```bash
|
||||
./examples/test_local.sh ../artdag-art-source/dog.mkv
|
||||
```
|
||||
|
||||
This will:
|
||||
1. Compute the SHA3-256 hash of the input video
|
||||
2. Run the `simple_sequence` recipe
|
||||
3. Store all outputs in `test_cache/`
|
||||
|
||||
## Test Scripts
|
||||
|
||||
### `test_local.sh` - Full Pipeline Test
|
||||
|
||||
Location: `./examples/test_local.sh`
|
||||
|
||||
Runs the complete artdag pipeline offline with a real video file.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
./examples/test_local.sh <video_file>
|
||||
```
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
./examples/test_local.sh ../artdag-art-source/dog.mkv
|
||||
```
|
||||
|
||||
**What it does:**
|
||||
- Computes content hash of input video
|
||||
- Runs `artdag run-recipe` with `simple_sequence.yaml`
|
||||
- Stores outputs in `test_cache/` directory
|
||||
- No external services required
|
||||
|
||||
### `test_plan.py` - Planning Phase Test
|
||||
|
||||
Location: `./examples/test_plan.py`
|
||||
|
||||
Tests the planning phase without requiring any media files.
|
||||
|
||||
**Usage:**
|
||||
```bash
|
||||
python3 examples/test_plan.py
|
||||
```
|
||||
|
||||
**What it tests:**
|
||||
- Recipe loading and YAML parsing
|
||||
- Execution plan generation
|
||||
- Cache ID computation (deterministic)
|
||||
- Multi-level parallel step organization
|
||||
- Human-readable step names
|
||||
- Multi-output support
|
||||
|
||||
**Output:**
|
||||
- Prints plan structure to console
|
||||
- Saves full plan to `test_plan_output.json`
|
||||
|
||||
### `simple_sequence.yaml` - Sample Recipe
|
||||
|
||||
Location: `./examples/simple_sequence.yaml`
|
||||
|
||||
A simple recipe for testing that:
|
||||
- Takes a video input
|
||||
- Extracts two segments (0-2s and 5-7s)
|
||||
- Concatenates them with SEQUENCE
|
||||
|
||||
## Test Outputs
|
||||
|
||||
All test outputs are stored locally and git-ignored:
|
||||
|
||||
| Output | Description |
|
||||
|--------|-------------|
|
||||
| `test_cache/` | Cached execution results (media files, analysis, plans) |
|
||||
| `test_cache/plans/` | Cached execution plans by plan_id |
|
||||
| `test_cache/analysis/` | Cached analysis results by input hash |
|
||||
| `test_plan_output.json` | Generated execution plan from `test_plan.py` |
|
||||
|
||||
## Unit Tests
|
||||
|
||||
The project includes a comprehensive pytest test suite in `tests/`:
|
||||
|
||||
```bash
|
||||
# Run all unit tests
|
||||
pytest
|
||||
|
||||
# Run specific test file
|
||||
pytest tests/test_dag.py
|
||||
pytest tests/test_engine.py
|
||||
pytest tests/test_cache.py
|
||||
```
|
||||
|
||||
## Testing Each Phase
|
||||
|
||||
### Phase 1: Analysis Only
|
||||
|
||||
Extract features without full execution:
|
||||
|
||||
```bash
|
||||
python3 -m artdag.cli analyze <recipe> -i <name>:<hash>@<path> --features beats,energy
|
||||
```
|
||||
|
||||
### Phase 2: Planning Only
|
||||
|
||||
Generate an execution plan (no media needed):
|
||||
|
||||
```bash
|
||||
python3 -m artdag.cli plan <recipe> -i <name>:<hash>
|
||||
```
|
||||
|
||||
Or use the test script:
|
||||
|
||||
```bash
|
||||
python3 examples/test_plan.py
|
||||
```
|
||||
|
||||
### Phase 3: Execution Only
|
||||
|
||||
Execute a pre-generated plan:
|
||||
|
||||
```bash
|
||||
python3 -m artdag.cli execute plan.json
|
||||
```
|
||||
|
||||
With dry-run to see what would execute:
|
||||
|
||||
```bash
|
||||
python3 -m artdag.cli execute plan.json --dry-run
|
||||
```
|
||||
|
||||
## Key Testing Features
|
||||
|
||||
### Content Addressing
|
||||
|
||||
All nodes have deterministic IDs computed as:
|
||||
```
|
||||
SHA3-256(type + config + sorted(input_IDs))
|
||||
```
|
||||
|
||||
Same inputs always produce same cache IDs, enabling:
|
||||
- Reproducibility across runs
|
||||
- Automatic deduplication
|
||||
- Incremental execution (only changed steps run)
|
||||
|
||||
### Local Caching
|
||||
|
||||
The `test_cache/` directory stores:
|
||||
- `plans/{plan_id}.json` - Execution plans (deterministic hash of recipe + inputs + analysis)
|
||||
- `analysis/{hash}.json` - Analysis results (audio beats, tempo, energy)
|
||||
- `{cache_id}/output.mkv` - Media outputs from each step
|
||||
|
||||
Subsequent test runs automatically skip cached steps. Plans are cached by their `plan_id`, which is a SHA3-256 hash of the recipe, input hashes, and analysis results - so the same recipe with the same inputs always produces the same plan.
|
||||
|
||||
### No External Dependencies
|
||||
|
||||
Offline testing requires:
|
||||
- Python 3.9+
|
||||
- ffmpeg (for media processing)
|
||||
- No Redis, IPFS, Celery, or network access
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
1. **Check cache contents:**
|
||||
```bash
|
||||
ls -la test_cache/
|
||||
ls -la test_cache/plans/
|
||||
```
|
||||
|
||||
2. **View cached plan:**
|
||||
```bash
|
||||
cat test_cache/plans/*.json | python3 -m json.tool | head -50
|
||||
```
|
||||
|
||||
3. **View execution plan structure:**
|
||||
```bash
|
||||
cat test_plan_output.json | python3 -m json.tool
|
||||
```
|
||||
|
||||
4. **Run with verbose output:**
|
||||
```bash
|
||||
python3 -m artdag.cli run-recipe examples/simple_sequence.yaml \
|
||||
-i "video:HASH@path" \
|
||||
--cache-dir test_cache \
|
||||
-v
|
||||
```
|
||||
|
||||
5. **Dry-run to see what would execute:**
|
||||
```bash
|
||||
python3 -m artdag.cli execute plan.json --dry-run
|
||||
```
|
||||
|
||||
## See Also
|
||||
|
||||
- [L1_STORAGE.md](L1_STORAGE.md) - Distributed storage on L1 (IPFS, Redis, PostgreSQL)
|
||||
- [EXECUTION_MODEL.md](EXECUTION_MODEL.md) - 3-phase execution model
|
||||
35
core/effects/identity/README.md
Normal file
35
core/effects/identity/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# Identity Effect
|
||||
|
||||
The identity effect returns its input unchanged. It serves as the foundational primitive in the effects registry.
|
||||
|
||||
## Purpose
|
||||
|
||||
- **Testing**: Verify the effects pipeline is working correctly
|
||||
- **No-op placeholder**: Use when an effect slot requires a value but no transformation is needed
|
||||
- **Composition base**: The neutral element for effect composition
|
||||
|
||||
## Signature
|
||||
|
||||
```
|
||||
identity(input) → input
|
||||
```
|
||||
|
||||
## Properties
|
||||
|
||||
- **Idempotent**: `identity(identity(x)) = identity(x)`
|
||||
- **Neutral**: For any effect `f`, `identity ∘ f = f ∘ identity = f`
|
||||
|
||||
## Implementation
|
||||
|
||||
```python
|
||||
def identity(input):
|
||||
return input
|
||||
```
|
||||
|
||||
## Content Hash
|
||||
|
||||
The identity effect is content-addressed by its behavior: given any input, the output hash equals the input hash.
|
||||
|
||||
## Owner
|
||||
|
||||
Registered by `@giles@artdag.rose-ash.com`
|
||||
2
core/effects/identity/requirements.txt
Normal file
2
core/effects/identity/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
# Identity effect has no dependencies
|
||||
# It's a pure function: identity(x) = x
|
||||
42
core/examples/simple_sequence.yaml
Normal file
42
core/examples/simple_sequence.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Simple sequence recipe - concatenates segments from a single input video
|
||||
name: simple_sequence
|
||||
version: "1.0"
|
||||
description: "Split input into segments and concatenate them"
|
||||
owner: test@local
|
||||
|
||||
dag:
|
||||
nodes:
|
||||
# Input source - variable (provided at runtime)
|
||||
- id: video
|
||||
type: SOURCE
|
||||
config:
|
||||
input: true
|
||||
name: "Input Video"
|
||||
description: "The video to process"
|
||||
|
||||
# Extract first 2 seconds
|
||||
- id: seg1
|
||||
type: SEGMENT
|
||||
config:
|
||||
start: 0.0
|
||||
end: 2.0
|
||||
inputs:
|
||||
- video
|
||||
|
||||
# Extract seconds 5-7
|
||||
- id: seg2
|
||||
type: SEGMENT
|
||||
config:
|
||||
start: 5.0
|
||||
end: 7.0
|
||||
inputs:
|
||||
- video
|
||||
|
||||
# Concatenate the segments
|
||||
- id: output
|
||||
type: SEQUENCE
|
||||
inputs:
|
||||
- seg1
|
||||
- seg2
|
||||
|
||||
output: output
|
||||
54
core/examples/test_local.sh
Executable file
54
core/examples/test_local.sh
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
# Local testing script for artdag
|
||||
# Tests the 3-phase execution without Redis/IPFS
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||
ARTDAG_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
CACHE_DIR="${ARTDAG_DIR}/test_cache"
|
||||
RECIPE="${SCRIPT_DIR}/simple_sequence.yaml"
|
||||
|
||||
# Check for input video
|
||||
if [ -z "$1" ]; then
|
||||
echo "Usage: $0 <video_file>"
|
||||
echo ""
|
||||
echo "Example:"
|
||||
echo " $0 /path/to/test_video.mp4"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
VIDEO_PATH="$1"
|
||||
if [ ! -f "$VIDEO_PATH" ]; then
|
||||
echo "Error: Video file not found: $VIDEO_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Compute content hash of input
|
||||
echo "=== Computing input hash ==="
|
||||
VIDEO_HASH=$(python3 -c "
|
||||
import hashlib
|
||||
with open('$VIDEO_PATH', 'rb') as f:
|
||||
print(hashlib.sha3_256(f.read()).hexdigest())
|
||||
")
|
||||
echo "Input hash: ${VIDEO_HASH:0:16}..."
|
||||
|
||||
# Change to artdag directory
|
||||
cd "$ARTDAG_DIR"
|
||||
|
||||
# Run the full pipeline
|
||||
echo ""
|
||||
echo "=== Running artdag run-recipe ==="
|
||||
echo "Recipe: $RECIPE"
|
||||
echo "Input: video:${VIDEO_HASH:0:16}...@$VIDEO_PATH"
|
||||
echo "Cache: $CACHE_DIR"
|
||||
echo ""
|
||||
|
||||
python3 -m artdag.cli run-recipe "$RECIPE" \
|
||||
-i "video:${VIDEO_HASH}@${VIDEO_PATH}" \
|
||||
--cache-dir "$CACHE_DIR"
|
||||
|
||||
echo ""
|
||||
echo "=== Done ==="
|
||||
echo "Cache directory: $CACHE_DIR"
|
||||
echo "Use 'ls -la $CACHE_DIR' to see cached outputs"
|
||||
93
core/examples/test_plan.py
Executable file
93
core/examples/test_plan.py
Executable file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test the planning phase locally.
|
||||
|
||||
This tests the new human-readable names and multi-output support
|
||||
without requiring actual video files or execution.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add artdag to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from artdag.planning import RecipePlanner, Recipe, ExecutionPlan
|
||||
|
||||
|
||||
def main():
|
||||
# Load recipe
|
||||
recipe_path = Path(__file__).parent / "simple_sequence.yaml"
|
||||
if not recipe_path.exists():
|
||||
print(f"Recipe not found: {recipe_path}")
|
||||
return 1
|
||||
|
||||
recipe = Recipe.from_file(recipe_path)
|
||||
print(f"Recipe: {recipe.name} v{recipe.version}")
|
||||
print(f"Nodes: {len(recipe.nodes)}")
|
||||
print()
|
||||
|
||||
# Fake input hash (would be real content hash in production)
|
||||
fake_input_hash = hashlib.sha3_256(b"fake video content").hexdigest()
|
||||
input_hashes = {"video": fake_input_hash}
|
||||
|
||||
print(f"Input: video -> {fake_input_hash[:16]}...")
|
||||
print()
|
||||
|
||||
# Generate plan
|
||||
planner = RecipePlanner(use_tree_reduction=True)
|
||||
plan = planner.plan(
|
||||
recipe=recipe,
|
||||
input_hashes=input_hashes,
|
||||
seed=42, # Optional seed for reproducibility
|
||||
)
|
||||
|
||||
print("=== Generated Plan ===")
|
||||
print(f"Plan ID: {plan.plan_id[:24]}...")
|
||||
print(f"Plan Name: {plan.name}")
|
||||
print(f"Recipe Name: {plan.recipe_name}")
|
||||
print(f"Output: {plan.output_name}")
|
||||
print(f"Steps: {len(plan.steps)}")
|
||||
print()
|
||||
|
||||
# Show steps by level
|
||||
steps_by_level = plan.get_steps_by_level()
|
||||
for level in sorted(steps_by_level.keys()):
|
||||
steps = steps_by_level[level]
|
||||
print(f"Level {level}: {len(steps)} step(s)")
|
||||
for step in steps:
|
||||
# Show human-readable name
|
||||
name = step.name or step.step_id[:20]
|
||||
print(f" - {name}")
|
||||
print(f" Type: {step.node_type}")
|
||||
print(f" Cache ID: {step.cache_id[:16]}...")
|
||||
if step.outputs:
|
||||
print(f" Outputs: {len(step.outputs)}")
|
||||
for out in step.outputs:
|
||||
print(f" - {out.name} ({out.media_type})")
|
||||
if step.inputs:
|
||||
print(f" Inputs: {[inp.name for inp in step.inputs]}")
|
||||
print()
|
||||
|
||||
# Save plan for inspection
|
||||
plan_path = Path(__file__).parent.parent / "test_plan_output.json"
|
||||
with open(plan_path, "w") as f:
|
||||
f.write(plan.to_json())
|
||||
print(f"Plan saved to: {plan_path}")
|
||||
|
||||
# Show plan JSON structure
|
||||
print()
|
||||
print("=== Plan JSON Preview ===")
|
||||
plan_dict = json.loads(plan.to_json())
|
||||
# Show first step as example
|
||||
if plan_dict.get("steps"):
|
||||
first_step = plan_dict["steps"][0]
|
||||
print(json.dumps(first_step, indent=2)[:500] + "...")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
62
core/pyproject.toml
Normal file
62
core/pyproject.toml
Normal file
@@ -0,0 +1,62 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "artdag"
|
||||
version = "0.1.0"
|
||||
description = "Content-addressed DAG execution engine with ActivityPub ownership"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
requires-python = ">=3.10"
|
||||
authors = [
|
||||
{name = "Giles", email = "giles@rose-ash.com"}
|
||||
]
|
||||
keywords = ["dag", "content-addressed", "activitypub", "video", "processing"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = [
|
||||
"cryptography>=41.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
]
|
||||
analysis = [
|
||||
"librosa>=0.10.0",
|
||||
"numpy>=1.24.0",
|
||||
"pyyaml>=6.0",
|
||||
]
|
||||
cv = [
|
||||
"opencv-python>=4.8.0",
|
||||
]
|
||||
all = [
|
||||
"librosa>=0.10.0",
|
||||
"numpy>=1.24.0",
|
||||
"pyyaml>=6.0",
|
||||
"opencv-python>=4.8.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
artdag = "artdag.cli:main"
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://artdag.rose-ash.com"
|
||||
Repository = "https://github.com/giles/artdag"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["artdag*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
67
core/scripts/compute_repo_hash.py
Normal file
67
core/scripts/compute_repo_hash.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compute content hash of a git repository.
|
||||
|
||||
Hashes all tracked files (respects .gitignore) in sorted order.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def repo_hash(repo_path: Path) -> str:
|
||||
"""
|
||||
Compute SHA3-256 hash of all tracked files in a repo.
|
||||
|
||||
Uses git ls-files to respect .gitignore.
|
||||
Files are hashed in sorted order for determinism.
|
||||
Each file contributes: relative_path + file_contents
|
||||
"""
|
||||
# Get list of tracked files
|
||||
result = subprocess.run(
|
||||
["git", "ls-files"],
|
||||
cwd=repo_path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
files = sorted(result.stdout.strip().split("\n"))
|
||||
|
||||
hasher = hashlib.sha3_256()
|
||||
|
||||
for rel_path in files:
|
||||
if not rel_path:
|
||||
continue
|
||||
|
||||
file_path = repo_path / rel_path
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
# Include path in hash
|
||||
hasher.update(rel_path.encode())
|
||||
|
||||
# Include contents
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
hasher.update(chunk)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) > 1:
|
||||
repo_path = Path(sys.argv[1])
|
||||
else:
|
||||
repo_path = Path.cwd()
|
||||
|
||||
h = repo_hash(repo_path)
|
||||
print(f"Repository: {repo_path}")
|
||||
print(f"Hash: {h}")
|
||||
return h
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
core/scripts/install-ffglitch.sh
Executable file
82
core/scripts/install-ffglitch.sh
Executable file
@@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
# Install ffglitch for datamosh effects
|
||||
# Usage: ./install-ffglitch.sh [install_dir]
|
||||
|
||||
set -e
|
||||
|
||||
FFGLITCH_VERSION="0.10.2"
|
||||
INSTALL_DIR="${1:-/usr/local/bin}"
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
case "$ARCH" in
|
||||
x86_64)
|
||||
URL="https://ffglitch.org/pub/bin/linux64/ffglitch-${FFGLITCH_VERSION}-linux-x86_64.zip"
|
||||
ARCHIVE="ffglitch.zip"
|
||||
;;
|
||||
aarch64)
|
||||
URL="https://ffglitch.org/pub/bin/linux-aarch64/ffglitch-${FFGLITCH_VERSION}-linux-aarch64.7z"
|
||||
ARCHIVE="ffglitch.7z"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported architecture: $ARCH"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
echo "Installing ffglitch ${FFGLITCH_VERSION} for ${ARCH}..."
|
||||
|
||||
# Create temp directory
|
||||
TMPDIR=$(mktemp -d)
|
||||
cd "$TMPDIR"
|
||||
|
||||
# Download
|
||||
echo "Downloading from ${URL}..."
|
||||
curl -L -o "$ARCHIVE" "$URL"
|
||||
|
||||
# Extract
|
||||
echo "Extracting..."
|
||||
if [[ "$ARCHIVE" == *.zip ]]; then
|
||||
unzip -q "$ARCHIVE"
|
||||
elif [[ "$ARCHIVE" == *.7z ]]; then
|
||||
# Requires p7zip
|
||||
if ! command -v 7z &> /dev/null; then
|
||||
echo "7z not found. Install with: apt install p7zip-full"
|
||||
exit 1
|
||||
fi
|
||||
7z x "$ARCHIVE" > /dev/null
|
||||
fi
|
||||
|
||||
# Find and install binaries
|
||||
echo "Installing to ${INSTALL_DIR}..."
|
||||
find . -name "ffgac" -o -name "ffedit" | while read bin; do
|
||||
chmod +x "$bin"
|
||||
if [ -w "$INSTALL_DIR" ]; then
|
||||
cp "$bin" "$INSTALL_DIR/"
|
||||
else
|
||||
sudo cp "$bin" "$INSTALL_DIR/"
|
||||
fi
|
||||
echo " Installed: $(basename $bin)"
|
||||
done
|
||||
|
||||
# Cleanup
|
||||
cd /
|
||||
rm -rf "$TMPDIR"
|
||||
|
||||
# Verify
|
||||
echo ""
|
||||
echo "Verifying installation..."
|
||||
if command -v ffgac &> /dev/null; then
|
||||
echo "ffgac: $(which ffgac)"
|
||||
else
|
||||
echo "Warning: ffgac not in PATH. Add ${INSTALL_DIR} to PATH."
|
||||
fi
|
||||
|
||||
if command -v ffedit &> /dev/null; then
|
||||
echo "ffedit: $(which ffedit)"
|
||||
else
|
||||
echo "Warning: ffedit not in PATH. Add ${INSTALL_DIR} to PATH."
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Done! ffglitch installed."
|
||||
83
core/scripts/register_identity_effect.py
Normal file
83
core/scripts/register_identity_effect.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Register the identity effect owned by giles.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# Add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from artdag.activitypub.ownership import OwnershipManager
|
||||
|
||||
|
||||
def folder_hash(folder: Path) -> str:
|
||||
"""
|
||||
Compute SHA3-256 hash of an entire folder.
|
||||
|
||||
Hashes all files in sorted order for deterministic results.
|
||||
Each file contributes: relative_path + file_contents
|
||||
"""
|
||||
hasher = hashlib.sha3_256()
|
||||
|
||||
# Get all files sorted by relative path
|
||||
files = sorted(folder.rglob("*"))
|
||||
|
||||
for file_path in files:
|
||||
if file_path.is_file():
|
||||
# Include relative path in hash for structure
|
||||
rel_path = file_path.relative_to(folder)
|
||||
hasher.update(str(rel_path).encode())
|
||||
|
||||
# Include file contents
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(65536), b""):
|
||||
hasher.update(chunk)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
def main():
|
||||
# Use .cache as the ownership data directory
|
||||
base_dir = Path(__file__).parent.parent / ".cache" / "ownership"
|
||||
manager = OwnershipManager(base_dir)
|
||||
|
||||
# Create or get giles actor
|
||||
actor = manager.get_actor("giles")
|
||||
if not actor:
|
||||
actor = manager.create_actor("giles", "Giles Bradshaw")
|
||||
print(f"Created actor: {actor.handle}")
|
||||
else:
|
||||
print(f"Using existing actor: {actor.handle}")
|
||||
|
||||
# Register the identity effect folder
|
||||
effect_path = Path(__file__).parent.parent / "effects" / "identity"
|
||||
cid = folder_hash(effect_path)
|
||||
|
||||
asset, activity = manager.register_asset(
|
||||
actor=actor,
|
||||
name="effect:identity",
|
||||
cid=cid,
|
||||
local_path=effect_path,
|
||||
tags=["effect", "primitive", "identity"],
|
||||
metadata={
|
||||
"type": "effect",
|
||||
"description": "The identity effect - returns input unchanged",
|
||||
"signature": "identity(input) → input",
|
||||
},
|
||||
)
|
||||
|
||||
print(f"\nRegistered: {asset.name}")
|
||||
print(f" Hash: {asset.cid}")
|
||||
print(f" Path: {asset.local_path}")
|
||||
print(f" Activity: {activity.activity_id}")
|
||||
print(f" Owner: {actor.handle}")
|
||||
|
||||
# Verify ownership
|
||||
verified = manager.verify_ownership(asset.name, actor)
|
||||
print(f" Ownership verified: {verified}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
120
core/scripts/setup_actor.py
Normal file
120
core/scripts/setup_actor.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Set up actor with keypair stored securely.
|
||||
|
||||
Private key: ~/.artdag/keys/{username}.pem
|
||||
Public key: exported for registry
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
# Add artdag to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
|
||||
def create_keypair():
|
||||
"""Generate RSA-2048 keypair."""
|
||||
private_key = rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
backend=default_backend(),
|
||||
)
|
||||
return private_key
|
||||
|
||||
|
||||
def save_private_key(private_key, path: Path):
|
||||
"""Save private key to PEM file."""
|
||||
pem = private_key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(pem)
|
||||
os.chmod(path, 0o600) # Owner read/write only
|
||||
return pem.decode()
|
||||
|
||||
|
||||
def get_public_key_pem(private_key) -> str:
|
||||
"""Extract public key as PEM string."""
|
||||
public_key = private_key.public_key()
|
||||
pem = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return pem.decode()
|
||||
|
||||
|
||||
def create_actor_json(username: str, display_name: str, public_key_pem: str, domain: str = "artdag.rose-ash.com"):
|
||||
"""Create ActivityPub actor JSON."""
|
||||
return {
|
||||
"@context": [
|
||||
"https://www.w3.org/ns/activitystreams",
|
||||
"https://w3id.org/security/v1"
|
||||
],
|
||||
"type": "Person",
|
||||
"id": f"https://{domain}/users/{username}",
|
||||
"preferredUsername": username,
|
||||
"name": display_name,
|
||||
"inbox": f"https://{domain}/users/{username}/inbox",
|
||||
"outbox": f"https://{domain}/users/{username}/outbox",
|
||||
"publicKey": {
|
||||
"id": f"https://{domain}/users/{username}#main-key",
|
||||
"owner": f"https://{domain}/users/{username}",
|
||||
"publicKeyPem": public_key_pem
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
username = "giles"
|
||||
display_name = "Giles Bradshaw"
|
||||
domain = "artdag.rose-ash.com"
|
||||
|
||||
keys_dir = Path.home() / ".artdag" / "keys"
|
||||
private_key_path = keys_dir / f"{username}.pem"
|
||||
|
||||
# Check if key already exists
|
||||
if private_key_path.exists():
|
||||
print(f"Private key already exists: {private_key_path}")
|
||||
print("Delete it first if you want to regenerate.")
|
||||
sys.exit(1)
|
||||
|
||||
# Create new keypair
|
||||
print(f"Creating new keypair for @{username}@{domain}...")
|
||||
private_key = create_keypair()
|
||||
|
||||
# Save private key
|
||||
save_private_key(private_key, private_key_path)
|
||||
print(f"Private key saved: {private_key_path}")
|
||||
print(f" Mode: 600 (owner read/write only)")
|
||||
print(f" BACK THIS UP!")
|
||||
|
||||
# Get public key
|
||||
public_key_pem = get_public_key_pem(private_key)
|
||||
|
||||
# Create actor JSON
|
||||
actor = create_actor_json(username, display_name, public_key_pem, domain)
|
||||
|
||||
# Output actor JSON
|
||||
actor_json = json.dumps(actor, indent=2)
|
||||
print(f"\nActor JSON (for registry/actors/{username}.json):")
|
||||
print(actor_json)
|
||||
|
||||
# Save to registry
|
||||
registry_path = Path.home() / "artdag-registry" / "actors" / f"{username}.json"
|
||||
registry_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
registry_path.write_text(actor_json)
|
||||
print(f"\nSaved to: {registry_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
core/scripts/sign_assets.py
Normal file
143
core/scripts/sign_assets.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Sign assets in the registry with giles's private key.
|
||||
|
||||
Creates ActivityPub Create activities with RSA signatures.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
||||
|
||||
def load_private_key(path: Path):
|
||||
"""Load private key from PEM file."""
|
||||
pem_data = path.read_bytes()
|
||||
return serialization.load_pem_private_key(pem_data, password=None, backend=default_backend())
|
||||
|
||||
|
||||
def sign_data(private_key, data: str) -> str:
|
||||
"""Sign data with RSA private key, return base64 signature."""
|
||||
signature = private_key.sign(
|
||||
data.encode(),
|
||||
padding.PKCS1v15(),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
return base64.b64encode(signature).decode()
|
||||
|
||||
|
||||
def create_activity(actor_id: str, asset_name: str, cid: str, asset_type: str, domain: str = "artdag.rose-ash.com"):
|
||||
"""Create a Create activity for an asset."""
|
||||
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
return {
|
||||
"activity_id": str(uuid.uuid4()),
|
||||
"activity_type": "Create",
|
||||
"actor_id": actor_id,
|
||||
"object_data": {
|
||||
"type": asset_type_to_ap(asset_type),
|
||||
"name": asset_name,
|
||||
"id": f"https://{domain}/objects/{cid}",
|
||||
"contentHash": {
|
||||
"algorithm": "sha3-256",
|
||||
"value": cid
|
||||
},
|
||||
"attributedTo": actor_id
|
||||
},
|
||||
"published": now,
|
||||
}
|
||||
|
||||
|
||||
def asset_type_to_ap(asset_type: str) -> str:
|
||||
"""Convert asset type to ActivityPub type."""
|
||||
type_map = {
|
||||
"image": "Image",
|
||||
"video": "Video",
|
||||
"audio": "Audio",
|
||||
"effect": "Application",
|
||||
"infrastructure": "Application",
|
||||
}
|
||||
return type_map.get(asset_type, "Document")
|
||||
|
||||
|
||||
def sign_activity(activity: dict, private_key, actor_id: str, domain: str = "artdag.rose-ash.com") -> dict:
|
||||
"""Add signature to activity."""
|
||||
# Create canonical string to sign
|
||||
to_sign = json.dumps(activity["object_data"], sort_keys=True, separators=(",", ":"))
|
||||
|
||||
signature_value = sign_data(private_key, to_sign)
|
||||
|
||||
activity["signature"] = {
|
||||
"type": "RsaSignature2017",
|
||||
"creator": f"{actor_id}#main-key",
|
||||
"created": activity["published"],
|
||||
"signatureValue": signature_value
|
||||
}
|
||||
|
||||
return activity
|
||||
|
||||
|
||||
def main():
|
||||
username = "giles"
|
||||
domain = "artdag.rose-ash.com"
|
||||
actor_id = f"https://{domain}/users/{username}"
|
||||
|
||||
# Load private key
|
||||
private_key_path = Path.home() / ".artdag" / "keys" / f"{username}.pem"
|
||||
if not private_key_path.exists():
|
||||
print(f"Private key not found: {private_key_path}")
|
||||
print("Run setup_actor.py first.")
|
||||
sys.exit(1)
|
||||
|
||||
private_key = load_private_key(private_key_path)
|
||||
print(f"Loaded private key: {private_key_path}")
|
||||
|
||||
# Load registry
|
||||
registry_path = Path.home() / "artdag-registry" / "registry.json"
|
||||
with open(registry_path) as f:
|
||||
registry = json.load(f)
|
||||
|
||||
# Create signed activities for each asset
|
||||
activities = []
|
||||
|
||||
for asset_name, asset_data in registry["assets"].items():
|
||||
print(f"\nSigning: {asset_name}")
|
||||
print(f" Hash: {asset_data['cid'][:16]}...")
|
||||
|
||||
activity = create_activity(
|
||||
actor_id=actor_id,
|
||||
asset_name=asset_name,
|
||||
cid=asset_data["cid"],
|
||||
asset_type=asset_data["asset_type"],
|
||||
domain=domain,
|
||||
)
|
||||
|
||||
signed_activity = sign_activity(activity, private_key, actor_id, domain)
|
||||
activities.append(signed_activity)
|
||||
|
||||
print(f" Activity ID: {signed_activity['activity_id']}")
|
||||
print(f" Signature: {signed_activity['signature']['signatureValue'][:32]}...")
|
||||
|
||||
# Save activities
|
||||
activities_path = Path.home() / "artdag-registry" / "activities.json"
|
||||
activities_data = {
|
||||
"version": "1.0",
|
||||
"activities": activities
|
||||
}
|
||||
|
||||
with open(activities_path, "w") as f:
|
||||
json.dump(activities_data, f, indent=2)
|
||||
|
||||
print(f"\nSaved {len(activities)} signed activities to: {activities_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
core/tests/__init__.py
Normal file
1
core/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for new standalone primitive engine
|
||||
613
core/tests/test_activities.py
Normal file
613
core/tests/test_activities.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# tests/test_activities.py
|
||||
"""Tests for the activity tracking and cache deletion system."""
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from artdag import Cache, DAG, Node, NodeType
|
||||
from artdag.activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn
|
||||
|
||||
|
||||
class MockActivityPubStore:
|
||||
"""Mock ActivityPub store for testing is_shared functionality."""
|
||||
|
||||
def __init__(self):
|
||||
self._shared_hashes = set()
|
||||
|
||||
def mark_shared(self, cid: str):
|
||||
"""Mark a content hash as shared (published)."""
|
||||
self._shared_hashes.add(cid)
|
||||
|
||||
def find_by_object_hash(self, cid: str):
|
||||
"""Return mock activities for shared hashes."""
|
||||
if cid in self._shared_hashes:
|
||||
return [MockActivity("Create")]
|
||||
return []
|
||||
|
||||
|
||||
class MockActivity:
|
||||
"""Mock ActivityPub activity."""
|
||||
def __init__(self, activity_type: str):
|
||||
self.activity_type = activity_type
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for tests."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache(temp_dir):
|
||||
"""Create a cache instance."""
|
||||
return Cache(temp_dir / "cache")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def activity_store(temp_dir):
|
||||
"""Create an activity store instance."""
|
||||
return ActivityStore(temp_dir / "activities")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ap_store():
|
||||
"""Create a mock ActivityPub store."""
|
||||
return MockActivityPubStore()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(cache, activity_store, ap_store):
|
||||
"""Create an ActivityManager instance."""
|
||||
return ActivityManager(
|
||||
cache=cache,
|
||||
activity_store=activity_store,
|
||||
is_shared_fn=make_is_shared_fn(ap_store),
|
||||
)
|
||||
|
||||
|
||||
def create_test_file(path: Path, content: str = "test content") -> Path:
|
||||
"""Create a test file with content."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
|
||||
class TestCacheEntryContentHash:
|
||||
"""Tests for cid in CacheEntry."""
|
||||
|
||||
def test_put_computes_cid(self, cache, temp_dir):
|
||||
"""put() should compute and store cid."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "hello world")
|
||||
|
||||
cache.put("node1", test_file, "test")
|
||||
entry = cache.get_entry("node1")
|
||||
|
||||
assert entry is not None
|
||||
assert entry.cid != ""
|
||||
assert len(entry.cid) == 64 # SHA-3-256 hex
|
||||
|
||||
def test_same_content_same_hash(self, cache, temp_dir):
|
||||
"""Same file content should produce same hash."""
|
||||
file1 = create_test_file(temp_dir / "file1.txt", "identical content")
|
||||
file2 = create_test_file(temp_dir / "file2.txt", "identical content")
|
||||
|
||||
cache.put("node1", file1, "test")
|
||||
cache.put("node2", file2, "test")
|
||||
|
||||
entry1 = cache.get_entry("node1")
|
||||
entry2 = cache.get_entry("node2")
|
||||
|
||||
assert entry1.cid == entry2.cid
|
||||
|
||||
def test_different_content_different_hash(self, cache, temp_dir):
|
||||
"""Different file content should produce different hash."""
|
||||
file1 = create_test_file(temp_dir / "file1.txt", "content A")
|
||||
file2 = create_test_file(temp_dir / "file2.txt", "content B")
|
||||
|
||||
cache.put("node1", file1, "test")
|
||||
cache.put("node2", file2, "test")
|
||||
|
||||
entry1 = cache.get_entry("node1")
|
||||
entry2 = cache.get_entry("node2")
|
||||
|
||||
assert entry1.cid != entry2.cid
|
||||
|
||||
def test_find_by_cid(self, cache, temp_dir):
|
||||
"""Should find entry by content hash."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "unique content")
|
||||
cache.put("node1", test_file, "test")
|
||||
|
||||
entry = cache.get_entry("node1")
|
||||
found = cache.find_by_cid(entry.cid)
|
||||
|
||||
assert found is not None
|
||||
assert found.node_id == "node1"
|
||||
|
||||
def test_cid_persists(self, temp_dir):
|
||||
"""cid should persist across cache reloads."""
|
||||
cache1 = Cache(temp_dir / "cache")
|
||||
test_file = create_test_file(temp_dir / "input.txt", "persistent")
|
||||
cache1.put("node1", test_file, "test")
|
||||
original_hash = cache1.get_entry("node1").cid
|
||||
|
||||
# Create new cache instance (reload from disk)
|
||||
cache2 = Cache(temp_dir / "cache")
|
||||
entry = cache2.get_entry("node1")
|
||||
|
||||
assert entry.cid == original_hash
|
||||
|
||||
|
||||
class TestActivity:
|
||||
"""Tests for Activity dataclass."""
|
||||
|
||||
def test_activity_from_dag(self):
|
||||
"""Activity.from_dag() should classify nodes correctly."""
|
||||
# Build a simple DAG: source -> transform -> output
|
||||
dag = DAG()
|
||||
source = Node(NodeType.SOURCE, {"path": "/test.mp4"})
|
||||
transform = Node(NodeType.TRANSFORM, {"effect": "blur"}, inputs=[source.node_id])
|
||||
output = Node(NodeType.RESIZE, {"width": 100}, inputs=[transform.node_id])
|
||||
|
||||
dag.add_node(source)
|
||||
dag.add_node(transform)
|
||||
dag.add_node(output)
|
||||
dag.set_output(output.node_id)
|
||||
|
||||
activity = Activity.from_dag(dag)
|
||||
|
||||
assert source.node_id in activity.input_ids
|
||||
assert activity.output_id == output.node_id
|
||||
assert transform.node_id in activity.intermediate_ids
|
||||
|
||||
def test_activity_with_multiple_inputs(self):
|
||||
"""Activity should handle DAGs with multiple source nodes."""
|
||||
dag = DAG()
|
||||
source1 = Node(NodeType.SOURCE, {"path": "/a.mp4"})
|
||||
source2 = Node(NodeType.SOURCE, {"path": "/b.mp4"})
|
||||
sequence = Node(NodeType.SEQUENCE, {}, inputs=[source1.node_id, source2.node_id])
|
||||
|
||||
dag.add_node(source1)
|
||||
dag.add_node(source2)
|
||||
dag.add_node(sequence)
|
||||
dag.set_output(sequence.node_id)
|
||||
|
||||
activity = Activity.from_dag(dag)
|
||||
|
||||
assert len(activity.input_ids) == 2
|
||||
assert source1.node_id in activity.input_ids
|
||||
assert source2.node_id in activity.input_ids
|
||||
assert activity.output_id == sequence.node_id
|
||||
assert len(activity.intermediate_ids) == 0
|
||||
|
||||
def test_activity_serialization(self):
|
||||
"""Activity should serialize and deserialize correctly."""
|
||||
dag = DAG()
|
||||
source = Node(NodeType.SOURCE, {"path": "/test.mp4"})
|
||||
dag.add_node(source)
|
||||
dag.set_output(source.node_id)
|
||||
|
||||
activity = Activity.from_dag(dag)
|
||||
data = activity.to_dict()
|
||||
restored = Activity.from_dict(data)
|
||||
|
||||
assert restored.activity_id == activity.activity_id
|
||||
assert restored.input_ids == activity.input_ids
|
||||
assert restored.output_id == activity.output_id
|
||||
assert restored.intermediate_ids == activity.intermediate_ids
|
||||
|
||||
def test_all_node_ids(self):
|
||||
"""all_node_ids should return all nodes."""
|
||||
activity = Activity(
|
||||
activity_id="test",
|
||||
input_ids=["a", "b"],
|
||||
output_id="c",
|
||||
intermediate_ids=["d", "e"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
all_ids = activity.all_node_ids
|
||||
assert set(all_ids) == {"a", "b", "c", "d", "e"}
|
||||
|
||||
|
||||
class TestActivityStore:
|
||||
"""Tests for ActivityStore persistence."""
|
||||
|
||||
def test_add_and_get(self, activity_store):
|
||||
"""Should add and retrieve activities."""
|
||||
activity = Activity(
|
||||
activity_id="test1",
|
||||
input_ids=["input1"],
|
||||
output_id="output1",
|
||||
intermediate_ids=["inter1"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
activity_store.add(activity)
|
||||
retrieved = activity_store.get("test1")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.activity_id == "test1"
|
||||
|
||||
def test_persistence(self, temp_dir):
|
||||
"""Activities should persist across store reloads."""
|
||||
store1 = ActivityStore(temp_dir / "activities")
|
||||
activity = Activity(
|
||||
activity_id="persist",
|
||||
input_ids=["i1"],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
store1.add(activity)
|
||||
|
||||
# Reload
|
||||
store2 = ActivityStore(temp_dir / "activities")
|
||||
retrieved = store2.get("persist")
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.activity_id == "persist"
|
||||
|
||||
def test_find_by_input_ids(self, activity_store):
|
||||
"""Should find activities with matching inputs."""
|
||||
activity1 = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["x", "y"],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity2 = Activity(
|
||||
activity_id="a2",
|
||||
input_ids=["y", "x"], # Same inputs, different order
|
||||
output_id="o2",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity3 = Activity(
|
||||
activity_id="a3",
|
||||
input_ids=["z"], # Different inputs
|
||||
output_id="o3",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
|
||||
activity_store.add(activity1)
|
||||
activity_store.add(activity2)
|
||||
activity_store.add(activity3)
|
||||
|
||||
found = activity_store.find_by_input_ids(["x", "y"])
|
||||
assert len(found) == 2
|
||||
assert {a.activity_id for a in found} == {"a1", "a2"}
|
||||
|
||||
def test_find_using_node(self, activity_store):
|
||||
"""Should find activities referencing a node."""
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["input1"],
|
||||
output_id="output1",
|
||||
intermediate_ids=["inter1"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
# Should find by input
|
||||
found = activity_store.find_using_node("input1")
|
||||
assert len(found) == 1
|
||||
|
||||
# Should find by intermediate
|
||||
found = activity_store.find_using_node("inter1")
|
||||
assert len(found) == 1
|
||||
|
||||
# Should find by output
|
||||
found = activity_store.find_using_node("output1")
|
||||
assert len(found) == 1
|
||||
|
||||
# Should not find unknown
|
||||
found = activity_store.find_using_node("unknown")
|
||||
assert len(found) == 0
|
||||
|
||||
def test_remove(self, activity_store):
|
||||
"""Should remove activities."""
|
||||
activity = Activity(
|
||||
activity_id="to_remove",
|
||||
input_ids=["i"],
|
||||
output_id="o",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
assert activity_store.get("to_remove") is not None
|
||||
|
||||
result = activity_store.remove("to_remove")
|
||||
assert result is True
|
||||
assert activity_store.get("to_remove") is None
|
||||
|
||||
|
||||
class TestActivityManager:
|
||||
"""Tests for ActivityManager deletion rules."""
|
||||
|
||||
def test_can_delete_orphaned_entry(self, manager, cache, temp_dir):
|
||||
"""Orphaned entries (not in any activity) can be deleted."""
|
||||
test_file = create_test_file(temp_dir / "orphan.txt", "orphan")
|
||||
cache.put("orphan_node", test_file, "test")
|
||||
|
||||
assert manager.can_delete_cache_entry("orphan_node") is True
|
||||
|
||||
def test_cannot_delete_shared_entry(self, manager, cache, temp_dir, ap_store):
|
||||
"""Shared entries (ActivityPub published) cannot be deleted."""
|
||||
test_file = create_test_file(temp_dir / "shared.txt", "shared content")
|
||||
cache.put("shared_node", test_file, "test")
|
||||
|
||||
# Mark as shared
|
||||
entry = cache.get_entry("shared_node")
|
||||
ap_store.mark_shared(entry.cid)
|
||||
|
||||
assert manager.can_delete_cache_entry("shared_node") is False
|
||||
|
||||
def test_cannot_delete_activity_input(self, manager, cache, activity_store, temp_dir):
|
||||
"""Activity inputs cannot be deleted."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
cache.put("input_node", test_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["input_node"],
|
||||
output_id="output_node",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
assert manager.can_delete_cache_entry("input_node") is False
|
||||
|
||||
def test_cannot_delete_activity_output(self, manager, cache, activity_store, temp_dir):
|
||||
"""Activity outputs cannot be deleted."""
|
||||
test_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
cache.put("output_node", test_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["input_node"],
|
||||
output_id="output_node",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
assert manager.can_delete_cache_entry("output_node") is False
|
||||
|
||||
def test_can_delete_intermediate(self, manager, cache, activity_store, temp_dir):
|
||||
"""Intermediate entries can be deleted (they're reconstructible)."""
|
||||
test_file = create_test_file(temp_dir / "inter.txt", "intermediate")
|
||||
cache.put("inter_node", test_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["input_node"],
|
||||
output_id="output_node",
|
||||
intermediate_ids=["inter_node"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
assert manager.can_delete_cache_entry("inter_node") is True
|
||||
|
||||
def test_can_discard_activity_no_shared(self, manager, activity_store):
|
||||
"""Activity can be discarded if nothing is shared."""
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["i1"],
|
||||
output_id="o1",
|
||||
intermediate_ids=["m1"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
assert manager.can_discard_activity("a1") is True
|
||||
|
||||
def test_cannot_discard_activity_with_shared_output(self, manager, cache, activity_store, temp_dir, ap_store):
|
||||
"""Activity cannot be discarded if output is shared."""
|
||||
test_file = create_test_file(temp_dir / "output.txt", "output content")
|
||||
cache.put("o1", test_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["i1"],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
# Mark output as shared
|
||||
entry = cache.get_entry("o1")
|
||||
ap_store.mark_shared(entry.cid)
|
||||
|
||||
assert manager.can_discard_activity("a1") is False
|
||||
|
||||
def test_cannot_discard_activity_with_shared_input(self, manager, cache, activity_store, temp_dir, ap_store):
|
||||
"""Activity cannot be discarded if input is shared."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "input content")
|
||||
cache.put("i1", test_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["i1"],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
entry = cache.get_entry("i1")
|
||||
ap_store.mark_shared(entry.cid)
|
||||
|
||||
assert manager.can_discard_activity("a1") is False
|
||||
|
||||
def test_discard_activity_deletes_intermediates(self, manager, cache, activity_store, temp_dir):
|
||||
"""Discarding activity should delete intermediate cache entries."""
|
||||
# Create cache entries
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
inter_file = create_test_file(temp_dir / "inter.txt", "intermediate")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
cache.put("i1", input_file, "test")
|
||||
cache.put("m1", inter_file, "test")
|
||||
cache.put("o1", output_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["i1"],
|
||||
output_id="o1",
|
||||
intermediate_ids=["m1"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
# Discard
|
||||
result = manager.discard_activity("a1")
|
||||
|
||||
assert result is True
|
||||
assert cache.has("m1") is False # Intermediate deleted
|
||||
assert activity_store.get("a1") is None # Activity removed
|
||||
|
||||
def test_discard_activity_deletes_orphaned_output(self, manager, cache, activity_store, temp_dir):
|
||||
"""Discarding activity should delete output if orphaned."""
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
cache.put("o1", output_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=[],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
manager.discard_activity("a1")
|
||||
|
||||
assert cache.has("o1") is False # Orphaned output deleted
|
||||
|
||||
def test_discard_activity_keeps_shared_output(self, manager, cache, activity_store, temp_dir, ap_store):
|
||||
"""Discarding should fail if output is shared."""
|
||||
output_file = create_test_file(temp_dir / "output.txt", "shared output")
|
||||
cache.put("o1", output_file, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=[],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
entry = cache.get_entry("o1")
|
||||
ap_store.mark_shared(entry.cid)
|
||||
|
||||
result = manager.discard_activity("a1")
|
||||
|
||||
assert result is False # Cannot discard
|
||||
assert cache.has("o1") is True # Output preserved
|
||||
assert activity_store.get("a1") is not None # Activity preserved
|
||||
|
||||
def test_discard_keeps_input_used_elsewhere(self, manager, cache, activity_store, temp_dir):
|
||||
"""Input used by another activity should not be deleted."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "shared input")
|
||||
cache.put("shared_input", input_file, "test")
|
||||
|
||||
activity1 = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["shared_input"],
|
||||
output_id="o1",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity2 = Activity(
|
||||
activity_id="a2",
|
||||
input_ids=["shared_input"],
|
||||
output_id="o2",
|
||||
intermediate_ids=[],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity1)
|
||||
activity_store.add(activity2)
|
||||
|
||||
manager.discard_activity("a1")
|
||||
|
||||
# Input still used by a2
|
||||
assert cache.has("shared_input") is True
|
||||
|
||||
def test_get_deletable_entries(self, manager, cache, activity_store, temp_dir):
|
||||
"""Should list all deletable entries."""
|
||||
# Orphan (deletable)
|
||||
orphan = create_test_file(temp_dir / "orphan.txt", "orphan")
|
||||
cache.put("orphan", orphan, "test")
|
||||
|
||||
# Intermediate (deletable)
|
||||
inter = create_test_file(temp_dir / "inter.txt", "inter")
|
||||
cache.put("inter", inter, "test")
|
||||
|
||||
# Input (not deletable)
|
||||
inp = create_test_file(temp_dir / "input.txt", "input")
|
||||
cache.put("input", inp, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["input"],
|
||||
output_id="output",
|
||||
intermediate_ids=["inter"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
deletable = manager.get_deletable_entries()
|
||||
deletable_ids = {e.node_id for e in deletable}
|
||||
|
||||
assert "orphan" in deletable_ids
|
||||
assert "inter" in deletable_ids
|
||||
assert "input" not in deletable_ids
|
||||
|
||||
def test_cleanup_intermediates(self, manager, cache, activity_store, temp_dir):
|
||||
"""cleanup_intermediates() should delete all intermediate entries."""
|
||||
inter1 = create_test_file(temp_dir / "i1.txt", "inter1")
|
||||
inter2 = create_test_file(temp_dir / "i2.txt", "inter2")
|
||||
cache.put("inter1", inter1, "test")
|
||||
cache.put("inter2", inter2, "test")
|
||||
|
||||
activity = Activity(
|
||||
activity_id="a1",
|
||||
input_ids=["input"],
|
||||
output_id="output",
|
||||
intermediate_ids=["inter1", "inter2"],
|
||||
created_at=time.time(),
|
||||
)
|
||||
activity_store.add(activity)
|
||||
|
||||
deleted = manager.cleanup_intermediates()
|
||||
|
||||
assert deleted == 2
|
||||
assert cache.has("inter1") is False
|
||||
assert cache.has("inter2") is False
|
||||
|
||||
|
||||
class TestMakeIsSharedFn:
|
||||
"""Tests for make_is_shared_fn factory."""
|
||||
|
||||
def test_returns_true_for_shared(self, ap_store):
|
||||
"""Should return True for shared content."""
|
||||
is_shared = make_is_shared_fn(ap_store)
|
||||
ap_store.mark_shared("hash123")
|
||||
|
||||
assert is_shared("hash123") is True
|
||||
|
||||
def test_returns_false_for_not_shared(self, ap_store):
|
||||
"""Should return False for non-shared content."""
|
||||
is_shared = make_is_shared_fn(ap_store)
|
||||
|
||||
assert is_shared("unknown_hash") is False
|
||||
163
core/tests/test_cache.py
Normal file
163
core/tests/test_cache.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# tests/test_primitive_new/test_cache.py
|
||||
"""Tests for primitive cache module."""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from artdag.cache import Cache, CacheStats
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_dir():
|
||||
"""Create temporary cache directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache(cache_dir):
|
||||
"""Create cache instance."""
|
||||
return Cache(cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(cache_dir):
|
||||
"""Create a sample file to cache."""
|
||||
file_path = cache_dir / "sample.txt"
|
||||
file_path.write_text("test content")
|
||||
return file_path
|
||||
|
||||
|
||||
class TestCache:
|
||||
"""Test Cache class."""
|
||||
|
||||
def test_cache_creation(self, cache_dir):
|
||||
"""Test cache directory is created."""
|
||||
cache = Cache(cache_dir / "new_cache")
|
||||
assert cache.cache_dir.exists()
|
||||
|
||||
def test_cache_put_and_get(self, cache, sample_file):
|
||||
"""Test putting and getting from cache."""
|
||||
node_id = "abc123"
|
||||
cached_path = cache.put(node_id, sample_file, "TEST")
|
||||
|
||||
assert cached_path.exists()
|
||||
assert cache.has(node_id)
|
||||
|
||||
retrieved = cache.get(node_id)
|
||||
assert retrieved == cached_path
|
||||
|
||||
def test_cache_miss(self, cache):
|
||||
"""Test cache miss returns None."""
|
||||
result = cache.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_cache_stats_hit_miss(self, cache, sample_file):
|
||||
"""Test cache hit/miss stats."""
|
||||
cache.put("abc123", sample_file, "TEST")
|
||||
|
||||
# Miss
|
||||
cache.get("nonexistent")
|
||||
assert cache.stats.misses == 1
|
||||
|
||||
# Hit
|
||||
cache.get("abc123")
|
||||
assert cache.stats.hits == 1
|
||||
|
||||
assert cache.stats.hit_rate == 0.5
|
||||
|
||||
def test_cache_remove(self, cache, sample_file):
|
||||
"""Test removing from cache."""
|
||||
node_id = "abc123"
|
||||
cache.put(node_id, sample_file, "TEST")
|
||||
assert cache.has(node_id)
|
||||
|
||||
cache.remove(node_id)
|
||||
assert not cache.has(node_id)
|
||||
|
||||
def test_cache_clear(self, cache, sample_file):
|
||||
"""Test clearing cache."""
|
||||
cache.put("node1", sample_file, "TEST")
|
||||
cache.put("node2", sample_file, "TEST")
|
||||
|
||||
assert cache.stats.total_entries == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache.stats.total_entries == 0
|
||||
assert not cache.has("node1")
|
||||
assert not cache.has("node2")
|
||||
|
||||
def test_cache_preserves_extension(self, cache, cache_dir):
|
||||
"""Test that cache preserves file extension."""
|
||||
mp4_file = cache_dir / "video.mp4"
|
||||
mp4_file.write_text("fake video")
|
||||
|
||||
cached = cache.put("video_node", mp4_file, "SOURCE")
|
||||
assert cached.suffix == ".mp4"
|
||||
|
||||
def test_cache_list_entries(self, cache, sample_file):
|
||||
"""Test listing cache entries."""
|
||||
cache.put("node1", sample_file, "TYPE1")
|
||||
cache.put("node2", sample_file, "TYPE2")
|
||||
|
||||
entries = cache.list_entries()
|
||||
assert len(entries) == 2
|
||||
|
||||
node_ids = {e.node_id for e in entries}
|
||||
assert "node1" in node_ids
|
||||
assert "node2" in node_ids
|
||||
|
||||
def test_cache_persistence(self, cache_dir, sample_file):
|
||||
"""Test cache persists across instances."""
|
||||
# First instance
|
||||
cache1 = Cache(cache_dir)
|
||||
cache1.put("abc123", sample_file, "TEST")
|
||||
|
||||
# Second instance loads from disk
|
||||
cache2 = Cache(cache_dir)
|
||||
assert cache2.has("abc123")
|
||||
|
||||
def test_cache_prune_by_age(self, cache, sample_file):
|
||||
"""Test pruning by age."""
|
||||
import time
|
||||
|
||||
cache.put("old_node", sample_file, "TEST")
|
||||
|
||||
# Manually set old creation time
|
||||
entry = cache._entries["old_node"]
|
||||
entry.created_at = time.time() - 3600 # 1 hour ago
|
||||
|
||||
removed = cache.prune(max_age_seconds=1800) # 30 minutes
|
||||
|
||||
assert removed == 1
|
||||
assert not cache.has("old_node")
|
||||
|
||||
def test_cache_output_path(self, cache):
|
||||
"""Test getting output path for node."""
|
||||
path = cache.get_output_path("abc123", ".mp4")
|
||||
assert path.suffix == ".mp4"
|
||||
assert "abc123" in str(path)
|
||||
assert path.parent.exists()
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""Test CacheStats class."""
|
||||
|
||||
def test_hit_rate_calculation(self):
|
||||
"""Test hit rate calculation."""
|
||||
stats = CacheStats()
|
||||
|
||||
stats.record_hit()
|
||||
stats.record_hit()
|
||||
stats.record_miss()
|
||||
|
||||
assert stats.hits == 2
|
||||
assert stats.misses == 1
|
||||
assert abs(stats.hit_rate - 0.666) < 0.01
|
||||
|
||||
def test_initial_hit_rate(self):
|
||||
"""Test hit rate with no requests."""
|
||||
stats = CacheStats()
|
||||
assert stats.hit_rate == 0.0
|
||||
271
core/tests/test_dag.py
Normal file
271
core/tests/test_dag.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# tests/test_primitive_new/test_dag.py
|
||||
"""Tests for primitive DAG data structures."""
|
||||
|
||||
import pytest
|
||||
from artdag.dag import Node, NodeType, DAG, DAGBuilder
|
||||
|
||||
|
||||
class TestNode:
|
||||
"""Test Node class."""
|
||||
|
||||
def test_node_creation(self):
|
||||
"""Test basic node creation."""
|
||||
node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
assert node.node_type == NodeType.SOURCE
|
||||
assert node.config == {"path": "/test.mp4"}
|
||||
assert node.node_id is not None
|
||||
|
||||
def test_node_id_is_content_addressed(self):
|
||||
"""Same content produces same node_id."""
|
||||
node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
assert node1.node_id == node2.node_id
|
||||
|
||||
def test_different_config_different_id(self):
|
||||
"""Different config produces different node_id."""
|
||||
node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test1.mp4"})
|
||||
node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test2.mp4"})
|
||||
assert node1.node_id != node2.node_id
|
||||
|
||||
def test_node_with_inputs(self):
|
||||
"""Node with inputs includes them in ID."""
|
||||
node1 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["abc123"])
|
||||
node2 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["abc123"])
|
||||
node3 = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["def456"])
|
||||
|
||||
assert node1.node_id == node2.node_id
|
||||
assert node1.node_id != node3.node_id
|
||||
|
||||
def test_node_serialization(self):
|
||||
"""Test node to_dict and from_dict."""
|
||||
original = Node(
|
||||
node_type=NodeType.SEGMENT,
|
||||
config={"duration": 5.0, "offset": 10.0},
|
||||
inputs=["abc123"],
|
||||
name="my_segment",
|
||||
)
|
||||
data = original.to_dict()
|
||||
restored = Node.from_dict(data)
|
||||
|
||||
assert restored.node_type == original.node_type
|
||||
assert restored.config == original.config
|
||||
assert restored.inputs == original.inputs
|
||||
assert restored.name == original.name
|
||||
assert restored.node_id == original.node_id
|
||||
|
||||
def test_custom_node_type(self):
|
||||
"""Test node with custom string type."""
|
||||
node = Node(node_type="CUSTOM_TYPE", config={"custom": True})
|
||||
assert node.node_type == "CUSTOM_TYPE"
|
||||
assert node.node_id is not None
|
||||
|
||||
|
||||
class TestDAG:
|
||||
"""Test DAG class."""
|
||||
|
||||
def test_dag_creation(self):
|
||||
"""Test basic DAG creation."""
|
||||
dag = DAG()
|
||||
assert len(dag.nodes) == 0
|
||||
assert dag.output_id is None
|
||||
|
||||
def test_add_node(self):
|
||||
"""Test adding nodes to DAG."""
|
||||
dag = DAG()
|
||||
node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
node_id = dag.add_node(node)
|
||||
|
||||
assert node_id in dag.nodes
|
||||
assert dag.nodes[node_id] == node
|
||||
|
||||
def test_node_deduplication(self):
|
||||
"""Same node added twice returns same ID."""
|
||||
dag = DAG()
|
||||
node1 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
node2 = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
|
||||
id1 = dag.add_node(node1)
|
||||
id2 = dag.add_node(node2)
|
||||
|
||||
assert id1 == id2
|
||||
assert len(dag.nodes) == 1
|
||||
|
||||
def test_set_output(self):
|
||||
"""Test setting output node."""
|
||||
dag = DAG()
|
||||
node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
node_id = dag.add_node(node)
|
||||
dag.set_output(node_id)
|
||||
|
||||
assert dag.output_id == node_id
|
||||
|
||||
def test_set_output_invalid(self):
|
||||
"""Setting invalid output raises error."""
|
||||
dag = DAG()
|
||||
with pytest.raises(ValueError):
|
||||
dag.set_output("nonexistent")
|
||||
|
||||
def test_topological_order(self):
|
||||
"""Test topological ordering."""
|
||||
dag = DAG()
|
||||
|
||||
# Create simple chain: source -> segment -> output
|
||||
source = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
source_id = dag.add_node(source)
|
||||
|
||||
segment = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=[source_id])
|
||||
segment_id = dag.add_node(segment)
|
||||
|
||||
dag.set_output(segment_id)
|
||||
order = dag.topological_order()
|
||||
|
||||
# Source must come before segment
|
||||
assert order.index(source_id) < order.index(segment_id)
|
||||
|
||||
def test_validate_valid_dag(self):
|
||||
"""Test validation of valid DAG."""
|
||||
dag = DAG()
|
||||
node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
node_id = dag.add_node(node)
|
||||
dag.set_output(node_id)
|
||||
|
||||
errors = dag.validate()
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_validate_no_output(self):
|
||||
"""DAG without output is invalid."""
|
||||
dag = DAG()
|
||||
node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
dag.add_node(node)
|
||||
|
||||
errors = dag.validate()
|
||||
assert len(errors) > 0
|
||||
assert any("output" in e.lower() for e in errors)
|
||||
|
||||
def test_validate_missing_input(self):
|
||||
"""DAG with missing input reference is invalid."""
|
||||
dag = DAG()
|
||||
node = Node(node_type=NodeType.SEGMENT, config={"duration": 5}, inputs=["nonexistent"])
|
||||
node_id = dag.add_node(node)
|
||||
dag.set_output(node_id)
|
||||
|
||||
errors = dag.validate()
|
||||
assert len(errors) > 0
|
||||
assert any("missing" in e.lower() for e in errors)
|
||||
|
||||
def test_dag_serialization(self):
|
||||
"""Test DAG to_dict and from_dict."""
|
||||
dag = DAG(metadata={"name": "test_dag"})
|
||||
source = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
source_id = dag.add_node(source)
|
||||
dag.set_output(source_id)
|
||||
|
||||
data = dag.to_dict()
|
||||
restored = DAG.from_dict(data)
|
||||
|
||||
assert len(restored.nodes) == len(dag.nodes)
|
||||
assert restored.output_id == dag.output_id
|
||||
assert restored.metadata == dag.metadata
|
||||
|
||||
def test_dag_json(self):
|
||||
"""Test DAG JSON serialization."""
|
||||
dag = DAG()
|
||||
node = Node(node_type=NodeType.SOURCE, config={"path": "/test.mp4"})
|
||||
node_id = dag.add_node(node)
|
||||
dag.set_output(node_id)
|
||||
|
||||
json_str = dag.to_json()
|
||||
restored = DAG.from_json(json_str)
|
||||
|
||||
assert len(restored.nodes) == 1
|
||||
assert restored.output_id == node_id
|
||||
|
||||
|
||||
class TestDAGBuilder:
|
||||
"""Test DAGBuilder class."""
|
||||
|
||||
def test_builder_source(self):
|
||||
"""Test building source node."""
|
||||
builder = DAGBuilder()
|
||||
source_id = builder.source("/test.mp4")
|
||||
|
||||
assert source_id in builder.dag.nodes
|
||||
node = builder.dag.nodes[source_id]
|
||||
assert node.node_type == NodeType.SOURCE
|
||||
assert node.config["path"] == "/test.mp4"
|
||||
|
||||
def test_builder_segment(self):
|
||||
"""Test building segment node."""
|
||||
builder = DAGBuilder()
|
||||
source_id = builder.source("/test.mp4")
|
||||
segment_id = builder.segment(source_id, duration=5.0, offset=10.0)
|
||||
|
||||
node = builder.dag.nodes[segment_id]
|
||||
assert node.node_type == NodeType.SEGMENT
|
||||
assert node.config["duration"] == 5.0
|
||||
assert node.config["offset"] == 10.0
|
||||
assert source_id in node.inputs
|
||||
|
||||
def test_builder_chain(self):
|
||||
"""Test building a chain of nodes."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source("/test.mp4")
|
||||
segment = builder.segment(source, duration=5.0)
|
||||
resized = builder.resize(segment, width=1920, height=1080)
|
||||
builder.set_output(resized)
|
||||
|
||||
dag = builder.build()
|
||||
|
||||
assert len(dag.nodes) == 3
|
||||
assert dag.output_id == resized
|
||||
errors = dag.validate()
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_builder_sequence(self):
|
||||
"""Test building sequence node."""
|
||||
builder = DAGBuilder()
|
||||
s1 = builder.source("/clip1.mp4")
|
||||
s2 = builder.source("/clip2.mp4")
|
||||
seq = builder.sequence([s1, s2], transition={"type": "crossfade", "duration": 0.5})
|
||||
builder.set_output(seq)
|
||||
|
||||
dag = builder.build()
|
||||
node = dag.nodes[seq]
|
||||
assert node.node_type == NodeType.SEQUENCE
|
||||
assert s1 in node.inputs
|
||||
assert s2 in node.inputs
|
||||
|
||||
def test_builder_mux(self):
|
||||
"""Test building mux node."""
|
||||
builder = DAGBuilder()
|
||||
video = builder.source("/video.mp4")
|
||||
audio = builder.source("/audio.mp3")
|
||||
muxed = builder.mux(video, audio)
|
||||
builder.set_output(muxed)
|
||||
|
||||
dag = builder.build()
|
||||
node = dag.nodes[muxed]
|
||||
assert node.node_type == NodeType.MUX
|
||||
assert video in node.inputs
|
||||
assert audio in node.inputs
|
||||
|
||||
def test_builder_transform(self):
|
||||
"""Test building transform node."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source("/test.mp4")
|
||||
transformed = builder.transform(source, effects={"saturation": 1.5, "contrast": 1.2})
|
||||
builder.set_output(transformed)
|
||||
|
||||
dag = builder.build()
|
||||
node = dag.nodes[transformed]
|
||||
assert node.node_type == NodeType.TRANSFORM
|
||||
assert node.config["effects"]["saturation"] == 1.5
|
||||
|
||||
def test_builder_validation_fails(self):
|
||||
"""Builder raises error for invalid DAG."""
|
||||
builder = DAGBuilder()
|
||||
builder.source("/test.mp4")
|
||||
# No output set
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
builder.build()
|
||||
464
core/tests/test_engine.py
Normal file
464
core/tests/test_engine.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# tests/test_primitive_new/test_engine.py
|
||||
"""Tests for primitive engine execution."""
|
||||
|
||||
import pytest
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from artdag.dag import DAG, DAGBuilder, Node, NodeType
|
||||
from artdag.engine import Engine
|
||||
from artdag import nodes # Register executors
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cache_dir():
|
||||
"""Create temporary cache directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine(cache_dir):
|
||||
"""Create engine instance."""
|
||||
return Engine(cache_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_video(cache_dir):
|
||||
"""Create a test video file."""
|
||||
video_path = cache_dir / "test_video.mp4"
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", "testsrc=duration=5:size=320x240:rate=30",
|
||||
"-f", "lavfi", "-i", "sine=frequency=440:duration=5",
|
||||
"-c:v", "libx264", "-preset", "ultrafast",
|
||||
"-c:a", "aac",
|
||||
str(video_path)
|
||||
]
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
return video_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_audio(cache_dir):
|
||||
"""Create a test audio file."""
|
||||
audio_path = cache_dir / "test_audio.mp3"
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", "sine=frequency=880:duration=5",
|
||||
"-c:a", "libmp3lame",
|
||||
str(audio_path)
|
||||
]
|
||||
subprocess.run(cmd, capture_output=True, check=True)
|
||||
return audio_path
|
||||
|
||||
|
||||
class TestEngineBasic:
|
||||
"""Test basic engine functionality."""
|
||||
|
||||
def test_engine_creation(self, cache_dir):
|
||||
"""Test engine creation."""
|
||||
engine = Engine(cache_dir)
|
||||
assert engine.cache is not None
|
||||
|
||||
def test_invalid_dag(self, engine):
|
||||
"""Test executing invalid DAG."""
|
||||
dag = DAG() # No nodes, no output
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert not result.success
|
||||
assert "Invalid DAG" in result.error
|
||||
|
||||
def test_missing_executor(self, engine):
|
||||
"""Test executing node with missing executor."""
|
||||
dag = DAG()
|
||||
node = Node(node_type="UNKNOWN_TYPE", config={})
|
||||
node_id = dag.add_node(node)
|
||||
dag.set_output(node_id)
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert not result.success
|
||||
assert "No executor" in result.error
|
||||
|
||||
|
||||
class TestSourceExecutor:
|
||||
"""Test SOURCE node executor."""
|
||||
|
||||
def test_source_creates_symlink(self, engine, test_video):
|
||||
"""Test source node creates symlink."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
builder.set_output(source)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
assert result.output_path.is_symlink()
|
||||
|
||||
def test_source_missing_file(self, engine):
|
||||
"""Test source with missing file."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source("/nonexistent/file.mp4")
|
||||
builder.set_output(source)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert not result.success
|
||||
assert "not found" in result.error.lower()
|
||||
|
||||
|
||||
class TestSegmentExecutor:
|
||||
"""Test SEGMENT node executor."""
|
||||
|
||||
def test_segment_duration(self, engine, test_video):
|
||||
"""Test segment extracts correct duration."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
segment = builder.segment(source, duration=2.0)
|
||||
builder.set_output(segment)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
|
||||
# Verify duration
|
||||
probe = subprocess.run([
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
str(result.output_path)
|
||||
], capture_output=True, text=True)
|
||||
duration = float(probe.stdout.strip())
|
||||
assert abs(duration - 2.0) < 0.1
|
||||
|
||||
def test_segment_with_offset(self, engine, test_video):
|
||||
"""Test segment with offset."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
segment = builder.segment(source, offset=1.0, duration=2.0)
|
||||
builder.set_output(segment)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
assert result.success
|
||||
|
||||
|
||||
class TestResizeExecutor:
|
||||
"""Test RESIZE node executor."""
|
||||
|
||||
def test_resize_dimensions(self, engine, test_video):
|
||||
"""Test resize to specific dimensions."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
resized = builder.resize(source, width=640, height=480, mode="fit")
|
||||
builder.set_output(resized)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
|
||||
# Verify dimensions
|
||||
probe = subprocess.run([
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "stream=width,height",
|
||||
"-of", "csv=p=0:s=x",
|
||||
str(result.output_path)
|
||||
], capture_output=True, text=True)
|
||||
dimensions = probe.stdout.strip().split("\n")[0]
|
||||
assert "640x480" in dimensions
|
||||
|
||||
|
||||
class TestTransformExecutor:
|
||||
"""Test TRANSFORM node executor."""
|
||||
|
||||
def test_transform_saturation(self, engine, test_video):
|
||||
"""Test transform with saturation effect."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
transformed = builder.transform(source, effects={"saturation": 1.5})
|
||||
builder.set_output(transformed)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
|
||||
def test_transform_multiple_effects(self, engine, test_video):
|
||||
"""Test transform with multiple effects."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
transformed = builder.transform(source, effects={
|
||||
"saturation": 1.2,
|
||||
"contrast": 1.1,
|
||||
"brightness": 0.05,
|
||||
})
|
||||
builder.set_output(transformed)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
assert result.success
|
||||
|
||||
|
||||
class TestSequenceExecutor:
|
||||
"""Test SEQUENCE node executor."""
|
||||
|
||||
def test_sequence_cut(self, engine, test_video):
|
||||
"""Test sequence with cut transition."""
|
||||
builder = DAGBuilder()
|
||||
s1 = builder.source(str(test_video))
|
||||
seg1 = builder.segment(s1, duration=2.0)
|
||||
seg2 = builder.segment(s1, offset=2.0, duration=2.0)
|
||||
seq = builder.sequence([seg1, seg2], transition={"type": "cut"})
|
||||
builder.set_output(seq)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
|
||||
# Verify combined duration
|
||||
probe = subprocess.run([
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
str(result.output_path)
|
||||
], capture_output=True, text=True)
|
||||
duration = float(probe.stdout.strip())
|
||||
assert abs(duration - 4.0) < 0.2
|
||||
|
||||
def test_sequence_crossfade(self, engine, test_video):
|
||||
"""Test sequence with crossfade transition."""
|
||||
builder = DAGBuilder()
|
||||
s1 = builder.source(str(test_video))
|
||||
seg1 = builder.segment(s1, duration=3.0)
|
||||
seg2 = builder.segment(s1, offset=1.0, duration=3.0)
|
||||
seq = builder.sequence([seg1, seg2], transition={"type": "crossfade", "duration": 0.5})
|
||||
builder.set_output(seq)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
|
||||
# Duration should be sum minus crossfade
|
||||
probe = subprocess.run([
|
||||
"ffprobe", "-v", "error",
|
||||
"-show_entries", "format=duration",
|
||||
"-of", "csv=p=0",
|
||||
str(result.output_path)
|
||||
], capture_output=True, text=True)
|
||||
duration = float(probe.stdout.strip())
|
||||
# 3 + 3 - 0.5 = 5.5
|
||||
assert abs(duration - 5.5) < 0.3
|
||||
|
||||
|
||||
class TestMuxExecutor:
|
||||
"""Test MUX node executor."""
|
||||
|
||||
def test_mux_video_audio(self, engine, test_video, test_audio):
|
||||
"""Test muxing video and audio."""
|
||||
builder = DAGBuilder()
|
||||
video = builder.source(str(test_video))
|
||||
audio = builder.source(str(test_audio))
|
||||
muxed = builder.mux(video, audio)
|
||||
builder.set_output(muxed)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
|
||||
|
||||
class TestAudioMixExecutor:
|
||||
"""Test AUDIO_MIX node executor."""
|
||||
|
||||
def test_audio_mix_simple(self, engine, cache_dir):
|
||||
"""Test simple audio mixing."""
|
||||
# Create two test audio files with different frequencies
|
||||
audio1_path = cache_dir / "audio1.mp3"
|
||||
audio2_path = cache_dir / "audio2.mp3"
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", "sine=frequency=440:duration=3",
|
||||
"-c:a", "libmp3lame",
|
||||
str(audio1_path)
|
||||
], capture_output=True, check=True)
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", "sine=frequency=880:duration=3",
|
||||
"-c:a", "libmp3lame",
|
||||
str(audio2_path)
|
||||
], capture_output=True, check=True)
|
||||
|
||||
builder = DAGBuilder()
|
||||
a1 = builder.source(str(audio1_path))
|
||||
a2 = builder.source(str(audio2_path))
|
||||
mixed = builder.audio_mix([a1, a2])
|
||||
builder.set_output(mixed)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
|
||||
def test_audio_mix_with_gains(self, engine, cache_dir):
|
||||
"""Test audio mixing with custom gains."""
|
||||
audio1_path = cache_dir / "audio1.mp3"
|
||||
audio2_path = cache_dir / "audio2.mp3"
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", "sine=frequency=440:duration=3",
|
||||
"-c:a", "libmp3lame",
|
||||
str(audio1_path)
|
||||
], capture_output=True, check=True)
|
||||
|
||||
subprocess.run([
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", "sine=frequency=880:duration=3",
|
||||
"-c:a", "libmp3lame",
|
||||
str(audio2_path)
|
||||
], capture_output=True, check=True)
|
||||
|
||||
builder = DAGBuilder()
|
||||
a1 = builder.source(str(audio1_path))
|
||||
a2 = builder.source(str(audio2_path))
|
||||
mixed = builder.audio_mix([a1, a2], gains=[1.0, 0.3])
|
||||
builder.set_output(mixed)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
|
||||
def test_audio_mix_three_inputs(self, engine, cache_dir):
|
||||
"""Test mixing three audio sources."""
|
||||
audio_paths = []
|
||||
for i, freq in enumerate([440, 660, 880]):
|
||||
path = cache_dir / f"audio{i}.mp3"
|
||||
subprocess.run([
|
||||
"ffmpeg", "-y",
|
||||
"-f", "lavfi", "-i", f"sine=frequency={freq}:duration=2",
|
||||
"-c:a", "libmp3lame",
|
||||
str(path)
|
||||
], capture_output=True, check=True)
|
||||
audio_paths.append(path)
|
||||
|
||||
builder = DAGBuilder()
|
||||
sources = [builder.source(str(p)) for p in audio_paths]
|
||||
mixed = builder.audio_mix(sources, gains=[1.0, 0.5, 0.3])
|
||||
builder.set_output(mixed)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
|
||||
|
||||
class TestCaching:
|
||||
"""Test engine caching behavior."""
|
||||
|
||||
def test_cache_reuse(self, engine, test_video):
|
||||
"""Test that cached results are reused."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
builder.set_output(source)
|
||||
dag = builder.build()
|
||||
|
||||
# First execution
|
||||
result1 = engine.execute(dag)
|
||||
assert result1.success
|
||||
assert result1.nodes_cached == 0
|
||||
assert result1.nodes_executed == 1
|
||||
|
||||
# Second execution should use cache
|
||||
result2 = engine.execute(dag)
|
||||
assert result2.success
|
||||
assert result2.nodes_cached == 1
|
||||
assert result2.nodes_executed == 0
|
||||
|
||||
def test_clear_cache(self, engine, test_video):
|
||||
"""Test clearing cache."""
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
builder.set_output(source)
|
||||
dag = builder.build()
|
||||
|
||||
engine.execute(dag)
|
||||
assert engine.cache.stats.total_entries == 1
|
||||
|
||||
engine.clear_cache()
|
||||
assert engine.cache.stats.total_entries == 0
|
||||
|
||||
|
||||
class TestProgressCallback:
|
||||
"""Test progress callback functionality."""
|
||||
|
||||
def test_progress_callback(self, engine, test_video):
|
||||
"""Test that progress callback is called."""
|
||||
progress_updates = []
|
||||
|
||||
def callback(progress):
|
||||
progress_updates.append((progress.node_id, progress.status))
|
||||
|
||||
engine.set_progress_callback(callback)
|
||||
|
||||
builder = DAGBuilder()
|
||||
source = builder.source(str(test_video))
|
||||
builder.set_output(source)
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert len(progress_updates) > 0
|
||||
# Should have pending, running, completed
|
||||
statuses = [p[1] for p in progress_updates]
|
||||
assert "pending" in statuses
|
||||
assert "completed" in statuses
|
||||
|
||||
|
||||
class TestFullWorkflow:
|
||||
"""Test complete workflow."""
|
||||
|
||||
def test_full_pipeline(self, engine, test_video, test_audio):
|
||||
"""Test complete video processing pipeline."""
|
||||
builder = DAGBuilder()
|
||||
|
||||
# Load sources
|
||||
video = builder.source(str(test_video))
|
||||
audio = builder.source(str(test_audio))
|
||||
|
||||
# Extract segment
|
||||
segment = builder.segment(video, duration=3.0)
|
||||
|
||||
# Resize
|
||||
resized = builder.resize(segment, width=640, height=480)
|
||||
|
||||
# Apply effects
|
||||
transformed = builder.transform(resized, effects={"saturation": 1.3})
|
||||
|
||||
# Mux with audio
|
||||
final = builder.mux(transformed, audio)
|
||||
builder.set_output(final)
|
||||
|
||||
dag = builder.build()
|
||||
|
||||
result = engine.execute(dag)
|
||||
|
||||
assert result.success
|
||||
assert result.output_path.exists()
|
||||
assert result.nodes_executed == 6 # source, source, segment, resize, transform, mux
|
||||
110
core/tests/test_executor.py
Normal file
110
core/tests/test_executor.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# tests/test_primitive_new/test_executor.py
|
||||
"""Tests for primitive executor module."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from artdag.dag import NodeType
|
||||
from artdag.executor import (
|
||||
Executor,
|
||||
register_executor,
|
||||
get_executor,
|
||||
list_executors,
|
||||
clear_executors,
|
||||
)
|
||||
|
||||
|
||||
class TestExecutorRegistry:
|
||||
"""Test executor registration."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear registry before each test."""
|
||||
clear_executors()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clear registry after each test."""
|
||||
clear_executors()
|
||||
|
||||
def test_register_executor(self):
|
||||
"""Test registering an executor."""
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class TestSourceExecutor(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
executor = get_executor(NodeType.SOURCE)
|
||||
assert executor is not None
|
||||
assert isinstance(executor, TestSourceExecutor)
|
||||
|
||||
def test_register_custom_type(self):
|
||||
"""Test registering executor for custom type."""
|
||||
@register_executor("CUSTOM_NODE")
|
||||
class CustomExecutor(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
executor = get_executor("CUSTOM_NODE")
|
||||
assert executor is not None
|
||||
|
||||
def test_get_unregistered(self):
|
||||
"""Test getting unregistered executor."""
|
||||
executor = get_executor(NodeType.ANALYZE)
|
||||
assert executor is None
|
||||
|
||||
def test_list_executors(self):
|
||||
"""Test listing registered executors."""
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class SourceExec(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
@register_executor(NodeType.SEGMENT)
|
||||
class SegmentExec(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
executors = list_executors()
|
||||
assert "SOURCE" in executors
|
||||
assert "SEGMENT" in executors
|
||||
|
||||
def test_overwrite_warning(self, caplog):
|
||||
"""Test warning when overwriting executor."""
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class FirstExecutor(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
# Register again - should warn
|
||||
@register_executor(NodeType.SOURCE)
|
||||
class SecondExecutor(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
# Second should be registered
|
||||
executor = get_executor(NodeType.SOURCE)
|
||||
assert isinstance(executor, SecondExecutor)
|
||||
|
||||
|
||||
class TestExecutorBase:
|
||||
"""Test Executor base class."""
|
||||
|
||||
def test_validate_config_default(self):
|
||||
"""Test default validate_config returns empty list."""
|
||||
class TestExecutor(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
executor = TestExecutor()
|
||||
errors = executor.validate_config({"any": "config"})
|
||||
assert errors == []
|
||||
|
||||
def test_estimate_output_size(self):
|
||||
"""Test default output size estimation."""
|
||||
class TestExecutor(Executor):
|
||||
def execute(self, config, inputs, output_path):
|
||||
return output_path
|
||||
|
||||
executor = TestExecutor()
|
||||
size = executor.estimate_output_size({}, [100, 200, 300])
|
||||
assert size == 600
|
||||
301
core/tests/test_ipfs_access.py
Normal file
301
core/tests/test_ipfs_access.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Tests for IPFS access consistency.
|
||||
|
||||
All IPFS access should use IPFS_API (multiaddr format) for consistency
|
||||
with art-celery's ipfs_client.py. This ensures Docker deployments work
|
||||
correctly since IPFS_API is set to /dns/ipfs/tcp/5001.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def multiaddr_to_url(multiaddr: str) -> str:
|
||||
"""
|
||||
Convert IPFS multiaddr to HTTP URL.
|
||||
|
||||
This is the canonical conversion used by ipfs_client.py.
|
||||
"""
|
||||
# Handle /dns/hostname/tcp/port format
|
||||
dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if dns_match:
|
||||
return f"http://{dns_match.group(1)}:{dns_match.group(2)}"
|
||||
|
||||
# Handle /ip4/address/tcp/port format
|
||||
ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if ip4_match:
|
||||
return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}"
|
||||
|
||||
# Fallback: assume it's already a URL or use default
|
||||
if multiaddr.startswith("http"):
|
||||
return multiaddr
|
||||
return "http://127.0.0.1:5001"
|
||||
|
||||
|
||||
class TestMultiaddrConversion:
|
||||
"""Tests for multiaddr to URL conversion."""
|
||||
|
||||
def test_dns_format(self) -> None:
|
||||
"""Docker DNS format should convert correctly."""
|
||||
result = multiaddr_to_url("/dns/ipfs/tcp/5001")
|
||||
assert result == "http://ipfs:5001"
|
||||
|
||||
def test_dns4_format(self) -> None:
|
||||
"""dns4 format should work."""
|
||||
result = multiaddr_to_url("/dns4/ipfs.example.com/tcp/5001")
|
||||
assert result == "http://ipfs.example.com:5001"
|
||||
|
||||
def test_ip4_format(self) -> None:
|
||||
"""IPv4 format should convert correctly."""
|
||||
result = multiaddr_to_url("/ip4/127.0.0.1/tcp/5001")
|
||||
assert result == "http://127.0.0.1:5001"
|
||||
|
||||
def test_already_url(self) -> None:
|
||||
"""HTTP URLs should pass through."""
|
||||
result = multiaddr_to_url("http://localhost:5001")
|
||||
assert result == "http://localhost:5001"
|
||||
|
||||
def test_fallback(self) -> None:
|
||||
"""Unknown format should fallback to localhost."""
|
||||
result = multiaddr_to_url("garbage")
|
||||
assert result == "http://127.0.0.1:5001"
|
||||
|
||||
|
||||
class TestIPFSConfigConsistency:
|
||||
"""
|
||||
Tests to ensure IPFS configuration is consistent.
|
||||
|
||||
The effect executor should use IPFS_API (like ipfs_client.py)
|
||||
rather than a separate IPFS_GATEWAY variable.
|
||||
"""
|
||||
|
||||
def test_effect_module_should_not_use_gateway_var(self) -> None:
|
||||
"""
|
||||
Regression test: Effect module should use IPFS_API, not IPFS_GATEWAY.
|
||||
|
||||
Bug found 2026-01-12: artdag/nodes/effect.py used IPFS_GATEWAY which
|
||||
defaulted to http://127.0.0.1:8080. This doesn't work in Docker where
|
||||
the IPFS node is a separate container. The ipfs_client.py uses IPFS_API
|
||||
which is correctly set in docker-compose.
|
||||
"""
|
||||
from artdag.nodes import effect
|
||||
|
||||
# Check if the module still has the old IPFS_GATEWAY variable
|
||||
# After the fix, this should use IPFS_API instead
|
||||
has_gateway_var = hasattr(effect, 'IPFS_GATEWAY')
|
||||
has_api_var = hasattr(effect, 'IPFS_API') or hasattr(effect, '_get_ipfs_base_url')
|
||||
|
||||
# This test documents the current buggy state
|
||||
# After fix: has_gateway_var should be False, has_api_var should be True
|
||||
if has_gateway_var and not has_api_var:
|
||||
pytest.fail(
|
||||
"Effect module uses IPFS_GATEWAY instead of IPFS_API. "
|
||||
"This breaks Docker deployments where IPFS_API=/dns/ipfs/tcp/5001 "
|
||||
"but IPFS_GATEWAY defaults to localhost."
|
||||
)
|
||||
|
||||
def test_ipfs_api_default_is_localhost(self) -> None:
|
||||
"""IPFS_API should default to localhost for local development."""
|
||||
default_api = "/ip4/127.0.0.1/tcp/5001"
|
||||
url = multiaddr_to_url(default_api)
|
||||
assert "127.0.0.1" in url
|
||||
assert "5001" in url
|
||||
|
||||
def test_docker_ipfs_api_uses_service_name(self) -> None:
|
||||
"""In Docker, IPFS_API should use the service name."""
|
||||
docker_api = "/dns/ipfs/tcp/5001"
|
||||
url = multiaddr_to_url(docker_api)
|
||||
assert url == "http://ipfs:5001"
|
||||
assert "127.0.0.1" not in url
|
||||
|
||||
|
||||
class TestEffectFetchURL:
|
||||
"""Tests for the URL used to fetch effects from IPFS."""
|
||||
|
||||
def test_fetch_should_use_api_cat_endpoint(self) -> None:
|
||||
"""
|
||||
Effect fetch should use /api/v0/cat endpoint (like ipfs_client.py).
|
||||
|
||||
The IPFS API's cat endpoint works reliably in Docker.
|
||||
The gateway endpoint (port 8080) requires separate configuration.
|
||||
"""
|
||||
# The correct way to fetch via API
|
||||
base_url = "http://ipfs:5001"
|
||||
cid = "QmTestCid123"
|
||||
correct_url = f"{base_url}/api/v0/cat?arg={cid}"
|
||||
|
||||
assert "/api/v0/cat" in correct_url
|
||||
assert "arg=" in correct_url
|
||||
|
||||
def test_gateway_url_is_different_from_api(self) -> None:
|
||||
"""
|
||||
Document the difference between gateway and API URLs.
|
||||
|
||||
Gateway: http://ipfs:8080/ipfs/{cid} (requires IPFS_GATEWAY config)
|
||||
API: http://ipfs:5001/api/v0/cat?arg={cid} (uses IPFS_API config)
|
||||
|
||||
Using the API is more reliable since IPFS_API is already configured
|
||||
correctly in docker-compose.yml.
|
||||
"""
|
||||
cid = "QmTestCid123"
|
||||
|
||||
# Gateway style (the old broken way)
|
||||
gateway_url = f"http://ipfs:8080/ipfs/{cid}"
|
||||
|
||||
# API style (the correct way)
|
||||
api_url = f"http://ipfs:5001/api/v0/cat?arg={cid}"
|
||||
|
||||
# These are different approaches
|
||||
assert gateway_url != api_url
|
||||
assert ":8080" in gateway_url
|
||||
assert ":5001" in api_url
|
||||
|
||||
|
||||
class TestEffectDependencies:
|
||||
"""Tests for effect dependency handling."""
|
||||
|
||||
def test_parse_pep723_dependencies(self) -> None:
|
||||
"""Should parse PEP 723 dependencies from effect source."""
|
||||
source = '''
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["numpy", "opencv-python"]
|
||||
# ///
|
||||
"""
|
||||
@effect test_effect
|
||||
"""
|
||||
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
# Import the function after the fix is applied
|
||||
from artdag.nodes.effect import _parse_pep723_dependencies
|
||||
|
||||
deps = _parse_pep723_dependencies(source)
|
||||
|
||||
assert deps == ["numpy", "opencv-python"]
|
||||
|
||||
def test_parse_pep723_no_dependencies(self) -> None:
|
||||
"""Should return empty list if no dependencies block."""
|
||||
source = '''
|
||||
"""
|
||||
@effect simple_effect
|
||||
"""
|
||||
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
from artdag.nodes.effect import _parse_pep723_dependencies
|
||||
|
||||
deps = _parse_pep723_dependencies(source)
|
||||
|
||||
assert deps == []
|
||||
|
||||
def test_ensure_dependencies_already_installed(self) -> None:
|
||||
"""Should return True if dependencies are already installed."""
|
||||
from artdag.nodes.effect import _ensure_dependencies
|
||||
|
||||
# os is always available
|
||||
result = _ensure_dependencies(["os"], "QmTest123")
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_effect_with_missing_dependency_gives_clear_error(self, tmp_path: Path) -> None:
|
||||
"""
|
||||
Regression test: Missing dependencies should give clear error message.
|
||||
|
||||
Bug found 2026-01-12: Effect with numpy dependency failed with
|
||||
"No module named 'numpy'" but this was swallowed and reported as
|
||||
"Unknown effect: invert" - very confusing.
|
||||
"""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effect_cid = "QmTestEffectWithDeps"
|
||||
|
||||
# Create effect that imports a non-existent module
|
||||
effect_dir = effects_dir / effect_cid
|
||||
effect_dir.mkdir(parents=True)
|
||||
(effect_dir / "effect.py").write_text('''
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["some_nonexistent_package_xyz"]
|
||||
# ///
|
||||
"""
|
||||
@effect test_effect
|
||||
"""
|
||||
import some_nonexistent_package_xyz
|
||||
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
''')
|
||||
|
||||
# The effect file exists
|
||||
effect_path = effects_dir / effect_cid / "effect.py"
|
||||
assert effect_path.exists()
|
||||
|
||||
# When loading fails due to missing import, error should mention the dependency
|
||||
with patch.dict(os.environ, {"CACHE_DIR": str(tmp_path)}):
|
||||
from artdag.nodes.effect import _load_cached_effect
|
||||
|
||||
# This should return None but log a clear error about the missing module
|
||||
result = _load_cached_effect(effect_cid)
|
||||
|
||||
# Currently returns None, which causes "Unknown effect" error
|
||||
# The real issue is the dependency isn't installed
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestEffectCacheAndFetch:
|
||||
"""Integration tests for effect caching and fetching."""
|
||||
|
||||
def test_effect_loads_from_cache_without_ipfs(self, tmp_path: Path) -> None:
|
||||
"""When effect is in cache, IPFS should not be contacted."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effect_cid = "QmTestEffect123"
|
||||
|
||||
# Create cached effect
|
||||
effect_dir = effects_dir / effect_cid
|
||||
effect_dir.mkdir(parents=True)
|
||||
(effect_dir / "effect.py").write_text('''
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
''')
|
||||
|
||||
# Patch environment and verify effect can be loaded
|
||||
with patch.dict(os.environ, {"CACHE_DIR": str(tmp_path)}):
|
||||
from artdag.nodes.effect import _load_cached_effect
|
||||
|
||||
# Should load without hitting IPFS
|
||||
effect_fn = _load_cached_effect(effect_cid)
|
||||
assert effect_fn is not None
|
||||
|
||||
def test_effect_fetch_uses_correct_endpoint(self, tmp_path: Path) -> None:
|
||||
"""When fetching from IPFS, should use API endpoint."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effects_dir.mkdir(parents=True)
|
||||
effect_cid = "QmNonExistentEffect"
|
||||
|
||||
with patch.dict(os.environ, {
|
||||
"CACHE_DIR": str(tmp_path),
|
||||
"IPFS_API": "/dns/ipfs/tcp/5001"
|
||||
}):
|
||||
with patch('requests.post') as mock_post:
|
||||
# Set up mock to return effect source
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.content = b'def process_frame(f, p, s): return f, s'
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
from artdag.nodes.effect import _load_cached_effect
|
||||
|
||||
# Try to load - should attempt IPFS fetch
|
||||
_load_cached_effect(effect_cid)
|
||||
|
||||
# After fix, this should use the API endpoint
|
||||
# Check if requests.post was called (API style)
|
||||
# or requests.get was called (gateway style)
|
||||
# The fix should make it use POST to /api/v0/cat
|
||||
Reference in New Issue
Block a user