Import L1 (celery) as l1/
This commit is contained in:
1
l1/tests/__init__.py
Normal file
1
l1/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for art-celery
|
||||
93
l1/tests/conftest.py
Normal file
93
l1/tests/conftest.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Pytest fixtures for art-celery tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_compiled_nodes() -> List[Dict[str, Any]]:
|
||||
"""Sample nodes as produced by the S-expression compiler."""
|
||||
return [
|
||||
{
|
||||
"id": "source_1",
|
||||
"type": "SOURCE",
|
||||
"config": {"asset": "cat"},
|
||||
"inputs": [],
|
||||
"name": None,
|
||||
},
|
||||
{
|
||||
"id": "source_2",
|
||||
"type": "SOURCE",
|
||||
"config": {
|
||||
"input": True,
|
||||
"name": "Second Video",
|
||||
"description": "A user-provided video",
|
||||
},
|
||||
"inputs": [],
|
||||
"name": "second-video",
|
||||
},
|
||||
{
|
||||
"id": "effect_1",
|
||||
"type": "EFFECT",
|
||||
"config": {"effect": "identity"},
|
||||
"inputs": ["source_1"],
|
||||
"name": None,
|
||||
},
|
||||
{
|
||||
"id": "effect_2",
|
||||
"type": "EFFECT",
|
||||
"config": {"effect": "invert", "intensity": 1.0},
|
||||
"inputs": ["source_2"],
|
||||
"name": None,
|
||||
},
|
||||
{
|
||||
"id": "sequence_1",
|
||||
"type": "SEQUENCE",
|
||||
"config": {},
|
||||
"inputs": ["effect_1", "effect_2"],
|
||||
"name": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_registry() -> Dict[str, Dict[str, Any]]:
|
||||
"""Sample registry with assets and effects."""
|
||||
return {
|
||||
"assets": {
|
||||
"cat": {
|
||||
"cid": "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ",
|
||||
"url": "https://example.com/cat.jpg",
|
||||
},
|
||||
},
|
||||
"effects": {
|
||||
"identity": {
|
||||
"cid": "QmcWhw6wbHr1GDmorM2KDz8S3yCGTfjuyPR6y8khS2tvko",
|
||||
},
|
||||
"invert": {
|
||||
"cid": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_recipe(
|
||||
sample_compiled_nodes: List[Dict[str, Any]],
|
||||
sample_registry: Dict[str, Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Sample compiled recipe."""
|
||||
return {
|
||||
"name": "test-recipe",
|
||||
"version": "1.0",
|
||||
"description": "A test recipe",
|
||||
"owner": "@test@example.com",
|
||||
"registry": sample_registry,
|
||||
"dag": {
|
||||
"nodes": sample_compiled_nodes,
|
||||
"output": "sequence_1",
|
||||
},
|
||||
"recipe_id": "Qmtest123",
|
||||
}
|
||||
42
l1/tests/test_auth.py
Normal file
42
l1/tests/test_auth.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Tests for authentication service.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestUserContext:
|
||||
"""Tests for UserContext dataclass."""
|
||||
|
||||
def test_user_context_accepts_l2_server(self) -> None:
|
||||
"""
|
||||
Regression test: UserContext must accept l2_server field.
|
||||
|
||||
Bug found 2026-01-12: auth_service.py passes l2_server to UserContext
|
||||
but the art-common library was pinned to old version without this field.
|
||||
"""
|
||||
from artdag_common.middleware.auth import UserContext
|
||||
|
||||
# This should not raise TypeError
|
||||
ctx = UserContext(
|
||||
username="testuser",
|
||||
actor_id="@testuser@example.com",
|
||||
token="test-token",
|
||||
l2_server="https://l2.example.com",
|
||||
)
|
||||
|
||||
assert ctx.username == "testuser"
|
||||
assert ctx.actor_id == "@testuser@example.com"
|
||||
assert ctx.token == "test-token"
|
||||
assert ctx.l2_server == "https://l2.example.com"
|
||||
|
||||
def test_user_context_l2_server_optional(self) -> None:
|
||||
"""l2_server should be optional (default None)."""
|
||||
from artdag_common.middleware.auth import UserContext
|
||||
|
||||
ctx = UserContext(
|
||||
username="testuser",
|
||||
actor_id="@testuser@example.com",
|
||||
)
|
||||
|
||||
assert ctx.l2_server is None
|
||||
397
l1/tests/test_cache_manager.py
Normal file
397
l1/tests/test_cache_manager.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# tests/test_cache_manager.py
|
||||
"""Tests for the L1 cache manager."""
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from cache_manager import (
|
||||
L1CacheManager,
|
||||
L2SharedChecker,
|
||||
CachedFile,
|
||||
file_hash,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
"""Create a temporary directory for tests."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
yield Path(tmpdir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_l2():
|
||||
"""Mock L2 server responses."""
|
||||
with patch("cache_manager.requests") as mock_requests:
|
||||
mock_requests.get.return_value = Mock(status_code=404)
|
||||
yield mock_requests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(temp_dir, mock_l2):
|
||||
"""Create a cache manager instance."""
|
||||
return L1CacheManager(
|
||||
cache_dir=temp_dir / "cache",
|
||||
l2_server="http://mock-l2:8200",
|
||||
)
|
||||
|
||||
|
||||
def create_test_file(path: Path, content: str = "test content") -> Path:
|
||||
"""Create a test file with content."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content)
|
||||
return path
|
||||
|
||||
|
||||
class TestFileHash:
|
||||
"""Tests for file_hash function."""
|
||||
|
||||
def test_consistent_hash(self, temp_dir):
|
||||
"""Same content produces same hash."""
|
||||
file1 = create_test_file(temp_dir / "f1.txt", "hello")
|
||||
file2 = create_test_file(temp_dir / "f2.txt", "hello")
|
||||
|
||||
assert file_hash(file1) == file_hash(file2)
|
||||
|
||||
def test_different_content_different_hash(self, temp_dir):
|
||||
"""Different content produces different hash."""
|
||||
file1 = create_test_file(temp_dir / "f1.txt", "hello")
|
||||
file2 = create_test_file(temp_dir / "f2.txt", "world")
|
||||
|
||||
assert file_hash(file1) != file_hash(file2)
|
||||
|
||||
def test_sha3_256_length(self, temp_dir):
|
||||
"""Hash is SHA3-256 (64 hex chars)."""
|
||||
f = create_test_file(temp_dir / "f.txt", "test")
|
||||
assert len(file_hash(f)) == 64
|
||||
|
||||
|
||||
class TestL2SharedChecker:
|
||||
"""Tests for L2 shared status checking."""
|
||||
|
||||
def test_not_shared_returns_false(self, mock_l2):
|
||||
"""Non-existent content returns False."""
|
||||
checker = L2SharedChecker("http://mock:8200")
|
||||
mock_l2.get.return_value = Mock(status_code=404)
|
||||
|
||||
assert checker.is_shared("abc123") is False
|
||||
|
||||
def test_shared_returns_true(self, mock_l2):
|
||||
"""Published content returns True."""
|
||||
checker = L2SharedChecker("http://mock:8200")
|
||||
mock_l2.get.return_value = Mock(status_code=200)
|
||||
|
||||
assert checker.is_shared("abc123") is True
|
||||
|
||||
def test_caches_result(self, mock_l2):
|
||||
"""Results are cached to avoid repeated API calls."""
|
||||
checker = L2SharedChecker("http://mock:8200", cache_ttl=60)
|
||||
mock_l2.get.return_value = Mock(status_code=200)
|
||||
|
||||
checker.is_shared("abc123")
|
||||
checker.is_shared("abc123")
|
||||
|
||||
# Should only call API once
|
||||
assert mock_l2.get.call_count == 1
|
||||
|
||||
def test_mark_shared(self, mock_l2):
|
||||
"""mark_shared updates cache without API call."""
|
||||
checker = L2SharedChecker("http://mock:8200")
|
||||
|
||||
checker.mark_shared("abc123")
|
||||
|
||||
assert checker.is_shared("abc123") is True
|
||||
assert mock_l2.get.call_count == 0
|
||||
|
||||
def test_invalidate(self, mock_l2):
|
||||
"""invalidate clears cache for a hash."""
|
||||
checker = L2SharedChecker("http://mock:8200")
|
||||
mock_l2.get.return_value = Mock(status_code=200)
|
||||
|
||||
checker.is_shared("abc123")
|
||||
checker.invalidate("abc123")
|
||||
|
||||
mock_l2.get.return_value = Mock(status_code=404)
|
||||
assert checker.is_shared("abc123") is False
|
||||
|
||||
def test_error_returns_true(self, mock_l2):
|
||||
"""API errors return True (safe - prevents accidental deletion)."""
|
||||
checker = L2SharedChecker("http://mock:8200")
|
||||
mock_l2.get.side_effect = Exception("Network error")
|
||||
|
||||
# On error, assume IS shared to prevent accidental deletion
|
||||
assert checker.is_shared("abc123") is True
|
||||
|
||||
|
||||
class TestL1CacheManagerStorage:
|
||||
"""Tests for cache storage operations."""
|
||||
|
||||
def test_put_and_get_by_cid(self, manager, temp_dir):
|
||||
"""Can store and retrieve by content hash."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "hello world")
|
||||
|
||||
cached, cid = manager.put(test_file, node_type="test")
|
||||
|
||||
retrieved_path = manager.get_by_cid(cached.cid)
|
||||
assert retrieved_path is not None
|
||||
assert retrieved_path.read_text() == "hello world"
|
||||
|
||||
def test_put_with_custom_node_id(self, manager, temp_dir):
|
||||
"""Can store with custom node_id."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "content")
|
||||
|
||||
cached, cid = manager.put(test_file, node_id="custom-node-123", node_type="test")
|
||||
|
||||
assert cached.node_id == "custom-node-123"
|
||||
assert manager.get_by_node_id("custom-node-123") is not None
|
||||
|
||||
def test_has_content(self, manager, temp_dir):
|
||||
"""has_content checks existence."""
|
||||
test_file = create_test_file(temp_dir / "input.txt", "data")
|
||||
|
||||
cached, cid = manager.put(test_file, node_type="test")
|
||||
|
||||
assert manager.has_content(cached.cid) is True
|
||||
assert manager.has_content("nonexistent") is False
|
||||
|
||||
def test_list_all(self, manager, temp_dir):
|
||||
"""list_all returns all cached files."""
|
||||
f1 = create_test_file(temp_dir / "f1.txt", "one")
|
||||
f2 = create_test_file(temp_dir / "f2.txt", "two")
|
||||
|
||||
manager.put(f1, node_type="test")
|
||||
manager.put(f2, node_type="test")
|
||||
|
||||
all_files = manager.list_all()
|
||||
assert len(all_files) == 2
|
||||
|
||||
def test_deduplication(self, manager, temp_dir):
|
||||
"""Same content is not stored twice."""
|
||||
f1 = create_test_file(temp_dir / "f1.txt", "identical")
|
||||
f2 = create_test_file(temp_dir / "f2.txt", "identical")
|
||||
|
||||
cached1, cid1 = manager.put(f1, node_type="test")
|
||||
cached2, cid2 = manager.put(f2, node_type="test")
|
||||
|
||||
assert cached1.cid == cached2.cid
|
||||
assert len(manager.list_all()) == 1
|
||||
|
||||
|
||||
class TestL1CacheManagerActivities:
|
||||
"""Tests for activity tracking."""
|
||||
|
||||
def test_record_simple_activity(self, manager, temp_dir):
|
||||
"""Can record a simple activity."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
activity = manager.record_simple_activity(
|
||||
input_hashes=[input_cached.cid],
|
||||
output_cid=output_cached.cid,
|
||||
run_id="run-001",
|
||||
)
|
||||
|
||||
assert activity.activity_id == "run-001"
|
||||
assert input_cached.cid in activity.input_ids
|
||||
assert activity.output_id == output_cached.cid
|
||||
|
||||
def test_list_activities(self, manager, temp_dir):
|
||||
"""Can list all activities."""
|
||||
for i in range(3):
|
||||
inp = create_test_file(temp_dir / f"in{i}.txt", f"input{i}")
|
||||
out = create_test_file(temp_dir / f"out{i}.txt", f"output{i}")
|
||||
inp_c, _ = manager.put(inp, node_type="source")
|
||||
out_c, _ = manager.put(out, node_type="effect")
|
||||
manager.record_simple_activity([inp_c.cid], out_c.cid)
|
||||
|
||||
activities = manager.list_activities()
|
||||
assert len(activities) == 3
|
||||
|
||||
def test_find_activities_by_inputs(self, manager, temp_dir):
|
||||
"""Can find activities with same inputs."""
|
||||
input_file = create_test_file(temp_dir / "shared_input.txt", "shared")
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
|
||||
# Two activities with same input
|
||||
out1 = create_test_file(temp_dir / "out1.txt", "output1")
|
||||
out2 = create_test_file(temp_dir / "out2.txt", "output2")
|
||||
out1_c, _ = manager.put(out1, node_type="effect")
|
||||
out2_c, _ = manager.put(out2, node_type="effect")
|
||||
|
||||
manager.record_simple_activity([input_cached.cid], out1_c.cid, "run1")
|
||||
manager.record_simple_activity([input_cached.cid], out2_c.cid, "run2")
|
||||
|
||||
found = manager.find_activities_by_inputs([input_cached.cid])
|
||||
assert len(found) == 2
|
||||
|
||||
|
||||
class TestL1CacheManagerDeletionRules:
|
||||
"""Tests for deletion rules enforcement."""
|
||||
|
||||
def test_can_delete_orphaned_item(self, manager, temp_dir):
|
||||
"""Orphaned items can be deleted."""
|
||||
test_file = create_test_file(temp_dir / "orphan.txt", "orphan")
|
||||
cached, _ = manager.put(test_file, node_type="test")
|
||||
|
||||
can_delete, reason = manager.can_delete(cached.cid)
|
||||
assert can_delete is True
|
||||
|
||||
def test_cannot_delete_activity_input(self, manager, temp_dir):
|
||||
"""Activity inputs cannot be deleted."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
manager.record_simple_activity(
|
||||
[input_cached.cid],
|
||||
output_cached.cid,
|
||||
)
|
||||
|
||||
can_delete, reason = manager.can_delete(input_cached.cid)
|
||||
assert can_delete is False
|
||||
assert "input" in reason.lower()
|
||||
|
||||
def test_cannot_delete_activity_output(self, manager, temp_dir):
|
||||
"""Activity outputs cannot be deleted."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
manager.record_simple_activity(
|
||||
[input_cached.cid],
|
||||
output_cached.cid,
|
||||
)
|
||||
|
||||
can_delete, reason = manager.can_delete(output_cached.cid)
|
||||
assert can_delete is False
|
||||
assert "output" in reason.lower()
|
||||
|
||||
def test_cannot_delete_pinned_item(self, manager, temp_dir):
|
||||
"""Pinned items cannot be deleted."""
|
||||
test_file = create_test_file(temp_dir / "shared.txt", "shared")
|
||||
cached, _ = manager.put(test_file, node_type="test")
|
||||
|
||||
# Mark as pinned (published)
|
||||
manager.pin(cached.cid, reason="published")
|
||||
|
||||
can_delete, reason = manager.can_delete(cached.cid)
|
||||
assert can_delete is False
|
||||
assert "pinned" in reason
|
||||
|
||||
def test_delete_orphaned_item(self, manager, temp_dir):
|
||||
"""Can delete orphaned items."""
|
||||
test_file = create_test_file(temp_dir / "orphan.txt", "orphan")
|
||||
cached, _ = manager.put(test_file, node_type="test")
|
||||
|
||||
success, msg = manager.delete_by_cid(cached.cid)
|
||||
|
||||
assert success is True
|
||||
assert manager.has_content(cached.cid) is False
|
||||
|
||||
def test_delete_protected_item_fails(self, manager, temp_dir):
|
||||
"""Cannot delete protected items."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
manager.record_simple_activity(
|
||||
[input_cached.cid],
|
||||
output_cached.cid,
|
||||
)
|
||||
|
||||
success, msg = manager.delete_by_cid(input_cached.cid)
|
||||
|
||||
assert success is False
|
||||
assert manager.has_content(input_cached.cid) is True
|
||||
|
||||
|
||||
class TestL1CacheManagerActivityDiscard:
|
||||
"""Tests for activity discard functionality."""
|
||||
|
||||
def test_can_discard_unshared_activity(self, manager, temp_dir):
|
||||
"""Activities with no shared items can be discarded."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
activity = manager.record_simple_activity(
|
||||
[input_cached.cid],
|
||||
output_cached.cid,
|
||||
"run-001",
|
||||
)
|
||||
|
||||
can_discard, reason = manager.can_discard_activity("run-001")
|
||||
assert can_discard is True
|
||||
|
||||
def test_cannot_discard_activity_with_pinned_output(self, manager, temp_dir):
|
||||
"""Activities with pinned outputs cannot be discarded."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
manager.record_simple_activity(
|
||||
[input_cached.cid],
|
||||
output_cached.cid,
|
||||
"run-001",
|
||||
)
|
||||
|
||||
# Mark output as pinned (published)
|
||||
manager.pin(output_cached.cid, reason="published")
|
||||
|
||||
can_discard, reason = manager.can_discard_activity("run-001")
|
||||
assert can_discard is False
|
||||
assert "pinned" in reason
|
||||
|
||||
def test_discard_activity_cleans_up(self, manager, temp_dir):
|
||||
"""Discarding activity cleans up orphaned items."""
|
||||
input_file = create_test_file(temp_dir / "input.txt", "input")
|
||||
output_file = create_test_file(temp_dir / "output.txt", "output")
|
||||
|
||||
input_cached, _ = manager.put(input_file, node_type="source")
|
||||
output_cached, _ = manager.put(output_file, node_type="effect")
|
||||
|
||||
manager.record_simple_activity(
|
||||
[input_cached.cid],
|
||||
output_cached.cid,
|
||||
"run-001",
|
||||
)
|
||||
|
||||
success, msg = manager.discard_activity("run-001")
|
||||
|
||||
assert success is True
|
||||
assert manager.get_activity("run-001") is None
|
||||
|
||||
|
||||
class TestL1CacheManagerStats:
|
||||
"""Tests for cache statistics."""
|
||||
|
||||
def test_get_stats(self, manager, temp_dir):
|
||||
"""get_stats returns cache statistics."""
|
||||
f1 = create_test_file(temp_dir / "f1.txt", "content1")
|
||||
f2 = create_test_file(temp_dir / "f2.txt", "content2")
|
||||
|
||||
manager.put(f1, node_type="test")
|
||||
manager.put(f2, node_type="test")
|
||||
|
||||
stats = manager.get_stats()
|
||||
|
||||
assert stats["total_entries"] == 2
|
||||
assert stats["total_size_bytes"] > 0
|
||||
assert "activities" in stats
|
||||
492
l1/tests/test_dag_transform.py
Normal file
492
l1/tests/test_dag_transform.py
Normal file
@@ -0,0 +1,492 @@
|
||||
"""
|
||||
Tests for DAG transformation and input binding.
|
||||
|
||||
These tests verify the critical path that was causing bugs:
|
||||
- Node transformation from compiled format to artdag format
|
||||
- Asset/effect CID resolution from registry
|
||||
- Variable input name mapping
|
||||
- Input binding
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import pytest
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
# Standalone implementations of the functions for testing
|
||||
# This avoids importing the full app which requires external dependencies
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_variable_input(config: Dict[str, Any]) -> bool:
|
||||
"""Check if a SOURCE node config represents a variable input."""
|
||||
return bool(config.get("input"))
|
||||
|
||||
|
||||
def transform_node(
|
||||
node: Dict[str, Any],
|
||||
assets: Dict[str, Dict[str, Any]],
|
||||
effects: Dict[str, Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform a compiled node to artdag execution format."""
|
||||
node_id = node.get("id", "")
|
||||
config = dict(node.get("config", {}))
|
||||
|
||||
if node.get("type") == "SOURCE" and "asset" in config:
|
||||
asset_name = config["asset"]
|
||||
if asset_name in assets:
|
||||
config["cid"] = assets[asset_name].get("cid")
|
||||
|
||||
if node.get("type") == "EFFECT" and "effect" in config:
|
||||
effect_name = config["effect"]
|
||||
if effect_name in effects:
|
||||
config["cid"] = effects[effect_name].get("cid")
|
||||
|
||||
return {
|
||||
"node_id": node_id,
|
||||
"node_type": node.get("type", "EFFECT"),
|
||||
"config": config,
|
||||
"inputs": node.get("inputs", []),
|
||||
"name": node.get("name"),
|
||||
}
|
||||
|
||||
|
||||
def build_input_name_mapping(nodes: Dict[str, Dict[str, Any]]) -> Dict[str, str]:
|
||||
"""Build a mapping from input names to node IDs for variable inputs."""
|
||||
input_name_to_node: Dict[str, str] = {}
|
||||
|
||||
for node_id, node in nodes.items():
|
||||
if node.get("node_type") != "SOURCE":
|
||||
continue
|
||||
|
||||
config = node.get("config", {})
|
||||
if not is_variable_input(config):
|
||||
continue
|
||||
|
||||
input_name_to_node[node_id] = node_id
|
||||
|
||||
name = config.get("name")
|
||||
if name:
|
||||
input_name_to_node[name] = node_id
|
||||
input_name_to_node[name.lower().replace(" ", "_")] = node_id
|
||||
input_name_to_node[name.lower().replace(" ", "-")] = node_id
|
||||
|
||||
node_name = node.get("name")
|
||||
if node_name:
|
||||
input_name_to_node[node_name] = node_id
|
||||
input_name_to_node[node_name.replace("-", "_")] = node_id
|
||||
|
||||
return input_name_to_node
|
||||
|
||||
|
||||
def bind_inputs(
|
||||
nodes: Dict[str, Dict[str, Any]],
|
||||
input_name_to_node: Dict[str, str],
|
||||
user_inputs: Dict[str, str],
|
||||
) -> List[str]:
|
||||
"""Bind user-provided input CIDs to source nodes."""
|
||||
warnings: List[str] = []
|
||||
|
||||
for input_name, cid in user_inputs.items():
|
||||
if input_name in nodes:
|
||||
node = nodes[input_name]
|
||||
if node.get("node_type") == "SOURCE":
|
||||
node["config"]["cid"] = cid
|
||||
continue
|
||||
|
||||
if input_name in input_name_to_node:
|
||||
node_id = input_name_to_node[input_name]
|
||||
node = nodes[node_id]
|
||||
node["config"]["cid"] = cid
|
||||
continue
|
||||
|
||||
warnings.append(f"Input '{input_name}' not found in recipe")
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
def prepare_dag_for_execution(
|
||||
recipe: Dict[str, Any],
|
||||
user_inputs: Dict[str, str],
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""Prepare a recipe DAG for execution."""
|
||||
recipe_dag = recipe.get("dag")
|
||||
if not recipe_dag or not isinstance(recipe_dag, dict):
|
||||
raise ValueError("Recipe has no DAG definition")
|
||||
|
||||
dag_copy = json.loads(json.dumps(recipe_dag))
|
||||
nodes = dag_copy.get("nodes", {})
|
||||
|
||||
registry = recipe.get("registry", {})
|
||||
assets = registry.get("assets", {}) if registry else {}
|
||||
effects = registry.get("effects", {}) if registry else {}
|
||||
|
||||
if isinstance(nodes, list):
|
||||
nodes_dict: Dict[str, Dict[str, Any]] = {}
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
nodes_dict[node_id] = transform_node(node, assets, effects)
|
||||
nodes = nodes_dict
|
||||
dag_copy["nodes"] = nodes
|
||||
|
||||
input_name_to_node = build_input_name_mapping(nodes)
|
||||
warnings = bind_inputs(nodes, input_name_to_node, user_inputs)
|
||||
|
||||
if "output" in dag_copy:
|
||||
dag_copy["output_id"] = dag_copy.pop("output")
|
||||
|
||||
if "metadata" not in dag_copy:
|
||||
dag_copy["metadata"] = {}
|
||||
|
||||
return json.dumps(dag_copy), warnings
|
||||
|
||||
|
||||
class TestTransformNode:
|
||||
"""Tests for transform_node function."""
|
||||
|
||||
def test_source_node_with_asset_resolves_cid(
|
||||
self,
|
||||
sample_registry: Dict[str, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""SOURCE nodes with asset reference should get CID from registry."""
|
||||
node = {
|
||||
"id": "source_1",
|
||||
"type": "SOURCE",
|
||||
"config": {"asset": "cat"},
|
||||
"inputs": [],
|
||||
}
|
||||
assets = sample_registry["assets"]
|
||||
effects = sample_registry["effects"]
|
||||
|
||||
result = transform_node(node, assets, effects)
|
||||
|
||||
assert result["node_id"] == "source_1"
|
||||
assert result["node_type"] == "SOURCE"
|
||||
assert result["config"]["cid"] == "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ"
|
||||
|
||||
def test_effect_node_resolves_cid(
|
||||
self,
|
||||
sample_registry: Dict[str, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""EFFECT nodes should get CID from registry."""
|
||||
node = {
|
||||
"id": "effect_1",
|
||||
"type": "EFFECT",
|
||||
"config": {"effect": "invert", "intensity": 1.0},
|
||||
"inputs": ["source_1"],
|
||||
}
|
||||
assets = sample_registry["assets"]
|
||||
effects = sample_registry["effects"]
|
||||
|
||||
result = transform_node(node, assets, effects)
|
||||
|
||||
assert result["node_id"] == "effect_1"
|
||||
assert result["node_type"] == "EFFECT"
|
||||
assert result["config"]["effect"] == "invert"
|
||||
assert result["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
assert result["config"]["intensity"] == 1.0
|
||||
|
||||
def test_variable_input_node_preserves_config(
|
||||
self,
|
||||
sample_registry: Dict[str, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Variable input SOURCE nodes should preserve their config."""
|
||||
node = {
|
||||
"id": "source_2",
|
||||
"type": "SOURCE",
|
||||
"config": {
|
||||
"input": True,
|
||||
"name": "Second Video",
|
||||
"description": "User-provided video",
|
||||
},
|
||||
"inputs": [],
|
||||
"name": "second-video",
|
||||
}
|
||||
assets = sample_registry["assets"]
|
||||
effects = sample_registry["effects"]
|
||||
|
||||
result = transform_node(node, assets, effects)
|
||||
|
||||
assert result["config"]["input"] is True
|
||||
assert result["config"]["name"] == "Second Video"
|
||||
assert result["name"] == "second-video"
|
||||
# No CID yet - will be bound at runtime
|
||||
assert "cid" not in result["config"]
|
||||
|
||||
def test_unknown_asset_not_resolved(
|
||||
self,
|
||||
sample_registry: Dict[str, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""Unknown assets should not crash, just not get CID."""
|
||||
node = {
|
||||
"id": "source_1",
|
||||
"type": "SOURCE",
|
||||
"config": {"asset": "unknown_asset"},
|
||||
"inputs": [],
|
||||
}
|
||||
assets = sample_registry["assets"]
|
||||
effects = sample_registry["effects"]
|
||||
|
||||
result = transform_node(node, assets, effects)
|
||||
|
||||
assert "cid" not in result["config"]
|
||||
|
||||
|
||||
class TestBuildInputNameMapping:
|
||||
"""Tests for build_input_name_mapping function."""
|
||||
|
||||
def test_maps_by_node_id(self) -> None:
|
||||
"""Should map by node_id."""
|
||||
nodes = {
|
||||
"source_2": {
|
||||
"node_id": "source_2",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"input": True, "name": "Test"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
|
||||
mapping = build_input_name_mapping(nodes)
|
||||
|
||||
assert mapping["source_2"] == "source_2"
|
||||
|
||||
def test_maps_by_config_name(self) -> None:
|
||||
"""Should map by config.name with various formats."""
|
||||
nodes = {
|
||||
"source_2": {
|
||||
"node_id": "source_2",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"input": True, "name": "Second Video"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
|
||||
mapping = build_input_name_mapping(nodes)
|
||||
|
||||
assert mapping["Second Video"] == "source_2"
|
||||
assert mapping["second_video"] == "source_2"
|
||||
assert mapping["second-video"] == "source_2"
|
||||
|
||||
def test_maps_by_def_name(self) -> None:
|
||||
"""Should map by node.name (def binding name)."""
|
||||
nodes = {
|
||||
"source_2": {
|
||||
"node_id": "source_2",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"input": True, "name": "Second Video"},
|
||||
"inputs": [],
|
||||
"name": "inverted-video",
|
||||
}
|
||||
}
|
||||
|
||||
mapping = build_input_name_mapping(nodes)
|
||||
|
||||
assert mapping["inverted-video"] == "source_2"
|
||||
assert mapping["inverted_video"] == "source_2"
|
||||
|
||||
def test_ignores_non_source_nodes(self) -> None:
|
||||
"""Should not include non-SOURCE nodes."""
|
||||
nodes = {
|
||||
"effect_1": {
|
||||
"node_id": "effect_1",
|
||||
"node_type": "EFFECT",
|
||||
"config": {"effect": "invert"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
|
||||
mapping = build_input_name_mapping(nodes)
|
||||
|
||||
assert "effect_1" not in mapping
|
||||
|
||||
def test_ignores_fixed_sources(self) -> None:
|
||||
"""Should not include SOURCE nodes without 'input' flag."""
|
||||
nodes = {
|
||||
"source_1": {
|
||||
"node_id": "source_1",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"asset": "cat", "cid": "Qm123"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
|
||||
mapping = build_input_name_mapping(nodes)
|
||||
|
||||
assert "source_1" not in mapping
|
||||
|
||||
|
||||
class TestBindInputs:
|
||||
"""Tests for bind_inputs function."""
|
||||
|
||||
def test_binds_by_direct_node_id(self) -> None:
|
||||
"""Should bind when using node_id directly."""
|
||||
nodes = {
|
||||
"source_2": {
|
||||
"node_id": "source_2",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"input": True, "name": "Test"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
mapping = {"source_2": "source_2"}
|
||||
user_inputs = {"source_2": "QmUserInput123"}
|
||||
|
||||
warnings = bind_inputs(nodes, mapping, user_inputs)
|
||||
|
||||
assert nodes["source_2"]["config"]["cid"] == "QmUserInput123"
|
||||
assert len(warnings) == 0
|
||||
|
||||
def test_binds_by_name_lookup(self) -> None:
|
||||
"""Should bind when using input name."""
|
||||
nodes = {
|
||||
"source_2": {
|
||||
"node_id": "source_2",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"input": True, "name": "Second Video"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
mapping = {
|
||||
"source_2": "source_2",
|
||||
"Second Video": "source_2",
|
||||
"second-video": "source_2",
|
||||
}
|
||||
user_inputs = {"second-video": "QmUserInput123"}
|
||||
|
||||
warnings = bind_inputs(nodes, mapping, user_inputs)
|
||||
|
||||
assert nodes["source_2"]["config"]["cid"] == "QmUserInput123"
|
||||
assert len(warnings) == 0
|
||||
|
||||
def test_warns_on_unknown_input(self) -> None:
|
||||
"""Should warn when input name not found."""
|
||||
nodes = {
|
||||
"source_2": {
|
||||
"node_id": "source_2",
|
||||
"node_type": "SOURCE",
|
||||
"config": {"input": True, "name": "Test"},
|
||||
"inputs": [],
|
||||
}
|
||||
}
|
||||
mapping = {"source_2": "source_2", "Test": "source_2"}
|
||||
user_inputs = {"unknown-input": "QmUserInput123"}
|
||||
|
||||
warnings = bind_inputs(nodes, mapping, user_inputs)
|
||||
|
||||
assert len(warnings) == 1
|
||||
assert "unknown-input" in warnings[0]
|
||||
|
||||
|
||||
class TestRegressions:
|
||||
"""Tests for specific bugs that were found in production."""
|
||||
|
||||
def test_effect_cid_key_not_effect_hash(
|
||||
self,
|
||||
sample_registry: Dict[str, Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Regression test: Effect CID must use 'cid' key, not 'effect_hash'.
|
||||
|
||||
Bug found 2026-01-12: The transform_node function was setting
|
||||
config["effect_hash"] but the executor looks for config["cid"].
|
||||
This caused "Unknown effect: invert" errors.
|
||||
"""
|
||||
node = {
|
||||
"id": "effect_1",
|
||||
"type": "EFFECT",
|
||||
"config": {"effect": "invert"},
|
||||
"inputs": ["source_1"],
|
||||
}
|
||||
assets = sample_registry["assets"]
|
||||
effects = sample_registry["effects"]
|
||||
|
||||
result = transform_node(node, assets, effects)
|
||||
|
||||
# MUST use 'cid' key - executor checks config.get("cid")
|
||||
assert "cid" in result["config"], "Effect CID must be stored as 'cid'"
|
||||
assert "effect_hash" not in result["config"], "Must not use 'effect_hash' key"
|
||||
assert result["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
|
||||
def test_source_cid_binding_persists(
|
||||
self,
|
||||
sample_recipe: Dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Regression test: Bound CIDs must appear in final DAG JSON.
|
||||
|
||||
Bug found 2026-01-12: Variable input CIDs were being bound but
|
||||
not appearing in the serialized DAG sent to the executor.
|
||||
"""
|
||||
user_inputs = {"second-video": "QmTestUserInput123"}
|
||||
|
||||
dag_json, _ = prepare_dag_for_execution(sample_recipe, user_inputs)
|
||||
dag = json.loads(dag_json)
|
||||
|
||||
# The bound CID must be in the final JSON
|
||||
source_2 = dag["nodes"]["source_2"]
|
||||
assert source_2["config"]["cid"] == "QmTestUserInput123"
|
||||
|
||||
|
||||
class TestPrepareDagForExecution:
|
||||
"""Integration tests for prepare_dag_for_execution."""
|
||||
|
||||
def test_full_pipeline(self, sample_recipe: Dict[str, Any]) -> None:
|
||||
"""Test the full DAG preparation pipeline."""
|
||||
user_inputs = {
|
||||
"second-video": "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS"
|
||||
}
|
||||
|
||||
dag_json, warnings = prepare_dag_for_execution(sample_recipe, user_inputs)
|
||||
|
||||
# Parse result
|
||||
dag = json.loads(dag_json)
|
||||
|
||||
# Check structure
|
||||
assert "nodes" in dag
|
||||
assert "output_id" in dag
|
||||
assert dag["output_id"] == "sequence_1"
|
||||
|
||||
# Check fixed source has CID
|
||||
source_1 = dag["nodes"]["source_1"]
|
||||
assert source_1["config"]["cid"] == "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ"
|
||||
|
||||
# Check variable input was bound
|
||||
source_2 = dag["nodes"]["source_2"]
|
||||
assert source_2["config"]["cid"] == "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS"
|
||||
|
||||
# Check effects have CIDs
|
||||
effect_1 = dag["nodes"]["effect_1"]
|
||||
assert effect_1["config"]["cid"] == "QmcWhw6wbHr1GDmorM2KDz8S3yCGTfjuyPR6y8khS2tvko"
|
||||
|
||||
effect_2 = dag["nodes"]["effect_2"]
|
||||
assert effect_2["config"]["cid"] == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
assert effect_2["config"]["intensity"] == 1.0
|
||||
|
||||
# No warnings
|
||||
assert len(warnings) == 0
|
||||
|
||||
def test_missing_input_produces_warning(
|
||||
self,
|
||||
sample_recipe: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Missing inputs should produce warnings but not fail."""
|
||||
user_inputs = {} # No inputs provided
|
||||
|
||||
dag_json, warnings = prepare_dag_for_execution(sample_recipe, user_inputs)
|
||||
|
||||
# Should still produce valid JSON
|
||||
dag = json.loads(dag_json)
|
||||
assert "nodes" in dag
|
||||
|
||||
# Variable input should not have CID
|
||||
source_2 = dag["nodes"]["source_2"]
|
||||
assert "cid" not in source_2["config"]
|
||||
|
||||
def test_raises_on_missing_dag(self) -> None:
|
||||
"""Should raise ValueError if recipe has no DAG."""
|
||||
recipe = {"name": "broken", "registry": {}}
|
||||
|
||||
with pytest.raises(ValueError, match="no DAG"):
|
||||
prepare_dag_for_execution(recipe, {})
|
||||
327
l1/tests/test_effect_loading.py
Normal file
327
l1/tests/test_effect_loading.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Tests for effect loading from cache and IPFS.
|
||||
|
||||
These tests verify that:
|
||||
- Effects can be loaded from the local cache directory
|
||||
- IPFS gateway configuration is correct for Docker environments
|
||||
- The effect executor correctly resolves CIDs from config
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# Minimal effect loading implementation for testing
|
||||
# This mirrors the logic in artdag/nodes/effect.py
|
||||
|
||||
|
||||
def get_effects_cache_dir_impl(env_vars: Dict[str, str]) -> Optional[Path]:
|
||||
"""Get the effects cache directory from environment or default."""
|
||||
for env_var in ["CACHE_DIR", "ARTDAG_CACHE_DIR"]:
|
||||
cache_dir = env_vars.get(env_var)
|
||||
if cache_dir:
|
||||
effects_dir = Path(cache_dir) / "_effects"
|
||||
if effects_dir.exists():
|
||||
return effects_dir
|
||||
|
||||
# Try default locations
|
||||
for base in [Path.home() / ".artdag" / "cache", Path("/var/cache/artdag")]:
|
||||
effects_dir = base / "_effects"
|
||||
if effects_dir.exists():
|
||||
return effects_dir
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def effect_path_for_cid(effects_dir: Path, effect_cid: str) -> Path:
|
||||
"""Get the expected path for an effect given its CID."""
|
||||
return effects_dir / effect_cid / "effect.py"
|
||||
|
||||
|
||||
class TestEffectCacheDirectory:
|
||||
"""Tests for effect cache directory resolution."""
|
||||
|
||||
def test_cache_dir_from_env(self, tmp_path: Path) -> None:
|
||||
"""CACHE_DIR env var should determine effects directory."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effects_dir.mkdir(parents=True)
|
||||
|
||||
env = {"CACHE_DIR": str(tmp_path)}
|
||||
result = get_effects_cache_dir_impl(env)
|
||||
|
||||
assert result == effects_dir
|
||||
|
||||
def test_artdag_cache_dir_fallback(self, tmp_path: Path) -> None:
|
||||
"""ARTDAG_CACHE_DIR should work as fallback."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effects_dir.mkdir(parents=True)
|
||||
|
||||
env = {"ARTDAG_CACHE_DIR": str(tmp_path)}
|
||||
result = get_effects_cache_dir_impl(env)
|
||||
|
||||
assert result == effects_dir
|
||||
|
||||
def test_no_env_returns_none_if_no_default_exists(self) -> None:
|
||||
"""Should return None if no cache directory exists."""
|
||||
env = {}
|
||||
result = get_effects_cache_dir_impl(env)
|
||||
|
||||
# Will return None unless default dirs exist
|
||||
# This is expected behavior
|
||||
if result is not None:
|
||||
assert result.exists()
|
||||
|
||||
|
||||
class TestEffectPathResolution:
|
||||
"""Tests for effect path resolution."""
|
||||
|
||||
def test_effect_path_structure(self, tmp_path: Path) -> None:
|
||||
"""Effect should be at _effects/{cid}/effect.py."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effect_cid = "QmTestEffect123"
|
||||
|
||||
path = effect_path_for_cid(effects_dir, effect_cid)
|
||||
|
||||
assert path == effects_dir / effect_cid / "effect.py"
|
||||
|
||||
def test_effect_file_exists_after_upload(self, tmp_path: Path) -> None:
|
||||
"""After upload, effect.py should exist in the right location."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
|
||||
# Simulate effect upload (as done by app/routers/effects.py)
|
||||
effect_dir = effects_dir / effect_cid
|
||||
effect_dir.mkdir(parents=True)
|
||||
effect_source = '''"""
|
||||
@effect invert
|
||||
@version 1.0.0
|
||||
"""
|
||||
|
||||
def process_frame(frame, params, state):
|
||||
return 255 - frame, state
|
||||
'''
|
||||
(effect_dir / "effect.py").write_text(effect_source)
|
||||
|
||||
# Verify the path structure
|
||||
expected_path = effect_path_for_cid(effects_dir, effect_cid)
|
||||
assert expected_path.exists()
|
||||
assert "process_frame" in expected_path.read_text()
|
||||
|
||||
|
||||
class TestIPFSAPIConfiguration:
|
||||
"""Tests for IPFS API configuration (consistent across codebase)."""
|
||||
|
||||
def test_ipfs_api_multiaddr_conversion(self) -> None:
|
||||
"""
|
||||
IPFS_API multiaddr should convert to correct URL.
|
||||
|
||||
Both ipfs_client.py and artdag/nodes/effect.py now use IPFS_API
|
||||
with multiaddr format for consistency.
|
||||
"""
|
||||
import re
|
||||
|
||||
def multiaddr_to_url(multiaddr: str) -> str:
|
||||
"""Convert multiaddr to URL (same logic as ipfs_client.py)."""
|
||||
dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if dns_match:
|
||||
return f"http://{dns_match.group(1)}:{dns_match.group(2)}"
|
||||
ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if ip4_match:
|
||||
return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}"
|
||||
return "http://127.0.0.1:5001"
|
||||
|
||||
# Docker config
|
||||
docker_api = "/dns/ipfs/tcp/5001"
|
||||
url = multiaddr_to_url(docker_api)
|
||||
assert url == "http://ipfs:5001"
|
||||
|
||||
# Local dev config
|
||||
local_api = "/ip4/127.0.0.1/tcp/5001"
|
||||
url = multiaddr_to_url(local_api)
|
||||
assert url == "http://127.0.0.1:5001"
|
||||
|
||||
def test_all_ipfs_access_uses_api_not_gateway(self) -> None:
|
||||
"""
|
||||
All IPFS access should use IPFS_API (port 5001), not IPFS_GATEWAY (port 8080).
|
||||
|
||||
Fixed 2026-01-12: artdag/nodes/effect.py was using a separate IPFS_GATEWAY
|
||||
variable. Now it uses IPFS_API like ipfs_client.py for consistency.
|
||||
"""
|
||||
# The API endpoint that both modules use
|
||||
api_endpoint = "/api/v0/cat"
|
||||
|
||||
# This is correct - using the API
|
||||
assert "api/v0" in api_endpoint
|
||||
|
||||
# Gateway endpoint would be /ipfs/{cid} - we don't use this anymore
|
||||
gateway_pattern = "/ipfs/"
|
||||
assert gateway_pattern not in api_endpoint
|
||||
|
||||
|
||||
class TestEffectExecutorConfigResolution:
|
||||
"""Tests for how the effect executor resolves CID from config."""
|
||||
|
||||
def test_executor_should_use_cid_key(self) -> None:
|
||||
"""
|
||||
Effect executor must look for 'cid' key in config.
|
||||
|
||||
The transform_node function sets config["cid"] for effects.
|
||||
The executor must read from the same key.
|
||||
"""
|
||||
config = {
|
||||
"effect": "invert",
|
||||
"cid": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J",
|
||||
"intensity": 1.0,
|
||||
}
|
||||
|
||||
# Simulate executor CID extraction (from artdag/nodes/effect.py:258)
|
||||
effect_cid = config.get("cid") or config.get("hash")
|
||||
|
||||
assert effect_cid == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
|
||||
def test_executor_should_not_use_effect_hash(self) -> None:
|
||||
"""
|
||||
Regression test: 'effect_hash' is not a valid config key.
|
||||
|
||||
Bug found 2026-01-12: transform_node was using config["effect_hash"]
|
||||
but executor only checks config["cid"] or config["hash"].
|
||||
"""
|
||||
config = {
|
||||
"effect": "invert",
|
||||
"effect_hash": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J",
|
||||
}
|
||||
|
||||
# This simulates the buggy behavior where effect_hash was set
|
||||
# but executor doesn't look for it
|
||||
effect_cid = config.get("cid") or config.get("hash")
|
||||
|
||||
# The bug: effect_hash is ignored, effect_cid is None
|
||||
assert effect_cid is None, "effect_hash should NOT be recognized"
|
||||
|
||||
def test_hash_key_is_legacy_fallback(self) -> None:
|
||||
"""'hash' key should work as legacy fallback for 'cid'."""
|
||||
config = {
|
||||
"effect": "invert",
|
||||
"hash": "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J",
|
||||
}
|
||||
|
||||
effect_cid = config.get("cid") or config.get("hash")
|
||||
|
||||
assert effect_cid == "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
|
||||
|
||||
class TestEffectLoadingIntegration:
|
||||
"""Integration tests for complete effect loading path."""
|
||||
|
||||
def test_effect_loads_from_cache_when_present(self, tmp_path: Path) -> None:
|
||||
"""Effect should load from cache without hitting IPFS."""
|
||||
effects_dir = tmp_path / "_effects"
|
||||
effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
|
||||
# Create effect file in cache
|
||||
effect_dir = effects_dir / effect_cid
|
||||
effect_dir.mkdir(parents=True)
|
||||
(effect_dir / "effect.py").write_text('''
|
||||
def process_frame(frame, params, state):
|
||||
"""Invert colors."""
|
||||
return 255 - frame, state
|
||||
''')
|
||||
|
||||
# Verify the effect can be found
|
||||
effect_path = effect_path_for_cid(effects_dir, effect_cid)
|
||||
assert effect_path.exists()
|
||||
|
||||
# Load and verify it has the expected function
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location("test_effect", effect_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
assert hasattr(module, "process_frame")
|
||||
|
||||
def test_effect_fetch_uses_ipfs_api(self, tmp_path: Path) -> None:
|
||||
"""Effect fetch should use IPFS API endpoint, not gateway."""
|
||||
import re
|
||||
|
||||
def multiaddr_to_url(multiaddr: str) -> str:
|
||||
dns_match = re.match(r"/dns[46]?/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if dns_match:
|
||||
return f"http://{dns_match.group(1)}:{dns_match.group(2)}"
|
||||
ip4_match = re.match(r"/ip4/([^/]+)/tcp/(\d+)", multiaddr)
|
||||
if ip4_match:
|
||||
return f"http://{ip4_match.group(1)}:{ip4_match.group(2)}"
|
||||
return "http://127.0.0.1:5001"
|
||||
|
||||
# In Docker, IPFS_API=/dns/ipfs/tcp/5001
|
||||
docker_multiaddr = "/dns/ipfs/tcp/5001"
|
||||
base_url = multiaddr_to_url(docker_multiaddr)
|
||||
effect_cid = "QmTestCid123"
|
||||
|
||||
# Should use API endpoint
|
||||
api_url = f"{base_url}/api/v0/cat?arg={effect_cid}"
|
||||
|
||||
assert "ipfs:5001" in api_url
|
||||
assert "/api/v0/cat" in api_url
|
||||
assert "127.0.0.1" not in api_url
|
||||
|
||||
|
||||
class TestSharedVolumeScenario:
|
||||
"""
|
||||
Tests simulating the Docker shared volume scenario.
|
||||
|
||||
In Docker:
|
||||
- l1-server uploads effect to /data/cache/_effects/{cid}/effect.py
|
||||
- l1-worker should find it at the same path via shared volume
|
||||
"""
|
||||
|
||||
def test_effect_visible_on_shared_volume(self, tmp_path: Path) -> None:
|
||||
"""Effect uploaded on server should be visible to worker."""
|
||||
# Simulate shared volume mounted at /data/cache on both containers
|
||||
shared_volume = tmp_path / "data" / "cache"
|
||||
effects_dir = shared_volume / "_effects"
|
||||
|
||||
# Server uploads effect
|
||||
effect_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
effect_upload_dir = effects_dir / effect_cid
|
||||
effect_upload_dir.mkdir(parents=True)
|
||||
(effect_upload_dir / "effect.py").write_text('def process_frame(f, p, s): return f, s')
|
||||
(effect_upload_dir / "metadata.json").write_text('{"cid": "' + effect_cid + '"}')
|
||||
|
||||
# Worker should find the effect
|
||||
env_vars = {"CACHE_DIR": str(shared_volume)}
|
||||
worker_effects_dir = get_effects_cache_dir_impl(env_vars)
|
||||
|
||||
assert worker_effects_dir is not None
|
||||
assert worker_effects_dir == effects_dir
|
||||
|
||||
worker_effect_path = effect_path_for_cid(worker_effects_dir, effect_cid)
|
||||
assert worker_effect_path.exists()
|
||||
|
||||
def test_effect_cid_matches_registry(self, tmp_path: Path) -> None:
|
||||
"""CID in recipe registry must match the uploaded effect directory name."""
|
||||
shared_volume = tmp_path
|
||||
effects_dir = shared_volume / "_effects"
|
||||
|
||||
# The CID used in the recipe registry
|
||||
registry_cid = "QmPWaW5E5WFrmDjT6w8enqvtJhM8c5jvQu7XN1doHA3Z7J"
|
||||
|
||||
# Upload creates directory with CID as name
|
||||
effect_upload_dir = effects_dir / registry_cid
|
||||
effect_upload_dir.mkdir(parents=True)
|
||||
(effect_upload_dir / "effect.py").write_text('def process_frame(f, p, s): return f, s')
|
||||
|
||||
# Executor receives the same CID from DAG config
|
||||
dag_config_cid = registry_cid # This comes from transform_node
|
||||
|
||||
# These must match for the lookup to work
|
||||
assert dag_config_cid == registry_cid
|
||||
|
||||
# And the path must exist
|
||||
lookup_path = effects_dir / dag_config_cid / "effect.py"
|
||||
assert lookup_path.exists()
|
||||
367
l1/tests/test_effects_web.py
Normal file
367
l1/tests/test_effects_web.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Tests for Effects web UI.
|
||||
|
||||
Tests effect metadata parsing, listing, and templates.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import re
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def parse_effect_metadata_standalone(source: str) -> dict:
|
||||
"""
|
||||
Standalone copy of parse_effect_metadata for testing.
|
||||
|
||||
This avoids import issues with the router module.
|
||||
"""
|
||||
metadata = {
|
||||
"name": "",
|
||||
"version": "1.0.0",
|
||||
"author": "",
|
||||
"temporal": False,
|
||||
"description": "",
|
||||
"params": [],
|
||||
"dependencies": [],
|
||||
"requires_python": ">=3.10",
|
||||
}
|
||||
|
||||
# Parse PEP 723 dependencies
|
||||
pep723_match = re.search(r"# /// script\n(.*?)# ///", source, re.DOTALL)
|
||||
if pep723_match:
|
||||
block = pep723_match.group(1)
|
||||
deps_match = re.search(r'# dependencies = \[(.*?)\]', block, re.DOTALL)
|
||||
if deps_match:
|
||||
metadata["dependencies"] = re.findall(r'"([^"]+)"', deps_match.group(1))
|
||||
python_match = re.search(r'# requires-python = "([^"]+)"', block)
|
||||
if python_match:
|
||||
metadata["requires_python"] = python_match.group(1)
|
||||
|
||||
# Parse docstring @-tags
|
||||
docstring_match = re.search(r'"""(.*?)"""', source, re.DOTALL)
|
||||
if not docstring_match:
|
||||
docstring_match = re.search(r"'''(.*?)'''", source, re.DOTALL)
|
||||
|
||||
if docstring_match:
|
||||
docstring = docstring_match.group(1)
|
||||
lines = docstring.split("\n")
|
||||
|
||||
current_param = None
|
||||
desc_lines = []
|
||||
in_description = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
if stripped.startswith("@effect "):
|
||||
metadata["name"] = stripped[8:].strip()
|
||||
in_description = False
|
||||
|
||||
elif stripped.startswith("@version "):
|
||||
metadata["version"] = stripped[9:].strip()
|
||||
|
||||
elif stripped.startswith("@author "):
|
||||
metadata["author"] = stripped[8:].strip()
|
||||
|
||||
elif stripped.startswith("@temporal "):
|
||||
val = stripped[10:].strip().lower()
|
||||
metadata["temporal"] = val in ("true", "yes", "1")
|
||||
|
||||
elif stripped.startswith("@description"):
|
||||
in_description = True
|
||||
desc_lines = []
|
||||
|
||||
elif stripped.startswith("@param "):
|
||||
if in_description:
|
||||
metadata["description"] = " ".join(desc_lines)
|
||||
in_description = False
|
||||
if current_param:
|
||||
metadata["params"].append(current_param)
|
||||
parts = stripped[7:].split()
|
||||
if len(parts) >= 2:
|
||||
current_param = {
|
||||
"name": parts[0],
|
||||
"type": parts[1],
|
||||
"description": "",
|
||||
}
|
||||
else:
|
||||
current_param = None
|
||||
|
||||
elif stripped.startswith("@range ") and current_param:
|
||||
range_parts = stripped[7:].split()
|
||||
if len(range_parts) >= 2:
|
||||
try:
|
||||
current_param["range"] = [float(range_parts[0]), float(range_parts[1])]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
elif stripped.startswith("@default ") and current_param:
|
||||
current_param["default"] = stripped[9:].strip()
|
||||
|
||||
elif stripped.startswith("@example"):
|
||||
if in_description:
|
||||
metadata["description"] = " ".join(desc_lines)
|
||||
in_description = False
|
||||
if current_param:
|
||||
metadata["params"].append(current_param)
|
||||
current_param = None
|
||||
|
||||
elif in_description and stripped:
|
||||
desc_lines.append(stripped)
|
||||
|
||||
elif current_param and stripped and not stripped.startswith("@"):
|
||||
current_param["description"] = stripped
|
||||
|
||||
if in_description:
|
||||
metadata["description"] = " ".join(desc_lines)
|
||||
|
||||
if current_param:
|
||||
metadata["params"].append(current_param)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
class TestEffectMetadataParsing:
|
||||
"""Tests for parse_effect_metadata function."""
|
||||
|
||||
def test_parses_pep723_dependencies(self) -> None:
|
||||
"""Should extract dependencies from PEP 723 script block."""
|
||||
source = '''
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = ["numpy", "opencv-python"]
|
||||
# ///
|
||||
"""
|
||||
@effect test_effect
|
||||
"""
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
meta = parse_effect_metadata_standalone(source)
|
||||
|
||||
assert meta["dependencies"] == ["numpy", "opencv-python"]
|
||||
assert meta["requires_python"] == ">=3.10"
|
||||
|
||||
def test_parses_effect_name(self) -> None:
|
||||
"""Should extract effect name from @effect tag."""
|
||||
source = '''
|
||||
"""
|
||||
@effect brightness
|
||||
@version 2.0.0
|
||||
@author @artist@example.com
|
||||
"""
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
meta = parse_effect_metadata_standalone(source)
|
||||
|
||||
assert meta["name"] == "brightness"
|
||||
assert meta["version"] == "2.0.0"
|
||||
assert meta["author"] == "@artist@example.com"
|
||||
|
||||
def test_parses_parameters(self) -> None:
|
||||
"""Should extract parameter definitions."""
|
||||
source = '''
|
||||
"""
|
||||
@effect brightness
|
||||
@param level float
|
||||
@range -1.0 1.0
|
||||
@default 0.0
|
||||
Brightness adjustment level
|
||||
"""
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
meta = parse_effect_metadata_standalone(source)
|
||||
|
||||
assert len(meta["params"]) == 1
|
||||
param = meta["params"][0]
|
||||
assert param["name"] == "level"
|
||||
assert param["type"] == "float"
|
||||
assert param["range"] == [-1.0, 1.0]
|
||||
assert param["default"] == "0.0"
|
||||
|
||||
def test_parses_temporal_flag(self) -> None:
|
||||
"""Should parse temporal flag correctly."""
|
||||
source_temporal = '''
|
||||
"""
|
||||
@effect motion_blur
|
||||
@temporal true
|
||||
"""
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
source_not_temporal = '''
|
||||
"""
|
||||
@effect brightness
|
||||
@temporal false
|
||||
"""
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
|
||||
assert parse_effect_metadata_standalone(source_temporal)["temporal"] is True
|
||||
assert parse_effect_metadata_standalone(source_not_temporal)["temporal"] is False
|
||||
|
||||
def test_handles_missing_metadata(self) -> None:
|
||||
"""Should return sensible defaults for minimal source."""
|
||||
source = '''
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
meta = parse_effect_metadata_standalone(source)
|
||||
|
||||
assert meta["name"] == ""
|
||||
assert meta["version"] == "1.0.0"
|
||||
assert meta["dependencies"] == []
|
||||
assert meta["params"] == []
|
||||
|
||||
def test_parses_description(self) -> None:
|
||||
"""Should extract description text."""
|
||||
source = '''
|
||||
"""
|
||||
@effect test
|
||||
@description
|
||||
This is a multi-line
|
||||
description of the effect.
|
||||
@param x float
|
||||
"""
|
||||
def process_frame(frame, params, state):
|
||||
return frame, state
|
||||
'''
|
||||
meta = parse_effect_metadata_standalone(source)
|
||||
|
||||
assert "multi-line" in meta["description"]
|
||||
|
||||
|
||||
class TestHomePageEffectsCount:
|
||||
"""Test that effects count is shown on home page."""
|
||||
|
||||
def test_home_template_has_effects_card(self) -> None:
|
||||
"""Home page should display effects count."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/home.html')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'stats.effects' in content, \
|
||||
"Home page should display stats.effects count"
|
||||
assert 'href="/effects"' in content, \
|
||||
"Home page should link to /effects"
|
||||
|
||||
def test_home_route_provides_effects_count(self) -> None:
|
||||
"""Home route should provide effects count in stats."""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/home.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'stats["effects"]' in content, \
|
||||
"Home route should populate stats['effects']"
|
||||
|
||||
|
||||
class TestNavigationIncludesEffects:
|
||||
"""Test that Effects link is in navigation."""
|
||||
|
||||
def test_base_template_has_effects_link(self) -> None:
|
||||
"""Base template should have Effects navigation link."""
|
||||
base_path = Path('/home/giles/art/art-celery/app/templates/base.html')
|
||||
content = base_path.read_text()
|
||||
|
||||
assert 'href="/effects"' in content
|
||||
assert "Effects" in content
|
||||
assert "active_tab == 'effects'" in content
|
||||
|
||||
def test_effects_link_between_recipes_and_media(self) -> None:
|
||||
"""Effects link should be positioned between Recipes and Media."""
|
||||
base_path = Path('/home/giles/art/art-celery/app/templates/base.html')
|
||||
content = base_path.read_text()
|
||||
|
||||
recipes_pos = content.find('href="/recipes"')
|
||||
effects_pos = content.find('href="/effects"')
|
||||
media_pos = content.find('href="/media"')
|
||||
|
||||
assert recipes_pos < effects_pos < media_pos, \
|
||||
"Effects link should be between Recipes and Media"
|
||||
|
||||
|
||||
class TestEffectsTemplatesExist:
|
||||
"""Tests for effects template files."""
|
||||
|
||||
def test_list_template_exists(self) -> None:
|
||||
"""List template should exist."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/list.html')
|
||||
assert path.exists(), "effects/list.html template should exist"
|
||||
|
||||
def test_detail_template_exists(self) -> None:
|
||||
"""Detail template should exist."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html')
|
||||
assert path.exists(), "effects/detail.html template should exist"
|
||||
|
||||
def test_list_template_extends_base(self) -> None:
|
||||
"""List template should extend base.html."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/list.html')
|
||||
content = path.read_text()
|
||||
assert '{% extends "base.html" %}' in content
|
||||
|
||||
def test_detail_template_extends_base(self) -> None:
|
||||
"""Detail template should extend base.html."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html')
|
||||
content = path.read_text()
|
||||
assert '{% extends "base.html" %}' in content
|
||||
|
||||
def test_list_template_has_upload_button(self) -> None:
|
||||
"""List template should have upload functionality."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/list.html')
|
||||
content = path.read_text()
|
||||
assert 'Upload Effect' in content or 'upload' in content.lower()
|
||||
|
||||
def test_detail_template_shows_parameters(self) -> None:
|
||||
"""Detail template should display parameters."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html')
|
||||
content = path.read_text()
|
||||
assert 'params' in content.lower() or 'parameter' in content.lower()
|
||||
|
||||
def test_detail_template_shows_source_code(self) -> None:
|
||||
"""Detail template should show source code."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html')
|
||||
content = path.read_text()
|
||||
assert 'source' in content.lower()
|
||||
assert 'language-python' in content
|
||||
|
||||
def test_detail_template_shows_dependencies(self) -> None:
|
||||
"""Detail template should display dependencies."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/effects/detail.html')
|
||||
content = path.read_text()
|
||||
assert 'dependencies' in content.lower()
|
||||
|
||||
|
||||
class TestEffectsRouterExists:
|
||||
"""Tests for effects router configuration."""
|
||||
|
||||
def test_effects_router_file_exists(self) -> None:
|
||||
"""Effects router should exist."""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/effects.py')
|
||||
assert path.exists(), "effects.py router should exist"
|
||||
|
||||
def test_effects_router_has_list_endpoint(self) -> None:
|
||||
"""Effects router should have list endpoint."""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/effects.py')
|
||||
content = path.read_text()
|
||||
assert '@router.get("")' in content or "@router.get('')" in content
|
||||
|
||||
def test_effects_router_has_detail_endpoint(self) -> None:
|
||||
"""Effects router should have detail endpoint."""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/effects.py')
|
||||
content = path.read_text()
|
||||
assert '@router.get("/{cid}")' in content
|
||||
|
||||
def test_effects_router_has_upload_endpoint(self) -> None:
|
||||
"""Effects router should have upload endpoint."""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/effects.py')
|
||||
content = path.read_text()
|
||||
assert '@router.post("/upload")' in content
|
||||
|
||||
def test_effects_router_renders_templates(self) -> None:
|
||||
"""Effects router should render HTML templates."""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/effects.py')
|
||||
content = path.read_text()
|
||||
assert 'effects/list.html' in content
|
||||
assert 'effects/detail.html' in content
|
||||
529
l1/tests/test_execute_recipe.py
Normal file
529
l1/tests/test_execute_recipe.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
Tests for execute_recipe SOURCE node resolution.
|
||||
|
||||
These tests verify that SOURCE nodes with :input true are correctly
|
||||
resolved from input_hashes at execution time.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class MockStep:
|
||||
"""Mock ExecutionStep for testing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_id: str,
|
||||
node_type: str,
|
||||
config: Dict[str, Any],
|
||||
cache_id: str,
|
||||
input_steps: list = None,
|
||||
level: int = 0,
|
||||
):
|
||||
self.step_id = step_id
|
||||
self.node_type = node_type
|
||||
self.config = config
|
||||
self.cache_id = cache_id
|
||||
self.input_steps = input_steps or []
|
||||
self.level = level
|
||||
self.name = config.get("name")
|
||||
self.outputs = []
|
||||
|
||||
|
||||
class MockPlan:
|
||||
"""Mock ExecutionPlanSexp for testing."""
|
||||
|
||||
def __init__(self, steps: list, output_step_id: str, plan_id: str = "test-plan"):
|
||||
self.steps = steps
|
||||
self.output_step_id = output_step_id
|
||||
self.plan_id = plan_id
|
||||
|
||||
def to_string(self, pretty: bool = False) -> str:
|
||||
return "(plan test)"
|
||||
|
||||
|
||||
def resolve_source_cid(step_config: Dict[str, Any], input_hashes: Dict[str, str]) -> str:
|
||||
"""
|
||||
Resolve CID for a SOURCE node.
|
||||
|
||||
This is the logic that should be in execute_recipe - extracted here for unit testing.
|
||||
"""
|
||||
source_cid = step_config.get("cid")
|
||||
|
||||
# If source has :input true, resolve CID from input_hashes
|
||||
if not source_cid and step_config.get("input"):
|
||||
source_name = step_config.get("name", "")
|
||||
# Try various key formats for lookup
|
||||
name_variants = [
|
||||
source_name,
|
||||
source_name.lower().replace(" ", "-"),
|
||||
source_name.lower().replace(" ", "_"),
|
||||
source_name.lower(),
|
||||
]
|
||||
for variant in name_variants:
|
||||
if variant in input_hashes:
|
||||
source_cid = input_hashes[variant]
|
||||
break
|
||||
|
||||
if not source_cid:
|
||||
raise ValueError(
|
||||
f"SOURCE '{source_name}' not found in input_hashes. "
|
||||
f"Available: {list(input_hashes.keys())}"
|
||||
)
|
||||
|
||||
return source_cid
|
||||
|
||||
|
||||
class TestSourceCidResolution:
|
||||
"""Tests for SOURCE node CID resolution from input_hashes."""
|
||||
|
||||
def test_source_with_fixed_cid(self):
|
||||
"""SOURCE with :cid should use that directly."""
|
||||
config = {"cid": "QmFixedCid123"}
|
||||
input_hashes = {}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmFixedCid123"
|
||||
|
||||
def test_source_with_input_true_exact_match(self):
|
||||
"""SOURCE with :input true should resolve from input_hashes by exact name."""
|
||||
config = {"input": True, "name": "my-video"}
|
||||
input_hashes = {"my-video": "QmInputVideo456"}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmInputVideo456"
|
||||
|
||||
def test_source_with_input_true_normalized_dash(self):
|
||||
"""SOURCE with :input true should resolve from normalized dash format."""
|
||||
config = {"input": True, "name": "Second Video"}
|
||||
input_hashes = {"second-video": "QmSecondVideo789"}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmSecondVideo789"
|
||||
|
||||
def test_source_with_input_true_normalized_underscore(self):
|
||||
"""SOURCE with :input true should resolve from normalized underscore format."""
|
||||
config = {"input": True, "name": "Second Video"}
|
||||
input_hashes = {"second_video": "QmSecondVideoUnderscore"}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmSecondVideoUnderscore"
|
||||
|
||||
def test_source_with_input_true_lowercase(self):
|
||||
"""SOURCE with :input true should resolve from lowercase format."""
|
||||
config = {"input": True, "name": "MyVideo"}
|
||||
input_hashes = {"myvideo": "QmLowercaseVideo"}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmLowercaseVideo"
|
||||
|
||||
def test_source_with_input_true_missing_raises(self):
|
||||
"""SOURCE with :input true should raise if not in input_hashes."""
|
||||
config = {"input": True, "name": "Missing Video"}
|
||||
input_hashes = {"other-video": "QmOther123"}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
resolve_source_cid(config, input_hashes)
|
||||
|
||||
assert "Missing Video" in str(excinfo.value)
|
||||
assert "not found in input_hashes" in str(excinfo.value)
|
||||
assert "other-video" in str(excinfo.value) # Shows available keys
|
||||
|
||||
def test_source_without_cid_or_input_returns_none(self):
|
||||
"""SOURCE without :cid or :input should return None."""
|
||||
config = {"name": "some-source"}
|
||||
input_hashes = {}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid is None
|
||||
|
||||
def test_source_priority_cid_over_input(self):
|
||||
"""SOURCE with both :cid and :input true should use :cid."""
|
||||
config = {"cid": "QmDirectCid", "input": True, "name": "my-video"}
|
||||
input_hashes = {"my-video": "QmInputHashCid"}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmDirectCid" # :cid takes priority
|
||||
|
||||
|
||||
class TestSourceNameVariants:
|
||||
"""Tests for the various input name normalization formats."""
|
||||
|
||||
@pytest.mark.parametrize("source_name,input_key", [
|
||||
("video", "video"),
|
||||
("My Video", "my-video"),
|
||||
("My Video", "my_video"),
|
||||
("My Video", "My Video"),
|
||||
("CamelCase", "camelcase"),
|
||||
("multiple spaces", "multiple spaces"), # Only single replace
|
||||
])
|
||||
def test_name_variant_matching(self, source_name: str, input_key: str):
|
||||
"""Various name formats should match."""
|
||||
config = {"input": True, "name": source_name}
|
||||
input_hashes = {input_key: "QmTestCid"}
|
||||
|
||||
cid = resolve_source_cid(config, input_hashes)
|
||||
assert cid == "QmTestCid"
|
||||
|
||||
|
||||
class TestExecuteRecipeIntegration:
|
||||
"""Integration tests for execute_recipe with SOURCE nodes."""
|
||||
|
||||
def test_recipe_with_user_input_source(self):
|
||||
"""
|
||||
Recipe execution should resolve SOURCE nodes with :input true.
|
||||
|
||||
This is the bug that was causing "No executor for node type: SOURCE".
|
||||
"""
|
||||
# Create mock plan with a SOURCE that has :input true
|
||||
source_step = MockStep(
|
||||
step_id="source_1",
|
||||
node_type="SOURCE",
|
||||
config={"input": True, "name": "Second Video", "description": "User input"},
|
||||
cache_id="abc123",
|
||||
level=0,
|
||||
)
|
||||
|
||||
effect_step = MockStep(
|
||||
step_id="effect_1",
|
||||
node_type="EFFECT",
|
||||
config={"effect": "invert"},
|
||||
cache_id="def456",
|
||||
input_steps=["source_1"],
|
||||
level=1,
|
||||
)
|
||||
|
||||
plan = MockPlan(
|
||||
steps=[source_step, effect_step],
|
||||
output_step_id="effect_1",
|
||||
)
|
||||
|
||||
# Input hashes provided by user
|
||||
input_hashes = {
|
||||
"second-video": "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS"
|
||||
}
|
||||
|
||||
# Verify source CID resolution works
|
||||
resolved_cid = resolve_source_cid(source_step.config, input_hashes)
|
||||
assert resolved_cid == "QmS4885aRikrjDB4yHPg9yTiPcBFWadZKVfAEvUy7B32zS"
|
||||
|
||||
def test_recipe_with_fixed_and_user_sources(self):
|
||||
"""
|
||||
Recipe with both fixed (asset) and user-input sources.
|
||||
|
||||
This is the dog-invert-concat recipe pattern:
|
||||
- Fixed source: cat asset with known CID
|
||||
- User input: Second Video from input_hashes
|
||||
"""
|
||||
fixed_source = MockStep(
|
||||
step_id="source_cat",
|
||||
node_type="SOURCE",
|
||||
config={"cid": "QmCatVideo123", "asset": "cat"},
|
||||
cache_id="cat_cache",
|
||||
level=0,
|
||||
)
|
||||
|
||||
user_source = MockStep(
|
||||
step_id="source_user",
|
||||
node_type="SOURCE",
|
||||
config={"input": True, "name": "Second Video"},
|
||||
cache_id="user_cache",
|
||||
level=0,
|
||||
)
|
||||
|
||||
input_hashes = {
|
||||
"second-video": "QmUserProvidedVideo456"
|
||||
}
|
||||
|
||||
# Fixed source uses its cid
|
||||
fixed_cid = resolve_source_cid(fixed_source.config, input_hashes)
|
||||
assert fixed_cid == "QmCatVideo123"
|
||||
|
||||
# User source resolves from input_hashes
|
||||
user_cid = resolve_source_cid(user_source.config, input_hashes)
|
||||
assert user_cid == "QmUserProvidedVideo456"
|
||||
|
||||
|
||||
class TestCompoundNodeHandling:
|
||||
"""
|
||||
Tests for COMPOUND node handling.
|
||||
|
||||
COMPOUND nodes are collapsed effect chains that should be executed
|
||||
sequentially through their respective effect executors.
|
||||
|
||||
Bug fixed: "No executor for node type: COMPOUND"
|
||||
"""
|
||||
|
||||
def test_compound_node_has_filter_chain(self):
|
||||
"""COMPOUND nodes must have a filter_chain config."""
|
||||
step = MockStep(
|
||||
step_id="compound_1",
|
||||
node_type="COMPOUND",
|
||||
config={
|
||||
"filter_chain": [
|
||||
{"type": "EFFECT", "config": {"effect": "identity"}},
|
||||
{"type": "EFFECT", "config": {"effect": "dog"}},
|
||||
],
|
||||
"inputs": ["source_1"],
|
||||
},
|
||||
cache_id="compound_cache",
|
||||
level=1,
|
||||
)
|
||||
|
||||
assert step.node_type == "COMPOUND"
|
||||
assert "filter_chain" in step.config
|
||||
assert len(step.config["filter_chain"]) == 2
|
||||
|
||||
def test_compound_filter_chain_has_effects(self):
|
||||
"""COMPOUND filter_chain should contain EFFECT items with effect names."""
|
||||
filter_chain = [
|
||||
{"type": "EFFECT", "config": {"effect": "identity", "cid": "Qm123"}},
|
||||
{"type": "EFFECT", "config": {"effect": "dog", "cid": "Qm456"}},
|
||||
]
|
||||
|
||||
for item in filter_chain:
|
||||
assert item["type"] == "EFFECT"
|
||||
assert "effect" in item["config"]
|
||||
assert "cid" in item["config"]
|
||||
|
||||
def test_compound_requires_inputs(self):
|
||||
"""COMPOUND nodes must have input steps."""
|
||||
step = MockStep(
|
||||
step_id="compound_1",
|
||||
node_type="COMPOUND",
|
||||
config={"filter_chain": [], "inputs": []},
|
||||
cache_id="compound_cache",
|
||||
input_steps=[],
|
||||
level=1,
|
||||
)
|
||||
|
||||
# Empty inputs should be detected as error
|
||||
assert len(step.input_steps) == 0
|
||||
# The execute_recipe should raise ValueError for empty inputs
|
||||
|
||||
|
||||
class TestCacheIdLookup:
|
||||
"""
|
||||
Tests for cache lookup by code-addressed cache_id.
|
||||
|
||||
Bug fixed: Cache lookups by cache_id (code hash) were failing because
|
||||
only IPFS CID was indexed. Now we also index by node_id when different.
|
||||
"""
|
||||
|
||||
def test_cache_id_is_code_addressed(self):
|
||||
"""cache_id should be a SHA3-256 hash (64 hex chars), not IPFS CID."""
|
||||
# Code-addressed hash example
|
||||
cache_id = "5702aaec14adaddda9baefa94d5842143749ee19e6bb7c1fa7068dce21f51ed4"
|
||||
|
||||
assert len(cache_id) == 64
|
||||
assert all(c in "0123456789abcdef" for c in cache_id)
|
||||
assert not cache_id.startswith("Qm") # Not IPFS CID
|
||||
|
||||
def test_ipfs_cid_format(self):
|
||||
"""IPFS CIDs start with 'Qm' (v0) or 'bafy' (v1)."""
|
||||
ipfs_cid_v0 = "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ"
|
||||
ipfs_cid_v1 = "bafybeigdyrzt5sfp7udm7hu76uh7y26nf3efuylqabf3oclgtqy55fbzdi"
|
||||
|
||||
assert ipfs_cid_v0.startswith("Qm")
|
||||
assert ipfs_cid_v1.startswith("bafy")
|
||||
|
||||
def test_cache_id_differs_from_ipfs_cid(self):
|
||||
"""
|
||||
Code-addressed cache_id is computed BEFORE execution.
|
||||
IPFS CID is computed AFTER execution from file content.
|
||||
They will differ for the same logical step.
|
||||
"""
|
||||
# Same step can have:
|
||||
cache_id = "5702aaec14adaddda9baefa94d5842143749ee19e6bb7c1fa7068dce21f51ed4"
|
||||
ipfs_cid = "QmXrj6tSSn1vQXxxEY2Tyoudvt4CeeqR9gGQwSt7WFrhMZ"
|
||||
|
||||
assert cache_id != ipfs_cid
|
||||
# Both should be indexable for cache lookups
|
||||
|
||||
|
||||
class TestPlanStepTypes:
|
||||
"""
|
||||
Tests verifying all node types in S-expression plans are handled.
|
||||
|
||||
These tests document the node types that execute_recipe must handle.
|
||||
"""
|
||||
|
||||
def test_source_node_types(self):
|
||||
"""SOURCE nodes: fixed asset or user input."""
|
||||
# Fixed asset source
|
||||
fixed = MockStep("s1", "SOURCE", {"cid": "Qm123"}, "cache1")
|
||||
assert fixed.node_type == "SOURCE"
|
||||
assert "cid" in fixed.config
|
||||
|
||||
# User input source
|
||||
user = MockStep("s2", "SOURCE", {"input": True, "name": "video"}, "cache2")
|
||||
assert user.node_type == "SOURCE"
|
||||
assert user.config.get("input") is True
|
||||
|
||||
def test_effect_node_type(self):
|
||||
"""EFFECT nodes: single effect application."""
|
||||
step = MockStep(
|
||||
"e1", "EFFECT",
|
||||
{"effect": "invert", "cid": "QmEffect123", "intensity": 1.0},
|
||||
"cache3"
|
||||
)
|
||||
assert step.node_type == "EFFECT"
|
||||
assert "effect" in step.config
|
||||
|
||||
def test_compound_node_type(self):
|
||||
"""COMPOUND nodes: collapsed effect chains."""
|
||||
step = MockStep(
|
||||
"c1", "COMPOUND",
|
||||
{"filter_chain": [{"type": "EFFECT", "config": {}}]},
|
||||
"cache4"
|
||||
)
|
||||
assert step.node_type == "COMPOUND"
|
||||
assert "filter_chain" in step.config
|
||||
|
||||
def test_sequence_node_type(self):
|
||||
"""SEQUENCE nodes: concatenate multiple clips."""
|
||||
step = MockStep(
|
||||
"seq1", "SEQUENCE",
|
||||
{"transition": {"type": "cut"}},
|
||||
"cache5",
|
||||
input_steps=["clip1", "clip2"]
|
||||
)
|
||||
assert step.node_type == "SEQUENCE"
|
||||
|
||||
|
||||
class TestNodeTypeCaseSensitivity:
|
||||
"""
|
||||
Tests for node type case handling.
|
||||
|
||||
Bug fixed: S-expression plans use lowercase (source, compound, effect)
|
||||
but code was checking uppercase (SOURCE, COMPOUND, EFFECT).
|
||||
"""
|
||||
|
||||
def test_source_lowercase_from_sexp(self):
|
||||
"""S-expression plans produce lowercase node types."""
|
||||
# From plan: (source :cid "Qm...")
|
||||
step = MockStep("s1", "source", {"cid": "Qm123"}, "cache1")
|
||||
|
||||
# Code should handle lowercase
|
||||
assert step.node_type.upper() == "SOURCE"
|
||||
|
||||
def test_compound_lowercase_from_sexp(self):
|
||||
"""COMPOUND from S-expression is lowercase."""
|
||||
# From plan: (compound :filter_chain ...)
|
||||
step = MockStep("c1", "compound", {"filter_chain": []}, "cache2")
|
||||
|
||||
assert step.node_type.upper() == "COMPOUND"
|
||||
|
||||
def test_effect_lowercase_from_sexp(self):
|
||||
"""EFFECT from S-expression is lowercase."""
|
||||
# From plan: (effect :effect "invert" ...)
|
||||
step = MockStep("e1", "effect", {"effect": "invert"}, "cache3")
|
||||
|
||||
assert step.node_type.upper() == "EFFECT"
|
||||
|
||||
def test_sequence_lowercase_from_sexp(self):
|
||||
"""SEQUENCE from S-expression is lowercase."""
|
||||
# From plan: (sequence :transition ...)
|
||||
step = MockStep("seq1", "sequence", {"transition": {}}, "cache4")
|
||||
|
||||
assert step.node_type.upper() == "SEQUENCE"
|
||||
|
||||
def test_node_type_comparison_must_be_case_insensitive(self):
|
||||
"""
|
||||
Node type comparisons must be case-insensitive.
|
||||
|
||||
This is the actual bug - checking step.node_type == "SOURCE"
|
||||
fails when step.node_type is "source" from S-expression.
|
||||
"""
|
||||
sexp_types = ["source", "compound", "effect", "sequence"]
|
||||
code_types = ["SOURCE", "COMPOUND", "EFFECT", "SEQUENCE"]
|
||||
|
||||
for sexp, code in zip(sexp_types, code_types):
|
||||
# Wrong: direct comparison fails
|
||||
assert sexp != code, f"{sexp} should not equal {code}"
|
||||
|
||||
# Right: case-insensitive comparison works
|
||||
assert sexp.upper() == code, f"{sexp}.upper() should equal {code}"
|
||||
|
||||
|
||||
class TestExecuteRecipeErrorHandling:
|
||||
"""Tests for error handling in execute_recipe."""
|
||||
|
||||
def test_missing_input_hash_error_message(self):
|
||||
"""Error should list available input keys when source not found."""
|
||||
config = {"input": True, "name": "Unknown Video"}
|
||||
input_hashes = {"video-a": "Qm1", "video-b": "Qm2"}
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
resolve_source_cid(config, input_hashes)
|
||||
|
||||
error_msg = str(excinfo.value)
|
||||
assert "Unknown Video" in error_msg
|
||||
assert "video-a" in error_msg or "video-b" in error_msg
|
||||
|
||||
def test_source_no_cid_no_input_error(self):
|
||||
"""SOURCE without cid or input flag should return None (invalid)."""
|
||||
config = {"name": "broken-source"} # Missing both cid and input
|
||||
input_hashes = {}
|
||||
|
||||
result = resolve_source_cid(config, input_hashes)
|
||||
assert result is None # execute_recipe should catch this
|
||||
|
||||
|
||||
class TestRecipeOutputRequired:
|
||||
"""
|
||||
Tests verifying recipes must produce output to succeed.
|
||||
|
||||
Bug: Recipe was returning success=True with output_cid=None
|
||||
"""
|
||||
|
||||
def test_recipe_without_output_should_fail(self):
|
||||
"""
|
||||
Recipe execution must fail if no output is produced.
|
||||
|
||||
This catches the bug where execute_recipe returned success=True
|
||||
but output_cid was None.
|
||||
"""
|
||||
# Simulate the check that should happen in execute_recipe
|
||||
output_cid = None
|
||||
step_results = {"step1": {"status": "executed", "path": "/tmp/x"}}
|
||||
|
||||
# This is the logic that should be in execute_recipe
|
||||
success = output_cid is not None
|
||||
|
||||
assert success is False, "Recipe with no output_cid should fail"
|
||||
|
||||
def test_recipe_with_output_should_succeed(self):
|
||||
"""Recipe with valid output should succeed."""
|
||||
output_cid = "QmOutputCid123"
|
||||
|
||||
success = output_cid is not None
|
||||
assert success is True
|
||||
|
||||
def test_output_step_result_must_have_cid(self):
|
||||
"""Output step result must contain cid or cache_id."""
|
||||
# Step result without cid
|
||||
bad_result = {"status": "executed", "path": "/tmp/output.mkv"}
|
||||
output_cid = bad_result.get("cid") or bad_result.get("cache_id")
|
||||
assert output_cid is None, "Should detect missing cid"
|
||||
|
||||
# Step result with cid
|
||||
good_result = {"status": "executed", "path": "/tmp/output.mkv", "cid": "Qm123"}
|
||||
output_cid = good_result.get("cid") or good_result.get("cache_id")
|
||||
assert output_cid == "Qm123"
|
||||
|
||||
def test_output_step_must_exist_in_results(self):
|
||||
"""Output step must be present in step_results."""
|
||||
step_results = {
|
||||
"source_1": {"status": "source", "cid": "QmSrc"},
|
||||
"effect_1": {"status": "executed", "cid": "QmEffect"},
|
||||
# Note: output_step "sequence_1" is missing!
|
||||
}
|
||||
|
||||
output_step_id = "sequence_1"
|
||||
output_result = step_results.get(output_step_id, {})
|
||||
output_cid = output_result.get("cid")
|
||||
|
||||
assert output_cid is None, "Missing output step should result in no cid"
|
||||
185
l1/tests/test_frame_compatibility.py
Normal file
185
l1/tests/test_frame_compatibility.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Integration tests for GPU/CPU frame compatibility.
|
||||
|
||||
These tests verify that all primitives work correctly with both:
|
||||
- numpy arrays (CPU frames)
|
||||
- CuPy arrays (GPU frames)
|
||||
- GPUFrame wrapper objects
|
||||
|
||||
Run with: pytest tests/test_frame_compatibility.py -v
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# Try to import CuPy
|
||||
try:
|
||||
import cupy as cp
|
||||
HAS_GPU = True
|
||||
except ImportError:
|
||||
cp = None
|
||||
HAS_GPU = False
|
||||
|
||||
|
||||
def create_test_frame(on_gpu=False):
|
||||
"""Create a test frame (100x100 RGB)."""
|
||||
frame = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
if on_gpu and HAS_GPU:
|
||||
return cp.asarray(frame)
|
||||
return frame
|
||||
|
||||
|
||||
class MockGPUFrame:
|
||||
"""Mock GPUFrame for testing without full GPU stack."""
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def cpu(self):
|
||||
if HAS_GPU and hasattr(self._data, 'get'):
|
||||
return self._data.get()
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def gpu(self):
|
||||
if HAS_GPU:
|
||||
if hasattr(self._data, 'get'):
|
||||
return self._data
|
||||
return cp.asarray(self._data)
|
||||
raise RuntimeError("No GPU available")
|
||||
|
||||
|
||||
class TestColorOps:
|
||||
"""Test color_ops primitives with different frame types."""
|
||||
|
||||
def test_shift_hsv_numpy(self):
|
||||
"""shift-hsv should work with numpy arrays."""
|
||||
from sexp_effects.primitive_libs.color_ops import prim_shift_hsv
|
||||
frame = create_test_frame(on_gpu=False)
|
||||
result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9)
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.shape == frame.shape
|
||||
|
||||
@pytest.mark.skipif(not HAS_GPU, reason="No GPU")
|
||||
def test_shift_hsv_cupy(self):
|
||||
"""shift-hsv should work with CuPy arrays."""
|
||||
from sexp_effects.primitive_libs.color_ops import prim_shift_hsv
|
||||
frame = create_test_frame(on_gpu=True)
|
||||
result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9)
|
||||
assert isinstance(result, np.ndarray) # Should return numpy
|
||||
|
||||
def test_shift_hsv_gpuframe(self):
|
||||
"""shift-hsv should work with GPUFrame objects."""
|
||||
from sexp_effects.primitive_libs.color_ops import prim_shift_hsv
|
||||
frame = MockGPUFrame(create_test_frame(on_gpu=False))
|
||||
result = prim_shift_hsv(frame, h=30, s=1.2, v=0.9)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_invert_numpy(self):
|
||||
"""invert-img should work with numpy arrays."""
|
||||
from sexp_effects.primitive_libs.color_ops import prim_invert_img
|
||||
frame = create_test_frame(on_gpu=False)
|
||||
result = prim_invert_img(frame)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_adjust_numpy(self):
|
||||
"""adjust should work with numpy arrays."""
|
||||
from sexp_effects.primitive_libs.color_ops import prim_adjust
|
||||
frame = create_test_frame(on_gpu=False)
|
||||
result = prim_adjust(frame, brightness=10, contrast=1.2)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
|
||||
class TestGeometry:
|
||||
"""Test geometry primitives with different frame types."""
|
||||
|
||||
def test_rotate_numpy(self):
|
||||
"""rotate should work with numpy arrays."""
|
||||
from sexp_effects.primitive_libs.geometry import prim_rotate
|
||||
frame = create_test_frame(on_gpu=False)
|
||||
result = prim_rotate(frame, 45)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_scale_numpy(self):
|
||||
"""scale should work with numpy arrays."""
|
||||
from sexp_effects.primitive_libs.geometry import prim_scale
|
||||
frame = create_test_frame(on_gpu=False)
|
||||
result = prim_scale(frame, 1.5)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
|
||||
class TestBlending:
|
||||
"""Test blending primitives with different frame types."""
|
||||
|
||||
def test_blend_numpy(self):
|
||||
"""blend should work with numpy arrays."""
|
||||
from sexp_effects.primitive_libs.blending import prim_blend
|
||||
frame_a = create_test_frame(on_gpu=False)
|
||||
frame_b = create_test_frame(on_gpu=False)
|
||||
result = prim_blend(frame_a, frame_b, 0.5)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
|
||||
class TestInterpreterConversion:
|
||||
"""Test the interpreter's frame conversion."""
|
||||
|
||||
def test_maybe_to_numpy_none(self):
|
||||
"""_maybe_to_numpy should handle None."""
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
# Create minimal interpreter
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
|
||||
f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))')
|
||||
f.flush()
|
||||
interp = StreamInterpreter(f.name)
|
||||
|
||||
assert interp._maybe_to_numpy(None) is None
|
||||
|
||||
def test_maybe_to_numpy_numpy(self):
|
||||
"""_maybe_to_numpy should pass through numpy arrays."""
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
|
||||
f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))')
|
||||
f.flush()
|
||||
interp = StreamInterpreter(f.name)
|
||||
|
||||
frame = create_test_frame(on_gpu=False)
|
||||
result = interp._maybe_to_numpy(frame)
|
||||
assert result is frame # Should be same object
|
||||
|
||||
@pytest.mark.skipif(not HAS_GPU, reason="No GPU")
|
||||
def test_maybe_to_numpy_cupy(self):
|
||||
"""_maybe_to_numpy should convert CuPy to numpy."""
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
|
||||
f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))')
|
||||
f.flush()
|
||||
interp = StreamInterpreter(f.name)
|
||||
|
||||
frame = create_test_frame(on_gpu=True)
|
||||
result = interp._maybe_to_numpy(frame)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
def test_maybe_to_numpy_gpuframe(self):
|
||||
"""_maybe_to_numpy should convert GPUFrame to numpy."""
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.sexp', delete=False) as f:
|
||||
f.write('(stream "test" :fps 30 :width 100 :height 100 (frame frame))')
|
||||
f.flush()
|
||||
interp = StreamInterpreter(f.name)
|
||||
|
||||
frame = MockGPUFrame(create_test_frame(on_gpu=False))
|
||||
result = interp._maybe_to_numpy(frame)
|
||||
assert isinstance(result, np.ndarray)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
272
l1/tests/test_item_visibility.py
Normal file
272
l1/tests/test_item_visibility.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Tests for item visibility in L1 web UI.
|
||||
|
||||
Bug found 2026-01-12: L1 run succeeded but web UI not showing:
|
||||
- Runs
|
||||
- Recipes
|
||||
- Created media
|
||||
|
||||
Root causes identified:
|
||||
1. Recipes: owner field filtering but owner never set in loaded recipes
|
||||
2. Media: item_types table entries not created on upload/import
|
||||
3. Run outputs: outputs not registered in item_types table
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import tempfile
|
||||
|
||||
|
||||
class TestRecipeVisibility:
|
||||
"""Tests for recipe listing visibility."""
|
||||
|
||||
def test_recipe_filter_allows_none_owner(self) -> None:
|
||||
"""
|
||||
Regression test: The recipe filter should allow recipes where owner is None.
|
||||
|
||||
Bug: recipe_service.list_recipes() filtered by owner == actor_id,
|
||||
but owner field is None in recipes loaded from S-expression files.
|
||||
This caused ALL recipes to be filtered out.
|
||||
|
||||
Fix: The filter is now: actor_id is None or owner is None or owner == actor_id
|
||||
"""
|
||||
# Simulate the filter logic from recipe_service.list_recipes
|
||||
# OLD (broken): if actor_id is None or owner == actor_id
|
||||
# NEW (fixed): if actor_id is None or owner is None or owner == actor_id
|
||||
|
||||
test_cases = [
|
||||
# (actor_id, owner, expected_visible, description)
|
||||
(None, None, True, "No filter, no owner -> visible"),
|
||||
(None, "@someone@example.com", True, "No filter, has owner -> visible"),
|
||||
("@testuser@example.com", None, True, "Has filter, no owner -> visible (shared)"),
|
||||
("@testuser@example.com", "@testuser@example.com", True, "Filter matches owner -> visible"),
|
||||
("@testuser@example.com", "@other@example.com", False, "Filter doesn't match -> hidden"),
|
||||
]
|
||||
|
||||
for actor_id, owner, expected_visible, description in test_cases:
|
||||
# This is the FIXED filter logic from recipe_service.py line 86
|
||||
is_visible = actor_id is None or owner is None or owner == actor_id
|
||||
|
||||
assert is_visible == expected_visible, f"Failed: {description}"
|
||||
|
||||
def test_recipe_filter_old_logic_was_broken(self) -> None:
|
||||
"""Document that the old filter logic excluded all recipes with owner=None."""
|
||||
# OLD filter: actor_id is None or owner == actor_id
|
||||
# This broke when owner=None and actor_id was provided
|
||||
|
||||
actor_id = "@testuser@example.com"
|
||||
owner = None # This is what compiled sexp produces
|
||||
|
||||
# OLD logic (broken):
|
||||
old_logic_visible = actor_id is None or owner == actor_id
|
||||
assert old_logic_visible is False, "Old logic incorrectly hid owner=None recipes"
|
||||
|
||||
# NEW logic (fixed):
|
||||
new_logic_visible = actor_id is None or owner is None or owner == actor_id
|
||||
assert new_logic_visible is True, "New logic should show owner=None recipes"
|
||||
|
||||
|
||||
class TestMediaVisibility:
|
||||
"""Tests for media visibility after upload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_content_creates_item_type_record(self) -> None:
|
||||
"""
|
||||
Test: Uploaded media must be registered in item_types table via save_item_metadata.
|
||||
|
||||
The save_item_metadata function creates entries in item_types table,
|
||||
enabling the media to appear in list_media queries.
|
||||
"""
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"cache_service",
|
||||
"/home/giles/art/art-celery/app/services/cache_service.py"
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
CacheService = module.CacheService
|
||||
|
||||
# Create mocks
|
||||
mock_db = AsyncMock()
|
||||
mock_db.create_cache_item = AsyncMock()
|
||||
mock_db.save_item_metadata = AsyncMock()
|
||||
|
||||
mock_cache = MagicMock()
|
||||
cached_result = MagicMock()
|
||||
cached_result.cid = "QmUploadedContent123"
|
||||
mock_cache.put.return_value = (cached_result, "QmIPFSCid123")
|
||||
|
||||
service = CacheService(database=mock_db, cache_manager=mock_cache)
|
||||
|
||||
# Upload content
|
||||
cid, ipfs_cid, error = await service.upload_content(
|
||||
content=b"test video content",
|
||||
filename="test.mp4",
|
||||
actor_id="@testuser@example.com",
|
||||
)
|
||||
|
||||
assert error is None, f"Upload failed: {error}"
|
||||
assert cid is not None
|
||||
|
||||
# Verify save_item_metadata was called (which creates item_types entry)
|
||||
mock_db.save_item_metadata.assert_called_once()
|
||||
|
||||
# Verify it was called with correct actor_id and a media type (not mime type)
|
||||
call_kwargs = mock_db.save_item_metadata.call_args[1]
|
||||
assert call_kwargs.get('actor_id') == "@testuser@example.com", \
|
||||
"save_item_metadata must be called with the uploading user's actor_id"
|
||||
# item_type should be media category like "video", "image", "audio", "unknown"
|
||||
# NOT mime type like "video/mp4"
|
||||
item_type = call_kwargs.get('item_type')
|
||||
assert item_type in ("video", "image", "audio", "unknown"), \
|
||||
f"item_type should be media category, got '{item_type}'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_import_from_ipfs_creates_item_type_record(self) -> None:
|
||||
"""
|
||||
Test: Imported media must be registered in item_types table via save_item_metadata.
|
||||
|
||||
The save_item_metadata function creates entries in item_types table with
|
||||
detected media type, enabling the media to appear in list_media queries.
|
||||
"""
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"cache_service",
|
||||
"/home/giles/art/art-celery/app/services/cache_service.py"
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
CacheService = module.CacheService
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.create_cache_item = AsyncMock()
|
||||
mock_db.save_item_metadata = AsyncMock()
|
||||
|
||||
mock_cache = MagicMock()
|
||||
cached_result = MagicMock()
|
||||
cached_result.cid = "QmImportedContent123"
|
||||
mock_cache.put.return_value = (cached_result, "QmIPFSCid456")
|
||||
|
||||
service = CacheService(database=mock_db, cache_manager=mock_cache)
|
||||
service.cache_dir = Path(tempfile.gettempdir())
|
||||
|
||||
# We need to mock the ipfs_client module at the right location
|
||||
import importlib
|
||||
ipfs_module = MagicMock()
|
||||
ipfs_module.get_file = MagicMock(return_value=True)
|
||||
|
||||
# Patch at module level
|
||||
with patch.dict('sys.modules', {'ipfs_client': ipfs_module}):
|
||||
# Import from IPFS
|
||||
cid, error = await service.import_from_ipfs(
|
||||
ipfs_cid="QmSourceIPFSCid",
|
||||
actor_id="@testuser@example.com",
|
||||
)
|
||||
|
||||
# Verify save_item_metadata was called (which creates item_types entry)
|
||||
mock_db.save_item_metadata.assert_called_once()
|
||||
|
||||
# Verify it was called with detected media type (not hardcoded "media")
|
||||
call_kwargs = mock_db.save_item_metadata.call_args[1]
|
||||
item_type = call_kwargs.get('item_type')
|
||||
assert item_type in ("video", "image", "audio", "unknown"), \
|
||||
f"item_type should be detected media category, got '{item_type}'"
|
||||
|
||||
|
||||
class TestRunOutputVisibility:
|
||||
"""Tests for run output visibility."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completed_run_output_visible_in_media_list(self) -> None:
|
||||
"""
|
||||
Run outputs should be accessible in media listings.
|
||||
|
||||
When a run completes, its output should be registered in item_types
|
||||
so it appears in the user's media gallery.
|
||||
"""
|
||||
# This test documents the expected behavior
|
||||
# Run outputs are stored in run_cache but should also be in item_types
|
||||
# for the media gallery to show them
|
||||
|
||||
# The fix should either:
|
||||
# 1. Add item_types entry when run completes, OR
|
||||
# 2. Modify list_media to also check run_cache outputs
|
||||
pass # Placeholder for implementation test
|
||||
|
||||
|
||||
class TestDatabaseItemTypes:
|
||||
"""Tests for item_types database operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item_type_function_exists(self) -> None:
|
||||
"""Verify add_item_type function exists and has correct signature."""
|
||||
import database
|
||||
|
||||
assert hasattr(database, 'add_item_type'), \
|
||||
"database.add_item_type function should exist"
|
||||
|
||||
# Check it's an async function
|
||||
import inspect
|
||||
assert inspect.iscoroutinefunction(database.add_item_type), \
|
||||
"add_item_type should be an async function"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_items_returns_items_from_item_types(self) -> None:
|
||||
"""
|
||||
Verify get_user_items queries item_types table.
|
||||
|
||||
If item_types has no entries for a user, they see no media.
|
||||
"""
|
||||
# This is a documentation test showing the data flow:
|
||||
# 1. User uploads content -> should create item_types entry
|
||||
# 2. list_media -> calls get_user_items -> queries item_types
|
||||
# 3. If step 1 didn't create item_types entry, step 2 returns empty
|
||||
pass
|
||||
|
||||
|
||||
class TestTemplateRendering:
|
||||
"""Tests for template variable passing."""
|
||||
|
||||
def test_cache_not_found_template_receives_content_hash(self) -> None:
|
||||
"""
|
||||
Regression test: cache/not_found.html template requires content_hash.
|
||||
|
||||
Bug: The template uses {{ content_hash[:24] }} but the route
|
||||
doesn't pass content_hash to the render context.
|
||||
|
||||
Error: jinja2.exceptions.UndefinedError: 'content_hash' is undefined
|
||||
"""
|
||||
# This test documents the bug - the template expects content_hash
|
||||
# but the route at /app/app/routers/cache.py line 57 doesn't provide it
|
||||
pass # Will verify fix by checking route code
|
||||
|
||||
|
||||
class TestOwnerFieldInRecipes:
|
||||
"""Tests for owner field handling in recipes."""
|
||||
|
||||
def test_sexp_recipe_has_none_owner_by_default(self) -> None:
|
||||
"""
|
||||
S-expression recipes have owner=None by default.
|
||||
|
||||
The compiled recipe includes owner field but it's None,
|
||||
so the list_recipes filter must allow owner=None to show
|
||||
shared/public recipes.
|
||||
"""
|
||||
sample_sexp = """
|
||||
(recipe "test"
|
||||
(-> (source :input true :name "video")
|
||||
(fx identity)))
|
||||
"""
|
||||
|
||||
from artdag.sexp import compile_string
|
||||
|
||||
compiled = compile_string(sample_sexp)
|
||||
recipe_dict = compiled.to_dict()
|
||||
|
||||
# The compiled recipe has owner field but it's None
|
||||
assert recipe_dict.get("owner") is None, \
|
||||
"Compiled S-expression should have owner=None"
|
||||
|
||||
# This means the filter must allow owner=None for recipes to be visible
|
||||
# The fix: if actor_id is None or owner is None or owner == actor_id
|
||||
517
l1/tests/test_jax_pipeline_integration.py
Normal file
517
l1/tests/test_jax_pipeline_integration.py
Normal file
@@ -0,0 +1,517 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Integration tests comparing JAX and Python rendering pipelines.
|
||||
|
||||
These tests ensure the JAX-compiled effect chains produce identical output
|
||||
to the Python/NumPy path. They test:
|
||||
1. Full effect pipelines through both interpreters
|
||||
2. Multi-frame sequences (to catch phase-dependent bugs)
|
||||
3. Compiled effect chain fusion
|
||||
4. Edge cases like shrinking/zooming that affect boundary handling
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
import numpy as np
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the art-celery module is importable
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
|
||||
# Path to test resources
|
||||
TEST_DIR = Path('/home/giles/art/test')
|
||||
EFFECTS_DIR = TEST_DIR / 'sexp_effects' / 'effects'
|
||||
TEMPLATES_DIR = TEST_DIR / 'templates'
|
||||
|
||||
|
||||
def create_test_image(h=96, w=128):
|
||||
"""Create a test image with distinct patterns."""
|
||||
import cv2
|
||||
img = np.zeros((h, w, 3), dtype=np.uint8)
|
||||
|
||||
# Create gradient background
|
||||
for y in range(h):
|
||||
for x in range(w):
|
||||
img[y, x] = [
|
||||
int(255 * x / w), # R: horizontal gradient
|
||||
int(255 * y / h), # G: vertical gradient
|
||||
128 # B: constant
|
||||
]
|
||||
|
||||
# Add features
|
||||
cv2.circle(img, (w//2, h//2), 20, (255, 0, 0), -1)
|
||||
cv2.rectangle(img, (10, 10), (30, 30), (0, 255, 0), -1)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def test_env(tmp_path_factory):
|
||||
"""Set up test environment with sexp files and test media."""
|
||||
test_dir = tmp_path_factory.mktemp('sexp_test')
|
||||
original_dir = os.getcwd()
|
||||
os.chdir(test_dir)
|
||||
|
||||
# Create directories
|
||||
(test_dir / 'effects').mkdir()
|
||||
(test_dir / 'sexp_effects' / 'effects').mkdir(parents=True)
|
||||
|
||||
# Create test image
|
||||
import cv2
|
||||
test_img = create_test_image()
|
||||
cv2.imwrite(str(test_dir / 'test_image.png'), test_img)
|
||||
|
||||
# Copy required effect files
|
||||
for effect in ['rotate', 'zoom', 'blend', 'invert', 'hue_shift']:
|
||||
src = EFFECTS_DIR / f'{effect}.sexp'
|
||||
dst = test_dir / 'sexp_effects' / 'effects' / f'{effect}.sexp'
|
||||
if src.exists():
|
||||
shutil.copy(src, dst)
|
||||
|
||||
yield {
|
||||
'dir': test_dir,
|
||||
'image_path': test_dir / 'test_image.png',
|
||||
'test_img': test_img,
|
||||
}
|
||||
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
def create_sexp_file(test_dir, content, filename='test.sexp'):
|
||||
"""Create a test sexp file."""
|
||||
path = test_dir / 'effects' / filename
|
||||
with open(path, 'w') as f:
|
||||
f.write(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
class TestJaxPythonPipelineEquivalence:
|
||||
"""Test that JAX and Python pipelines produce equivalent output."""
|
||||
|
||||
def test_single_rotate_effect(self, test_env):
|
||||
"""Test that a single rotate effect matches between Python and JAX."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
|
||||
(frame (rotate frame :angle 15 :speed 0))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
# Python path
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
# JAX path
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
# Inject test image into globals
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = py_interp._eval(py_interp.frame_pipeline, frame_env)
|
||||
jax_result = jax_interp._eval(jax_interp.frame_pipeline, frame_env)
|
||||
|
||||
# Force deferred if needed
|
||||
py_result = np.asarray(py_interp._maybe_force(py_result))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(jax_result))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 2.0, f"Rotate effect: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_rotate_then_zoom(self, test_env):
|
||||
"""Test rotate followed by zoom - tests effect chain fusion."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame (-> (rotate frame :angle 15 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 2.0, f"Rotate+zoom chain: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_zoom_shrink_boundary_handling(self, test_env):
|
||||
"""Test zoom with shrinking factor - critical for boundary handling."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame (zoom frame :amount 0.8 :speed 0))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
# Check corners specifically - these are most affected by boundary handling
|
||||
h, w = test_img.shape[:2]
|
||||
corners = [(0, 0), (0, w-1), (h-1, 0), (h-1, w-1)]
|
||||
for y, x in corners:
|
||||
py_val = py_result[y, x]
|
||||
jax_val = jax_result[y, x]
|
||||
corner_diff = np.abs(py_val.astype(float) - jax_val.astype(float)).max()
|
||||
assert corner_diff < 10.0, f"Corner ({y},{x}): diff {corner_diff} - py={py_val}, jax={jax_val}"
|
||||
|
||||
def test_blend_two_clips(self, test_env):
|
||||
"""Test blending two effect chains - the core bug scenario."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(require-primitives "core")
|
||||
(require-primitives "image")
|
||||
(require-primitives "blending")
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||
|
||||
(frame
|
||||
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||
(zoom :amount 1.05 :speed 0))
|
||||
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0))]
|
||||
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
max_diff = np.max(diff)
|
||||
|
||||
# Check edge region specifically
|
||||
edge_diff = diff[0, :].mean()
|
||||
|
||||
assert mean_diff < 3.0, f"Blend: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
assert edge_diff < 10.0, f"Blend edge: diff {edge_diff:.2f} exceeds threshold"
|
||||
|
||||
def test_blend_with_invert(self, test_env):
|
||||
"""Test blending with invert - matches the problematic recipe pattern."""
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(require-primitives "core")
|
||||
(require-primitives "image")
|
||||
(require-primitives "blending")
|
||||
(require-primitives "color_ops")
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
(effect blend :path "../sexp_effects/effects/blend.sexp")
|
||||
(effect invert :path "../sexp_effects/effects/invert.sexp")
|
||||
|
||||
(frame
|
||||
(let [clip_a (-> (rotate frame :angle 5 :speed 0)
|
||||
(zoom :amount 1.05 :speed 0)
|
||||
(invert :amount 1))
|
||||
clip_b (-> (rotate frame :angle -5 :speed 0)
|
||||
(zoom :amount 0.95 :speed 0))]
|
||||
(blend :base clip_a :overlay clip_b :opacity 0.5)))
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, Context
|
||||
import cv2
|
||||
|
||||
test_img = cv2.imread(str(test_env['image_path']))
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
py_interp = StreamInterpreter(sexp_path, use_jax=False)
|
||||
py_interp._init()
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
jax_interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
jax_interp._init()
|
||||
|
||||
ctx = Context(fps=10)
|
||||
ctx.t = 0.5
|
||||
ctx.frame_num = 5
|
||||
|
||||
frame_env = {
|
||||
'ctx': {'t': ctx.t, 'frame-num': ctx.frame_num, 'fps': 10},
|
||||
't': ctx.t, 'frame-num': ctx.frame_num,
|
||||
}
|
||||
|
||||
py_interp.globals['frame'] = test_img
|
||||
jax_interp.globals['frame'] = test_img
|
||||
|
||||
py_result = np.asarray(py_interp._maybe_force(
|
||||
py_interp._eval(py_interp.frame_pipeline, frame_env)))
|
||||
jax_result = np.asarray(jax_interp._maybe_force(
|
||||
jax_interp._eval(jax_interp.frame_pipeline, frame_env)))
|
||||
|
||||
diff = np.abs(py_result.astype(float) - jax_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 3.0, f"Blend+invert: mean diff {mean_diff:.2f} exceeds threshold"
|
||||
|
||||
|
||||
class TestDeferredEffectChainFusion:
|
||||
"""Test the DeferredEffectChain fusion mechanism specifically."""
|
||||
|
||||
def test_manual_vs_fused_chain(self, test_env):
|
||||
"""Compare manual application vs fused DeferredEffectChain."""
|
||||
import jax.numpy as jnp
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||
|
||||
# Create minimal sexp to load effects
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame frame)
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
interp._init()
|
||||
|
||||
test_img = test_env['test_img']
|
||||
jax_frame = jnp.array(test_img)
|
||||
t = 0.5
|
||||
frame_num = 5
|
||||
|
||||
# Manual step-by-step application
|
||||
rotate_fn = interp.jax_effects['rotate']
|
||||
zoom_fn = interp.jax_effects['zoom']
|
||||
|
||||
rot_angle = -5.0
|
||||
zoom_amount = 0.95
|
||||
|
||||
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||
angle=rot_angle, speed=0)
|
||||
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||
amount=zoom_amount, speed=0)
|
||||
manual_result = np.asarray(manual_result)
|
||||
|
||||
# Fused chain application
|
||||
chain = DeferredEffectChain(
|
||||
['rotate'],
|
||||
[{'angle': rot_angle, 'speed': 0}],
|
||||
jax_frame, t, frame_num
|
||||
)
|
||||
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||
|
||||
fused_result = np.asarray(interp._force_deferred(chain))
|
||||
|
||||
diff = np.abs(manual_result.astype(float) - fused_result.astype(float))
|
||||
mean_diff = np.mean(diff)
|
||||
|
||||
assert mean_diff < 1.0, f"Manual vs fused: mean diff {mean_diff:.2f}"
|
||||
|
||||
# Check specific pixels
|
||||
h, w = test_img.shape[:2]
|
||||
for y in [0, h//2, h-1]:
|
||||
for x in [0, w//2, w-1]:
|
||||
manual_val = manual_result[y, x]
|
||||
fused_val = fused_result[y, x]
|
||||
pixel_diff = np.abs(manual_val.astype(float) - fused_val.astype(float)).max()
|
||||
assert pixel_diff < 2.0, f"Pixel ({y},{x}): manual={manual_val}, fused={fused_val}"
|
||||
|
||||
def test_chain_with_shrink_zoom_boundary(self, test_env):
|
||||
"""Test that shrinking zoom handles boundaries correctly in chain."""
|
||||
import jax.numpy as jnp
|
||||
from streaming.stream_sexp_generic import StreamInterpreter, DeferredEffectChain
|
||||
|
||||
sexp_content = '''(stream "test"
|
||||
:width 128
|
||||
:height 96
|
||||
:seed 42
|
||||
|
||||
(effect rotate :path "../sexp_effects/effects/rotate.sexp")
|
||||
(effect zoom :path "../sexp_effects/effects/zoom.sexp")
|
||||
|
||||
(frame frame)
|
||||
)
|
||||
'''
|
||||
sexp_path = create_sexp_file(test_env['dir'], sexp_content)
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
interp = StreamInterpreter(sexp_path, use_jax=True)
|
||||
interp._init()
|
||||
|
||||
test_img = test_env['test_img']
|
||||
jax_frame = jnp.array(test_img)
|
||||
t = 0.5
|
||||
frame_num = 5
|
||||
|
||||
# Parameters that shrink the image (zoom < 1.0)
|
||||
rot_angle = -4.555
|
||||
zoom_amount = 0.9494 # This pulls in from edges, exposing boundaries
|
||||
|
||||
# Manual application
|
||||
rotate_fn = interp.jax_effects['rotate']
|
||||
zoom_fn = interp.jax_effects['zoom']
|
||||
|
||||
manual_result = rotate_fn(jax_frame, t=t, frame_num=frame_num, seed=42,
|
||||
angle=rot_angle, speed=0)
|
||||
manual_result = zoom_fn(manual_result, t=t, frame_num=frame_num, seed=42,
|
||||
amount=zoom_amount, speed=0)
|
||||
manual_result = np.asarray(manual_result)
|
||||
|
||||
# Fused chain
|
||||
chain = DeferredEffectChain(
|
||||
['rotate'],
|
||||
[{'angle': rot_angle, 'speed': 0}],
|
||||
jax_frame, t, frame_num
|
||||
)
|
||||
chain = chain.extend('zoom', {'amount': zoom_amount, 'speed': 0})
|
||||
|
||||
fused_result = np.asarray(interp._force_deferred(chain))
|
||||
|
||||
# Check top edge specifically - this is where boundary issues manifest
|
||||
top_edge_manual = manual_result[0, :]
|
||||
top_edge_fused = fused_result[0, :]
|
||||
|
||||
edge_diff = np.abs(top_edge_manual.astype(float) - top_edge_fused.astype(float))
|
||||
mean_edge_diff = np.mean(edge_diff)
|
||||
|
||||
assert mean_edge_diff < 2.0, f"Top edge diff: {mean_edge_diff:.2f}"
|
||||
|
||||
# Check for zeros at edges that shouldn't be there
|
||||
manual_edge_sum = np.sum(top_edge_manual)
|
||||
fused_edge_sum = np.sum(top_edge_fused)
|
||||
|
||||
if manual_edge_sum > 100: # If manual has significant values
|
||||
assert fused_edge_sum > manual_edge_sum * 0.5, \
|
||||
f"Fused has too many zeros: manual sum={manual_edge_sum}, fused sum={fused_edge_sum}"
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
334
l1/tests/test_jax_primitives.py
Normal file
334
l1/tests/test_jax_primitives.py
Normal file
@@ -0,0 +1,334 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test framework to verify JAX primitives match Python primitives.
|
||||
|
||||
Compares output of each primitive through:
|
||||
1. Python/NumPy path
|
||||
2. JAX path (CPU)
|
||||
3. JAX path (GPU) - if available
|
||||
|
||||
Reports any mismatches with detailed diffs.
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, '/home/giles/art/art-celery')
|
||||
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Test configuration
|
||||
TEST_WIDTH = 64
|
||||
TEST_HEIGHT = 48
|
||||
TOLERANCE_MEAN = 1.0 # Max allowed mean difference
|
||||
TOLERANCE_MAX = 10.0 # Max allowed single pixel difference
|
||||
TOLERANCE_PCT = 0.95 # Min % of pixels within ±1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
primitive: str
|
||||
passed: bool
|
||||
python_mean: float = 0.0
|
||||
jax_mean: float = 0.0
|
||||
diff_mean: float = 0.0
|
||||
diff_max: float = 0.0
|
||||
pct_within_1: float = 0.0
|
||||
error: str = ""
|
||||
|
||||
|
||||
def create_test_frame(w=TEST_WIDTH, h=TEST_HEIGHT, pattern='gradient'):
|
||||
"""Create a test frame with known pattern."""
|
||||
if pattern == 'gradient':
|
||||
# Diagonal gradient
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
r = (x * 255 / w).astype(np.uint8)
|
||||
g = (y * 255 / h).astype(np.uint8)
|
||||
b = ((x + y) * 127 / (w + h)).astype(np.uint8)
|
||||
return np.stack([r, g, b], axis=2)
|
||||
elif pattern == 'checkerboard':
|
||||
y, x = np.mgrid[0:h, 0:w]
|
||||
check = ((x // 8) + (y // 8)) % 2
|
||||
v = (check * 255).astype(np.uint8)
|
||||
return np.stack([v, v, v], axis=2)
|
||||
elif pattern == 'solid':
|
||||
return np.full((h, w, 3), 128, dtype=np.uint8)
|
||||
else:
|
||||
return np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
|
||||
|
||||
|
||||
def compare_outputs(py_out, jax_out) -> Tuple[float, float, float]:
|
||||
"""Compare two outputs, return (mean_diff, max_diff, pct_within_1)."""
|
||||
if py_out is None or jax_out is None:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
if isinstance(py_out, dict) and isinstance(jax_out, dict):
|
||||
# Compare coordinate maps
|
||||
diffs = []
|
||||
for k in py_out:
|
||||
if k in jax_out:
|
||||
py_arr = np.asarray(py_out[k])
|
||||
jax_arr = np.asarray(jax_out[k])
|
||||
if py_arr.shape == jax_arr.shape:
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
diffs.append(diff)
|
||||
if diffs:
|
||||
all_diff = np.concatenate([d.flatten() for d in diffs])
|
||||
return float(np.mean(all_diff)), float(np.max(all_diff)), float(np.mean(all_diff <= 1))
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
py_arr = np.asarray(py_out)
|
||||
jax_arr = np.asarray(jax_out)
|
||||
|
||||
if py_arr.shape != jax_arr.shape:
|
||||
return float('inf'), float('inf'), 0.0
|
||||
|
||||
diff = np.abs(py_arr.astype(float) - jax_arr.astype(float))
|
||||
return float(np.mean(diff)), float(np.max(diff)), float(np.mean(diff <= 1))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Primitive Test Definitions
|
||||
# ============================================================================
|
||||
|
||||
PRIMITIVE_TESTS = {
|
||||
# Geometry primitives
|
||||
'geometry:ripple-displace': {
|
||||
'args': [TEST_WIDTH, TEST_HEIGHT, 5, 10, TEST_WIDTH/2, TEST_HEIGHT/2, 1, 0.5],
|
||||
'returns': 'coords',
|
||||
},
|
||||
'geometry:rotate-img': {
|
||||
'args': ['frame', 45],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:scale-img': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-h': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'geometry:flip-v': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Color operations
|
||||
'color_ops:invert': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:grayscale': {
|
||||
'args': ['frame'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:brightness': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:contrast': {
|
||||
'args': ['frame', 1.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'color_ops:hue-shift': {
|
||||
'args': ['frame', 90],
|
||||
'returns': 'frame',
|
||||
},
|
||||
|
||||
# Image operations
|
||||
'image:width': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:height': {
|
||||
'args': ['frame'],
|
||||
'returns': 'scalar',
|
||||
},
|
||||
'image:channel': {
|
||||
'args': ['frame', 0],
|
||||
'returns': 'array',
|
||||
},
|
||||
|
||||
# Blending
|
||||
'blending:blend': {
|
||||
'args': ['frame', 'frame2', 0.5],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-add': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
'blending:blend-multiply': {
|
||||
'args': ['frame', 'frame2'],
|
||||
'returns': 'frame',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_python_primitive(interp, prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the Python interpreter."""
|
||||
if prim_name not in interp.primitives:
|
||||
return None
|
||||
|
||||
func = interp.primitives[prim_name]
|
||||
args = []
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(frame.copy())
|
||||
elif a == 'frame2':
|
||||
args.append(frame2.copy())
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
try:
|
||||
return func(*args)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def run_jax_primitive(prim_name: str, test_def: dict, frame: np.ndarray, frame2: np.ndarray) -> Any:
|
||||
"""Run a primitive through the JAX compiler."""
|
||||
try:
|
||||
from streaming.sexp_to_jax import JaxCompiler
|
||||
import jax.numpy as jnp
|
||||
|
||||
compiler = JaxCompiler()
|
||||
|
||||
# Build a simple expression to test the primitive
|
||||
from sexp_effects.parser import Symbol, Keyword
|
||||
|
||||
args = []
|
||||
env = {'frame': jnp.array(frame), 'frame2': jnp.array(frame2)}
|
||||
|
||||
for a in test_def['args']:
|
||||
if a == 'frame':
|
||||
args.append(Symbol('frame'))
|
||||
elif a == 'frame2':
|
||||
args.append(Symbol('frame2'))
|
||||
else:
|
||||
args.append(a)
|
||||
|
||||
# Create expression: (prim_name arg1 arg2 ...)
|
||||
expr = [Symbol(prim_name)] + args
|
||||
|
||||
result = compiler._eval(expr, env)
|
||||
|
||||
if hasattr(result, '__array__'):
|
||||
return np.asarray(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def test_primitive(interp, prim_name: str, test_def: dict) -> TestResult:
|
||||
"""Test a single primitive."""
|
||||
frame = create_test_frame(pattern='gradient')
|
||||
frame2 = create_test_frame(pattern='checkerboard')
|
||||
|
||||
result = TestResult(primitive=prim_name, passed=False)
|
||||
|
||||
# Run Python version
|
||||
try:
|
||||
py_out = run_python_primitive(interp, prim_name, test_def, frame, frame2)
|
||||
if py_out is not None and hasattr(py_out, 'shape'):
|
||||
result.python_mean = float(np.mean(py_out))
|
||||
except Exception as e:
|
||||
result.error = f"Python error: {e}"
|
||||
return result
|
||||
|
||||
# Run JAX version
|
||||
try:
|
||||
jax_out = run_jax_primitive(prim_name, test_def, frame, frame2)
|
||||
if jax_out is not None and hasattr(jax_out, 'shape'):
|
||||
result.jax_mean = float(np.mean(jax_out))
|
||||
except Exception as e:
|
||||
result.error = f"JAX error: {e}"
|
||||
return result
|
||||
|
||||
if py_out is None:
|
||||
result.error = "Python returned None"
|
||||
return result
|
||||
if jax_out is None:
|
||||
result.error = "JAX returned None"
|
||||
return result
|
||||
|
||||
# Compare
|
||||
diff_mean, diff_max, pct = compare_outputs(py_out, jax_out)
|
||||
result.diff_mean = diff_mean
|
||||
result.diff_max = diff_max
|
||||
result.pct_within_1 = pct
|
||||
|
||||
# Check pass/fail
|
||||
result.passed = (
|
||||
diff_mean <= TOLERANCE_MEAN and
|
||||
diff_max <= TOLERANCE_MAX and
|
||||
pct >= TOLERANCE_PCT
|
||||
)
|
||||
|
||||
if not result.passed:
|
||||
result.error = f"Diff too large: mean={diff_mean:.2f}, max={diff_max:.1f}, pct={pct:.1%}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def discover_primitives(interp) -> List[str]:
|
||||
"""Discover all primitives available in the interpreter."""
|
||||
return sorted(interp.primitives.keys())
|
||||
|
||||
|
||||
def run_all_tests(verbose=True):
|
||||
"""Run all primitive tests."""
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import os
|
||||
os.chdir('/home/giles/art/test')
|
||||
|
||||
from streaming.stream_sexp_generic import StreamInterpreter
|
||||
from sexp_effects.primitive_libs import core as core_mod
|
||||
|
||||
core_mod.set_random_seed(42)
|
||||
|
||||
# Create interpreter to get Python primitives
|
||||
interp = StreamInterpreter('effects/quick_test_explicit.sexp', use_jax=False)
|
||||
interp._init()
|
||||
|
||||
results = []
|
||||
|
||||
print("=" * 60)
|
||||
print("JAX PRIMITIVE TEST SUITE")
|
||||
print("=" * 60)
|
||||
|
||||
# Test defined primitives
|
||||
for prim_name, test_def in PRIMITIVE_TESTS.items():
|
||||
result = test_primitive(interp, prim_name, test_def)
|
||||
results.append(result)
|
||||
|
||||
status = "✓ PASS" if result.passed else "✗ FAIL"
|
||||
if verbose:
|
||||
print(f"{status} {prim_name}")
|
||||
if not result.passed:
|
||||
print(f" {result.error}")
|
||||
|
||||
# Summary
|
||||
passed = sum(1 for r in results if r.passed)
|
||||
failed = sum(1 for r in results if not r.passed)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"SUMMARY: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
if failed > 0:
|
||||
print("\nFailed primitives:")
|
||||
for r in results:
|
||||
if not r.passed:
|
||||
print(f" - {r.primitive}: {r.error}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_all_tests()
|
||||
246
l1/tests/test_naming_service.py
Normal file
246
l1/tests/test_naming_service.py
Normal file
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Tests for the friendly naming service.
|
||||
"""
|
||||
|
||||
import re
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Copy the pure functions from naming_service for testing
|
||||
# This avoids import issues with the app module
|
||||
|
||||
CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz"
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
"""Copy of normalize_name for testing."""
|
||||
name = name.lower()
|
||||
name = re.sub(r"[\s_]+", "-", name)
|
||||
name = re.sub(r"[^a-z0-9-]", "", name)
|
||||
name = re.sub(r"-+", "-", name)
|
||||
name = name.strip("-")
|
||||
return name or "unnamed"
|
||||
|
||||
|
||||
def parse_friendly_name(friendly_name: str):
|
||||
"""Copy of parse_friendly_name for testing."""
|
||||
parts = friendly_name.strip().split(" ", 1)
|
||||
base_name = parts[0]
|
||||
version_id = parts[1] if len(parts) > 1 else None
|
||||
return base_name, version_id
|
||||
|
||||
|
||||
def format_friendly_name(base_name: str, version_id: str) -> str:
|
||||
"""Copy of format_friendly_name for testing."""
|
||||
return f"{base_name} {version_id}"
|
||||
|
||||
|
||||
def format_l2_name(actor_id: str, base_name: str, version_id: str) -> str:
|
||||
"""Copy of format_l2_name for testing."""
|
||||
return f"{actor_id} {base_name} {version_id}"
|
||||
|
||||
|
||||
class TestNameNormalization:
|
||||
"""Tests for name normalization."""
|
||||
|
||||
def test_normalize_simple_name(self) -> None:
|
||||
"""Simple names should be lowercased."""
|
||||
assert normalize_name("Brightness") == "brightness"
|
||||
|
||||
def test_normalize_spaces_to_dashes(self) -> None:
|
||||
"""Spaces should be converted to dashes."""
|
||||
assert normalize_name("My Cool Effect") == "my-cool-effect"
|
||||
|
||||
def test_normalize_underscores_to_dashes(self) -> None:
|
||||
"""Underscores should be converted to dashes."""
|
||||
assert normalize_name("brightness_v2") == "brightness-v2"
|
||||
|
||||
def test_normalize_removes_special_chars(self) -> None:
|
||||
"""Special characters should be removed."""
|
||||
# Special chars are removed (not replaced with dashes)
|
||||
assert normalize_name("Test!!!Effect") == "testeffect"
|
||||
assert normalize_name("cool@effect#1") == "cooleffect1"
|
||||
# But spaces/underscores become dashes first
|
||||
assert normalize_name("Test Effect!") == "test-effect"
|
||||
|
||||
def test_normalize_collapses_dashes(self) -> None:
|
||||
"""Multiple dashes should be collapsed."""
|
||||
assert normalize_name("test--effect") == "test-effect"
|
||||
assert normalize_name("test___effect") == "test-effect"
|
||||
|
||||
def test_normalize_strips_edge_dashes(self) -> None:
|
||||
"""Leading/trailing dashes should be stripped."""
|
||||
assert normalize_name("-test-effect-") == "test-effect"
|
||||
|
||||
def test_normalize_empty_returns_unnamed(self) -> None:
|
||||
"""Empty names should return 'unnamed'."""
|
||||
assert normalize_name("") == "unnamed"
|
||||
assert normalize_name("---") == "unnamed"
|
||||
assert normalize_name("!!!") == "unnamed"
|
||||
|
||||
|
||||
class TestFriendlyNameParsing:
|
||||
"""Tests for friendly name parsing."""
|
||||
|
||||
def test_parse_base_name_only(self) -> None:
|
||||
"""Parsing base name only returns None for version."""
|
||||
base, version = parse_friendly_name("my-effect")
|
||||
assert base == "my-effect"
|
||||
assert version is None
|
||||
|
||||
def test_parse_with_version(self) -> None:
|
||||
"""Parsing with version returns both parts."""
|
||||
base, version = parse_friendly_name("my-effect 01hw3x9k")
|
||||
assert base == "my-effect"
|
||||
assert version == "01hw3x9k"
|
||||
|
||||
def test_parse_strips_whitespace(self) -> None:
|
||||
"""Parsing should strip leading/trailing whitespace."""
|
||||
base, version = parse_friendly_name(" my-effect ")
|
||||
assert base == "my-effect"
|
||||
assert version is None
|
||||
|
||||
|
||||
class TestFriendlyNameFormatting:
|
||||
"""Tests for friendly name formatting."""
|
||||
|
||||
def test_format_friendly_name(self) -> None:
|
||||
"""Format combines base and version with space."""
|
||||
assert format_friendly_name("my-effect", "01hw3x9k") == "my-effect 01hw3x9k"
|
||||
|
||||
def test_format_l2_name(self) -> None:
|
||||
"""L2 format includes actor ID."""
|
||||
result = format_l2_name("@alice@example.com", "my-effect", "01hw3x9k")
|
||||
assert result == "@alice@example.com my-effect 01hw3x9k"
|
||||
|
||||
|
||||
class TestDatabaseSchemaExists:
|
||||
"""Tests that verify database schema includes friendly_names table."""
|
||||
|
||||
def test_schema_has_friendly_names_table(self) -> None:
|
||||
"""Database schema should include friendly_names table."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "CREATE TABLE IF NOT EXISTS friendly_names" in content
|
||||
|
||||
def test_schema_has_required_columns(self) -> None:
|
||||
"""Friendly names table should have required columns."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "actor_id" in content
|
||||
assert "base_name" in content
|
||||
assert "version_id" in content
|
||||
assert "item_type" in content
|
||||
assert "display_name" in content
|
||||
|
||||
def test_schema_has_unique_constraints(self) -> None:
|
||||
"""Friendly names table should have unique constraints."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
# Unique on (actor_id, base_name, version_id)
|
||||
assert "UNIQUE(actor_id, base_name, version_id)" in content
|
||||
# Unique on (actor_id, cid)
|
||||
assert "UNIQUE(actor_id, cid)" in content
|
||||
|
||||
|
||||
class TestDatabaseFunctionsExist:
|
||||
"""Tests that verify database functions exist."""
|
||||
|
||||
def test_create_friendly_name_exists(self) -> None:
|
||||
"""create_friendly_name function should exist."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "async def create_friendly_name(" in content
|
||||
|
||||
def test_get_friendly_name_by_cid_exists(self) -> None:
|
||||
"""get_friendly_name_by_cid function should exist."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "async def get_friendly_name_by_cid(" in content
|
||||
|
||||
def test_resolve_friendly_name_exists(self) -> None:
|
||||
"""resolve_friendly_name function should exist."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "async def resolve_friendly_name(" in content
|
||||
|
||||
def test_list_friendly_names_exists(self) -> None:
|
||||
"""list_friendly_names function should exist."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "async def list_friendly_names(" in content
|
||||
|
||||
def test_delete_friendly_name_exists(self) -> None:
|
||||
"""delete_friendly_name function should exist."""
|
||||
path = Path(__file__).parent.parent / "database.py"
|
||||
content = path.read_text()
|
||||
assert "async def delete_friendly_name(" in content
|
||||
|
||||
|
||||
class TestNamingServiceModuleExists:
|
||||
"""Tests that verify naming service module structure."""
|
||||
|
||||
def test_module_file_exists(self) -> None:
|
||||
"""Naming service module file should exist."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
assert path.exists()
|
||||
|
||||
def test_module_has_normalize_name(self) -> None:
|
||||
"""Module should have normalize_name function."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "def normalize_name(" in content
|
||||
|
||||
def test_module_has_generate_version_id(self) -> None:
|
||||
"""Module should have generate_version_id function."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "def generate_version_id(" in content
|
||||
|
||||
def test_module_has_naming_service_class(self) -> None:
|
||||
"""Module should have NamingService class."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "class NamingService:" in content
|
||||
|
||||
def test_naming_service_has_assign_name(self) -> None:
|
||||
"""NamingService should have assign_name method."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "async def assign_name(" in content
|
||||
|
||||
def test_naming_service_has_resolve(self) -> None:
|
||||
"""NamingService should have resolve method."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "async def resolve(" in content
|
||||
|
||||
def test_naming_service_has_get_by_cid(self) -> None:
|
||||
"""NamingService should have get_by_cid method."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "async def get_by_cid(" in content
|
||||
|
||||
|
||||
class TestVersionIdProperties:
|
||||
"""Tests for version ID format properties (using actual function)."""
|
||||
|
||||
def test_version_id_format(self) -> None:
|
||||
"""Version ID should use base32-crockford alphabet."""
|
||||
# Read the naming service to verify it uses the right alphabet
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert 'CROCKFORD_ALPHABET = "0123456789abcdefghjkmnpqrstvwxyz"' in content
|
||||
|
||||
def test_version_id_uses_hmac(self) -> None:
|
||||
"""Version ID generation should use HMAC for server verification."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "hmac.new(" in content
|
||||
|
||||
def test_version_id_uses_timestamp(self) -> None:
|
||||
"""Version ID generation should be timestamp-based."""
|
||||
path = Path(__file__).parent.parent / "app" / "services" / "naming_service.py"
|
||||
content = path.read_text()
|
||||
assert "time.time()" in content
|
||||
150
l1/tests/test_recipe_visibility.py
Normal file
150
l1/tests/test_recipe_visibility.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Tests for recipe visibility in web UI.
|
||||
|
||||
Bug found 2026-01-12: Recipes not showing in list even after upload.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestRecipeListingFlow:
|
||||
"""Tests for recipe listing data flow."""
|
||||
|
||||
def test_cache_manager_has_list_by_type(self) -> None:
|
||||
"""L1CacheManager should have list_by_type method."""
|
||||
# Read cache_manager.py and verify list_by_type exists
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'def list_by_type' in content, \
|
||||
"L1CacheManager should have list_by_type method"
|
||||
|
||||
def test_list_by_type_returns_node_id(self) -> None:
|
||||
"""list_by_type should return entry.node_id values (IPFS CID)."""
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
# Find list_by_type function and verify it appends entry.node_id
|
||||
assert 'cids.append(entry.node_id)' in content, \
|
||||
"list_by_type should append entry.node_id (IPFS CID) to results"
|
||||
|
||||
def test_recipe_service_uses_database_items(self) -> None:
|
||||
"""Recipe service should use database.get_user_items for listing."""
|
||||
path = Path('/home/giles/art/art-celery/app/services/recipe_service.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'get_user_items' in content, \
|
||||
"Recipe service should use database.get_user_items for listing"
|
||||
|
||||
def test_recipe_upload_uses_recipe_node_type(self) -> None:
|
||||
"""Recipe upload should store with node_type='recipe'."""
|
||||
path = Path('/home/giles/art/art-celery/app/services/recipe_service.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'node_type="recipe"' in content, \
|
||||
"Recipe upload should use node_type='recipe'"
|
||||
|
||||
def test_get_by_cid_uses_find_by_cid(self) -> None:
|
||||
"""get_by_cid should use cache.find_by_cid to locate entries."""
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
# Verify get_by_cid uses find_by_cid
|
||||
assert 'find_by_cid(cid)' in content, \
|
||||
"get_by_cid should use find_by_cid to locate entries"
|
||||
|
||||
def test_no_duplicate_get_by_cid_methods(self) -> None:
|
||||
"""
|
||||
Regression test: There should only be ONE get_by_cid method.
|
||||
|
||||
Bug: Two get_by_cid methods existed, the second shadowed the first,
|
||||
breaking recipe lookup because the comprehensive method was hidden.
|
||||
"""
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
# Count occurrences of 'def get_by_cid'
|
||||
count = content.count('def get_by_cid')
|
||||
assert count == 1, \
|
||||
f"Should have exactly 1 get_by_cid method, found {count}"
|
||||
|
||||
|
||||
class TestRecipeFilterLogic:
|
||||
"""Tests for recipe filtering logic via database."""
|
||||
|
||||
def test_recipes_filtered_by_actor_id(self) -> None:
|
||||
"""list_recipes should filter by actor_id parameter."""
|
||||
path = Path('/home/giles/art/art-celery/app/services/recipe_service.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'actor_id' in content, \
|
||||
"list_recipes should accept actor_id parameter"
|
||||
|
||||
def test_recipes_uses_item_type_filter(self) -> None:
|
||||
"""list_recipes should filter by item_type='recipe'."""
|
||||
path = Path('/home/giles/art/art-celery/app/services/recipe_service.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'item_type="recipe"' in content, \
|
||||
"Recipe listing should filter by item_type='recipe'"
|
||||
|
||||
|
||||
class TestCacheEntryHasCid:
|
||||
"""Tests for cache entry cid field."""
|
||||
|
||||
def test_artdag_cache_entry_has_cid(self) -> None:
|
||||
"""artdag CacheEntry should have cid field."""
|
||||
from artdag import CacheEntry
|
||||
import dataclasses
|
||||
|
||||
fields = {f.name for f in dataclasses.fields(CacheEntry)}
|
||||
assert 'cid' in fields, \
|
||||
"CacheEntry should have cid field"
|
||||
|
||||
def test_artdag_cache_put_computes_cid(self) -> None:
|
||||
"""artdag Cache.put should compute and store cid in metadata."""
|
||||
from artdag import Cache
|
||||
import inspect
|
||||
|
||||
source = inspect.getsource(Cache.put)
|
||||
assert '"cid":' in source or "'cid':" in source, \
|
||||
"Cache.put should store cid in metadata"
|
||||
|
||||
|
||||
class TestListByTypeReturnsEntries:
|
||||
"""Tests for list_by_type returning cached entries."""
|
||||
|
||||
def test_list_by_type_iterates_cache_entries(self) -> None:
|
||||
"""list_by_type should iterate self.cache.list_entries()."""
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'self.cache.list_entries()' in content, \
|
||||
"list_by_type should iterate cache entries"
|
||||
|
||||
def test_list_by_type_filters_by_node_type(self) -> None:
|
||||
"""list_by_type should filter entries by node_type."""
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'entry.node_type == node_type' in content, \
|
||||
"list_by_type should filter by node_type"
|
||||
|
||||
def test_list_by_type_returns_node_id(self) -> None:
|
||||
"""list_by_type should return entry.node_id (IPFS CID)."""
|
||||
path = Path('/home/giles/art/art-celery/cache_manager.py')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'cids.append(entry.node_id)' in content, \
|
||||
"list_by_type should append entry.node_id (IPFS CID)"
|
||||
|
||||
def test_artdag_cache_list_entries_returns_all(self) -> None:
|
||||
"""artdag Cache.list_entries should return all entries."""
|
||||
from artdag import Cache
|
||||
import inspect
|
||||
|
||||
source = inspect.getsource(Cache.list_entries)
|
||||
# Should return self._entries.values()
|
||||
assert '_entries' in source, \
|
||||
"list_entries should access _entries dict"
|
||||
111
l1/tests/test_run_artifacts.py
Normal file
111
l1/tests/test_run_artifacts.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Tests for run artifacts data structure and template variables.
|
||||
|
||||
Bug found 2026-01-12: runs/detail.html template expects artifact.cid
|
||||
but the route provides artifact.hash, causing UndefinedError.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestCacheNotFoundTemplate:
|
||||
"""Tests for cache/not_found.html template."""
|
||||
|
||||
def test_template_uses_cid_not_content_hash(self) -> None:
|
||||
"""
|
||||
Regression test: not_found.html must use 'cid' variable.
|
||||
|
||||
Bug: Template used 'content_hash' but route passes 'cid'.
|
||||
Fix: Changed template to use 'cid'.
|
||||
"""
|
||||
template_path = Path('/home/giles/art/art-celery/app/templates/cache/not_found.html')
|
||||
content = template_path.read_text()
|
||||
|
||||
assert 'cid' in content, "Template should use 'cid' variable"
|
||||
assert 'content_hash' not in content, \
|
||||
"Template should not use 'content_hash' (route passes 'cid')"
|
||||
|
||||
def test_route_passes_cid_to_template(self) -> None:
|
||||
"""Verify route passes 'cid' variable to not_found template."""
|
||||
router_path = Path('/home/giles/art/art-celery/app/routers/cache.py')
|
||||
content = router_path.read_text()
|
||||
|
||||
# Find the render call for not_found.html
|
||||
assert 'cid=cid' in content, \
|
||||
"Route should pass cid=cid to not_found.html template"
|
||||
|
||||
|
||||
class TestInputPreviewsDataStructure:
|
||||
"""Tests for input_previews dict keys matching template expectations."""
|
||||
|
||||
def test_run_card_template_expects_inp_cid(self) -> None:
|
||||
"""Run card template uses inp.cid for input previews."""
|
||||
path = Path('/home/giles/art/art-celery/app/templates/runs/_run_card.html')
|
||||
content = path.read_text()
|
||||
|
||||
assert 'inp.cid' in content, \
|
||||
"Run card template should use inp.cid for input previews"
|
||||
|
||||
def test_input_previews_use_cid_key(self) -> None:
|
||||
"""
|
||||
Regression test: input_previews must use 'cid' key not 'hash'.
|
||||
|
||||
Bug: Router created input_previews with 'hash' key but template expected 'cid'.
|
||||
"""
|
||||
path = Path('/home/giles/art/art-celery/app/routers/runs.py')
|
||||
content = path.read_text()
|
||||
|
||||
# Should have: "cid": input_hash, not "hash": input_hash
|
||||
assert '"cid": input_hash' in content, \
|
||||
"input_previews should use 'cid' key"
|
||||
assert '"hash": input_hash' not in content, \
|
||||
"input_previews should not use 'hash' key (template expects 'cid')"
|
||||
|
||||
|
||||
class TestArtifactDataStructure:
|
||||
"""Tests for artifact dict keys matching template expectations."""
|
||||
|
||||
def test_template_expects_cid_key(self) -> None:
|
||||
"""Template uses artifact.cid - verify this expectation."""
|
||||
template_path = Path('/home/giles/art/art-celery/app/templates/runs/detail.html')
|
||||
content = template_path.read_text()
|
||||
|
||||
# Template uses artifact.cid in multiple places
|
||||
assert 'artifact.cid' in content, "Template should reference artifact.cid"
|
||||
|
||||
def test_run_service_artifacts_have_cid_key(self) -> None:
|
||||
"""
|
||||
Regression test: get_run_artifacts must return dicts with 'cid' key.
|
||||
|
||||
Bug: Service returned artifacts with 'hash' key but template expected 'cid'.
|
||||
Fix: Changed service to use 'cid' key for consistency.
|
||||
"""
|
||||
# Read the run_service.py file and check it uses 'cid' not 'hash'
|
||||
service_path = Path('/home/giles/art/art-celery/app/services/run_service.py')
|
||||
content = service_path.read_text()
|
||||
|
||||
# Find the get_artifact_info function and check it returns 'cid'
|
||||
# The function should have: "cid": cid, not "hash": cid
|
||||
assert '"cid": cid' in content or "'cid': cid" in content, \
|
||||
"get_run_artifacts should return artifacts with 'cid' key, not 'hash'"
|
||||
|
||||
def test_router_artifacts_have_cid_key(self) -> None:
|
||||
"""
|
||||
Regression test: inline artifact creation in router must use 'cid' key.
|
||||
|
||||
Bug: Router created artifacts with 'hash' key but template expected 'cid'.
|
||||
"""
|
||||
router_path = Path('/home/giles/art/art-celery/app/routers/runs.py')
|
||||
content = router_path.read_text()
|
||||
|
||||
# Check that artifacts.append uses 'cid' key
|
||||
# Should have: "cid": output_cid, not "hash": output_cid
|
||||
# Count occurrences of the patterns
|
||||
hash_pattern_count = content.count('"hash": output_cid')
|
||||
cid_pattern_count = content.count('"cid": output_cid')
|
||||
|
||||
assert cid_pattern_count > 0, \
|
||||
"Router should create artifacts with 'cid' key"
|
||||
assert hash_pattern_count == 0, \
|
||||
"Router should not use 'hash' key for artifacts (template expects 'cid')"
|
||||
305
l1/tests/test_xector.py
Normal file
305
l1/tests/test_xector.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Tests for xector primitives - parallel array operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from sexp_effects.primitive_libs.xector import (
|
||||
Xector,
|
||||
xector_red, xector_green, xector_blue, xector_rgb,
|
||||
xector_x_coords, xector_y_coords, xector_x_norm, xector_y_norm,
|
||||
xector_dist_from_center,
|
||||
alpha_add, alpha_sub, alpha_mul, alpha_div, alpha_sqrt, alpha_clamp,
|
||||
alpha_sin, alpha_cos, alpha_sq,
|
||||
alpha_lt, alpha_gt, alpha_eq,
|
||||
beta_add, beta_mul, beta_min, beta_max, beta_mean, beta_count,
|
||||
xector_where, xector_fill, xector_zeros, xector_ones,
|
||||
is_xector,
|
||||
)
|
||||
|
||||
|
||||
class TestXectorBasics:
|
||||
"""Test Xector class basic operations."""
|
||||
|
||||
def test_create_from_list(self):
|
||||
x = Xector([1, 2, 3])
|
||||
assert len(x) == 3
|
||||
assert is_xector(x)
|
||||
|
||||
def test_create_from_numpy(self):
|
||||
arr = np.array([1.0, 2.0, 3.0])
|
||||
x = Xector(arr)
|
||||
assert len(x) == 3
|
||||
np.testing.assert_array_equal(x.to_numpy(), arr.astype(np.float32))
|
||||
|
||||
def test_implicit_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = a + b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_implicit_mul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([2, 2, 2])
|
||||
c = a * b
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_scalar_broadcast(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = a + 10
|
||||
np.testing.assert_array_equal(c.to_numpy(), [11, 12, 13])
|
||||
|
||||
def test_scalar_broadcast_rmul(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = 2 * a
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
|
||||
class TestAlphaOperations:
|
||||
"""Test α (element-wise) operations."""
|
||||
|
||||
def test_alpha_add(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([4, 5, 6])
|
||||
c = alpha_add(a, b)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [5, 7, 9])
|
||||
|
||||
def test_alpha_add_multi(self):
|
||||
a = Xector([1, 2, 3])
|
||||
b = Xector([1, 1, 1])
|
||||
c = Xector([10, 10, 10])
|
||||
d = alpha_add(a, b, c)
|
||||
np.testing.assert_array_equal(d.to_numpy(), [12, 13, 14])
|
||||
|
||||
def test_alpha_mul_scalar(self):
|
||||
a = Xector([1, 2, 3])
|
||||
c = alpha_mul(a, 2)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [2, 4, 6])
|
||||
|
||||
def test_alpha_sqrt(self):
|
||||
a = Xector([1, 4, 9, 16])
|
||||
c = alpha_sqrt(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 2, 3, 4])
|
||||
|
||||
def test_alpha_clamp(self):
|
||||
a = Xector([-5, 0, 5, 10, 15])
|
||||
c = alpha_clamp(a, 0, 10)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [0, 0, 5, 10, 10])
|
||||
|
||||
def test_alpha_sin_cos(self):
|
||||
a = Xector([0, np.pi/2, np.pi])
|
||||
s = alpha_sin(a)
|
||||
c = alpha_cos(a)
|
||||
np.testing.assert_array_almost_equal(s.to_numpy(), [0, 1, 0], decimal=5)
|
||||
np.testing.assert_array_almost_equal(c.to_numpy(), [1, 0, -1], decimal=5)
|
||||
|
||||
def test_alpha_sq(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
c = alpha_sq(a)
|
||||
np.testing.assert_array_equal(c.to_numpy(), [1, 4, 9, 16])
|
||||
|
||||
def test_alpha_comparison(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
b = Xector([2, 2, 2, 2])
|
||||
lt = alpha_lt(a, b)
|
||||
gt = alpha_gt(a, b)
|
||||
eq = alpha_eq(a, b)
|
||||
np.testing.assert_array_equal(lt.to_numpy(), [True, False, False, False])
|
||||
np.testing.assert_array_equal(gt.to_numpy(), [False, False, True, True])
|
||||
np.testing.assert_array_equal(eq.to_numpy(), [False, True, False, False])
|
||||
|
||||
|
||||
class TestBetaOperations:
|
||||
"""Test β (reduction) operations."""
|
||||
|
||||
def test_beta_add(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_add(a) == 10
|
||||
|
||||
def test_beta_mul(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mul(a) == 24
|
||||
|
||||
def test_beta_min_max(self):
|
||||
a = Xector([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
assert beta_min(a) == 1
|
||||
assert beta_max(a) == 9
|
||||
|
||||
def test_beta_mean(self):
|
||||
a = Xector([1, 2, 3, 4])
|
||||
assert beta_mean(a) == 2.5
|
||||
|
||||
def test_beta_count(self):
|
||||
a = Xector([1, 2, 3, 4, 5])
|
||||
assert beta_count(a) == 5
|
||||
|
||||
|
||||
class TestFrameConversion:
|
||||
"""Test frame/xector conversion."""
|
||||
|
||||
def test_extract_channels(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[255, 0, 0], [0, 255, 0]],
|
||||
[[0, 0, 255], [128, 128, 128]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
assert len(r) == 4
|
||||
np.testing.assert_array_equal(r.to_numpy(), [255, 0, 0, 128])
|
||||
np.testing.assert_array_equal(g.to_numpy(), [0, 255, 0, 128])
|
||||
np.testing.assert_array_equal(b.to_numpy(), [0, 0, 255, 128])
|
||||
|
||||
def test_rgb_roundtrip(self):
|
||||
# Create a 2x2 RGB frame
|
||||
frame = np.array([
|
||||
[[100, 150, 200], [50, 75, 100]],
|
||||
[[200, 100, 50], [25, 50, 75]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
reconstructed = xector_rgb(r, g, b)
|
||||
np.testing.assert_array_equal(reconstructed, frame)
|
||||
|
||||
def test_modify_and_reconstruct(self):
|
||||
frame = np.array([
|
||||
[[100, 100, 100], [100, 100, 100]],
|
||||
[[100, 100, 100], [100, 100, 100]]
|
||||
], dtype=np.uint8)
|
||||
|
||||
r = xector_red(frame)
|
||||
g = xector_green(frame)
|
||||
b = xector_blue(frame)
|
||||
|
||||
# Double red channel
|
||||
r_doubled = r * 2
|
||||
|
||||
result = xector_rgb(r_doubled, g, b)
|
||||
|
||||
# Red should be 200, others unchanged
|
||||
assert result[0, 0, 0] == 200
|
||||
assert result[0, 0, 1] == 100
|
||||
assert result[0, 0, 2] == 100
|
||||
|
||||
|
||||
class TestCoordinates:
|
||||
"""Test coordinate generation."""
|
||||
|
||||
def test_x_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
x = xector_x_coords(frame)
|
||||
# Should be [0,1,2, 0,1,2] (x coords repeated for each row)
|
||||
np.testing.assert_array_equal(x.to_numpy(), [0, 1, 2, 0, 1, 2])
|
||||
|
||||
def test_y_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8) # 2 rows, 3 cols
|
||||
y = xector_y_coords(frame)
|
||||
# Should be [0,0,0, 1,1,1] (y coords for each pixel)
|
||||
np.testing.assert_array_equal(y.to_numpy(), [0, 0, 0, 1, 1, 1])
|
||||
|
||||
def test_normalized_coords(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_x_norm(frame)
|
||||
y = xector_y_norm(frame)
|
||||
|
||||
# x should go 0 to 1 across width
|
||||
assert x.to_numpy()[0] == 0
|
||||
assert x.to_numpy()[2] == 1
|
||||
|
||||
# y should go 0 to 1 down height
|
||||
assert y.to_numpy()[0] == 0
|
||||
assert y.to_numpy()[3] == 1
|
||||
|
||||
|
||||
class TestConditional:
|
||||
"""Test conditional operations."""
|
||||
|
||||
def test_where(self):
|
||||
cond = Xector([True, False, True, False])
|
||||
true_val = Xector([1, 1, 1, 1])
|
||||
false_val = Xector([0, 0, 0, 0])
|
||||
|
||||
result = xector_where(cond, true_val, false_val)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [1, 0, 1, 0])
|
||||
|
||||
def test_where_with_comparison(self):
|
||||
a = Xector([1, 5, 3, 7])
|
||||
threshold = 4
|
||||
|
||||
# Elements > 4 become 255, others become 0
|
||||
result = xector_where(alpha_gt(a, threshold), 255, 0)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [0, 255, 0, 255])
|
||||
|
||||
def test_fill(self):
|
||||
frame = np.zeros((2, 3, 3), dtype=np.uint8)
|
||||
x = xector_fill(42, frame)
|
||||
assert len(x) == 6
|
||||
assert all(v == 42 for v in x.to_numpy())
|
||||
|
||||
def test_zeros_ones(self):
|
||||
frame = np.zeros((2, 2, 3), dtype=np.uint8)
|
||||
z = xector_zeros(frame)
|
||||
o = xector_ones(frame)
|
||||
|
||||
assert all(v == 0 for v in z.to_numpy())
|
||||
assert all(v == 1 for v in o.to_numpy())
|
||||
|
||||
|
||||
class TestInterpreterIntegration:
|
||||
"""Test xector operations through the interpreter."""
|
||||
|
||||
def test_xector_vignette_effect(self):
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
|
||||
# Load the xector vignette effect
|
||||
interp.load_effect('sexp_effects/effects/xector_vignette.sexp')
|
||||
|
||||
# Create a test frame (white)
|
||||
frame = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||
|
||||
# Run effect
|
||||
result, state = interp.run_effect('xector_vignette', frame, {'strength': 0.5}, {})
|
||||
|
||||
# Center should be brighter than corners
|
||||
center = result[50, 50]
|
||||
corner = result[0, 0]
|
||||
|
||||
assert center.mean() > corner.mean(), "Center should be brighter than corners"
|
||||
# Corners should be darkened
|
||||
assert corner.mean() < 255, "Corners should be darkened"
|
||||
|
||||
def test_implicit_elementwise(self):
|
||||
"""Test that regular + works element-wise on xectors."""
|
||||
from sexp_effects.interpreter import Interpreter
|
||||
|
||||
interp = Interpreter(minimal_primitives=True)
|
||||
# Load xector primitives
|
||||
from sexp_effects.primitive_libs.xector import PRIMITIVES
|
||||
for name, fn in PRIMITIVES.items():
|
||||
interp.global_env.set(name, fn)
|
||||
|
||||
# Parse and eval a simple xector expression
|
||||
from sexp_effects.parser import parse
|
||||
expr = parse('(+ (red frame) 10)')
|
||||
|
||||
# Create test frame
|
||||
frame = np.full((2, 2, 3), 100, dtype=np.uint8)
|
||||
interp.global_env.set('frame', frame)
|
||||
|
||||
result = interp.eval(expr)
|
||||
|
||||
# Should be a xector with values 110
|
||||
assert is_xector(result)
|
||||
np.testing.assert_array_equal(result.to_numpy(), [110, 110, 110, 110])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user