Files
celery/cache_manager.py
gilesb 43788108c0 Fix Celery workers to use Redis for shared cache index
The get_cache_manager() singleton wasn't initializing with Redis,
so workers couldn't see files uploaded via the API server.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 11:25:38 +00:00

785 lines
28 KiB
Python

# art-celery/cache_manager.py
"""
Cache management for Art DAG L1 server.
Integrates artdag's Cache, ActivityStore, and ActivityManager to provide:
- Content-addressed caching with both node_id and content_hash
- Activity tracking for runs (input/output/intermediate relationships)
- Deletion rules enforcement (shared items protected)
- L2 ActivityPub integration for "shared" status checks
- IPFS as durable backing store (local cache as hot storage)
- Redis-backed indexes for multi-worker consistency
"""
import hashlib
import json
import logging
import os
import shutil
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING
import requests
if TYPE_CHECKING:
import redis
from artdag import Cache, CacheEntry, DAG, Node, NodeType
from artdag.activities import Activity, ActivityStore, ActivityManager, make_is_shared_fn
import ipfs_client
logger = logging.getLogger(__name__)
def file_hash(path: Path, algorithm: str = "sha3_256") -> str:
"""Compute SHA3-256 hash of a file."""
hasher = hashlib.new(algorithm)
actual_path = path.resolve() if path.is_symlink() else path
with open(actual_path, "rb") as f:
for chunk in iter(lambda: f.read(65536), b""):
hasher.update(chunk)
return hasher.hexdigest()
@dataclass
class CachedFile:
"""
A cached file with both identifiers.
Provides a unified view combining:
- node_id: computation identity (for DAG caching)
- content_hash: file content identity (for external references)
"""
node_id: str
content_hash: str
path: Path
size_bytes: int
node_type: str
created_at: float
@classmethod
def from_cache_entry(cls, entry: CacheEntry) -> "CachedFile":
return cls(
node_id=entry.node_id,
content_hash=entry.content_hash,
path=entry.output_path,
size_bytes=entry.size_bytes,
node_type=entry.node_type,
created_at=entry.created_at,
)
class L2SharedChecker:
"""
Checks if content is shared (published) via L2 ActivityPub server.
Caches results to avoid repeated API calls.
"""
def __init__(self, l2_server: str, cache_ttl: int = 300):
self.l2_server = l2_server
self.cache_ttl = cache_ttl
self._cache: Dict[str, tuple[bool, float]] = {}
def is_shared(self, content_hash: str) -> bool:
"""Check if content_hash has been published to L2."""
import time
now = time.time()
# Check cache
if content_hash in self._cache:
is_shared, cached_at = self._cache[content_hash]
if now - cached_at < self.cache_ttl:
logger.debug(f"L2 check (cached): {content_hash[:16]}... = {is_shared}")
return is_shared
# Query L2
try:
url = f"{self.l2_server}/assets/by-hash/{content_hash}"
logger.info(f"L2 check: GET {url}")
resp = requests.get(url, timeout=5)
logger.info(f"L2 check response: {resp.status_code}")
is_shared = resp.status_code == 200
except Exception as e:
logger.warning(f"Failed to check L2 for {content_hash}: {e}")
# On error, assume IS shared (safer - prevents accidental deletion)
is_shared = True
self._cache[content_hash] = (is_shared, now)
return is_shared
def invalidate(self, content_hash: str):
"""Invalidate cache for a content_hash (call after publishing)."""
self._cache.pop(content_hash, None)
def mark_shared(self, content_hash: str):
"""Mark as shared without querying (call after successful publish)."""
import time
self._cache[content_hash] = (True, time.time())
class L1CacheManager:
"""
Unified cache manager for Art DAG L1 server.
Combines:
- artdag Cache for file storage
- ActivityStore for run tracking
- ActivityManager for deletion rules
- L2 integration for shared status
Provides both node_id and content_hash based access.
"""
def __init__(
self,
cache_dir: Path | str,
l2_server: str = "http://localhost:8200",
redis_client: Optional["redis.Redis"] = None,
):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Redis for shared state between workers
self._redis = redis_client
self._redis_content_key = "artdag:content_index"
self._redis_ipfs_key = "artdag:ipfs_index"
# artdag components
self.cache = Cache(self.cache_dir / "nodes")
self.activity_store = ActivityStore(self.cache_dir / "activities")
# L2 shared checker
self.l2_checker = L2SharedChecker(l2_server)
# Activity manager with L2-based is_shared
self.activity_manager = ActivityManager(
cache=self.cache,
activity_store=self.activity_store,
is_shared_fn=self._is_shared_by_node_id,
)
# Content hash index: content_hash -> node_id
# Uses Redis if available, falls back to in-memory dict
self._content_index: Dict[str, str] = {}
self._load_content_index()
# IPFS CID index: content_hash -> ipfs_cid
self._ipfs_cids: Dict[str, str] = {}
self._load_ipfs_index()
# Legacy files directory (for files uploaded directly by content_hash)
self.legacy_dir = self.cache_dir / "legacy"
self.legacy_dir.mkdir(parents=True, exist_ok=True)
def _index_path(self) -> Path:
return self.cache_dir / "content_index.json"
def _load_content_index(self):
"""Load content_hash -> node_id index from Redis or JSON file."""
# If Redis available and has data, use it
if self._redis:
try:
redis_data = self._redis.hgetall(self._redis_content_key)
if redis_data:
self._content_index = {
k.decode() if isinstance(k, bytes) else k:
v.decode() if isinstance(v, bytes) else v
for k, v in redis_data.items()
}
logger.info(f"Loaded {len(self._content_index)} content index entries from Redis")
return
except Exception as e:
logger.warning(f"Failed to load content index from Redis: {e}")
# Fall back to JSON file
if self._index_path().exists():
try:
with open(self._index_path()) as f:
self._content_index = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load content index: {e}")
self._content_index = {}
# Also index from existing cache entries
for entry in self.cache.list_entries():
if entry.content_hash:
self._content_index[entry.content_hash] = entry.node_id
# Migrate to Redis if available
if self._redis and self._content_index:
try:
self._redis.hset(self._redis_content_key, mapping=self._content_index)
logger.info(f"Migrated {len(self._content_index)} content index entries to Redis")
except Exception as e:
logger.warning(f"Failed to migrate content index to Redis: {e}")
def _save_content_index(self):
"""Save content_hash -> node_id index to Redis and JSON file."""
# Always save to JSON as backup
with open(self._index_path(), "w") as f:
json.dump(self._content_index, f, indent=2)
def _set_content_index(self, content_hash: str, node_id: str):
"""Set a single content index entry (Redis + in-memory)."""
self._content_index[content_hash] = node_id
if self._redis:
try:
self._redis.hset(self._redis_content_key, content_hash, node_id)
except Exception as e:
logger.warning(f"Failed to set content index in Redis: {e}")
self._save_content_index()
def _get_content_index(self, content_hash: str) -> Optional[str]:
"""Get a content index entry (Redis-first, then in-memory)."""
if self._redis:
try:
val = self._redis.hget(self._redis_content_key, content_hash)
if val:
return val.decode() if isinstance(val, bytes) else val
except Exception as e:
logger.warning(f"Failed to get content index from Redis: {e}")
return self._content_index.get(content_hash)
def _del_content_index(self, content_hash: str):
"""Delete a content index entry."""
if content_hash in self._content_index:
del self._content_index[content_hash]
if self._redis:
try:
self._redis.hdel(self._redis_content_key, content_hash)
except Exception as e:
logger.warning(f"Failed to delete content index from Redis: {e}")
self._save_content_index()
def _ipfs_index_path(self) -> Path:
return self.cache_dir / "ipfs_index.json"
def _load_ipfs_index(self):
"""Load content_hash -> ipfs_cid index from Redis or JSON file."""
# If Redis available and has data, use it
if self._redis:
try:
redis_data = self._redis.hgetall(self._redis_ipfs_key)
if redis_data:
self._ipfs_cids = {
k.decode() if isinstance(k, bytes) else k:
v.decode() if isinstance(v, bytes) else v
for k, v in redis_data.items()
}
logger.info(f"Loaded {len(self._ipfs_cids)} IPFS index entries from Redis")
return
except Exception as e:
logger.warning(f"Failed to load IPFS index from Redis: {e}")
# Fall back to JSON file
if self._ipfs_index_path().exists():
try:
with open(self._ipfs_index_path()) as f:
self._ipfs_cids = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load IPFS index: {e}")
self._ipfs_cids = {}
# Migrate to Redis if available
if self._redis and self._ipfs_cids:
try:
self._redis.hset(self._redis_ipfs_key, mapping=self._ipfs_cids)
logger.info(f"Migrated {len(self._ipfs_cids)} IPFS index entries to Redis")
except Exception as e:
logger.warning(f"Failed to migrate IPFS index to Redis: {e}")
def _save_ipfs_index(self):
"""Save content_hash -> ipfs_cid index to JSON file (backup)."""
with open(self._ipfs_index_path(), "w") as f:
json.dump(self._ipfs_cids, f, indent=2)
def _set_ipfs_index(self, content_hash: str, ipfs_cid: str):
"""Set a single IPFS index entry (Redis + in-memory)."""
self._ipfs_cids[content_hash] = ipfs_cid
if self._redis:
try:
self._redis.hset(self._redis_ipfs_key, content_hash, ipfs_cid)
except Exception as e:
logger.warning(f"Failed to set IPFS index in Redis: {e}")
self._save_ipfs_index()
def _get_ipfs_cid_from_index(self, content_hash: str) -> Optional[str]:
"""Get IPFS CID from index (Redis-first, then in-memory)."""
if self._redis:
try:
val = self._redis.hget(self._redis_ipfs_key, content_hash)
if val:
return val.decode() if isinstance(val, bytes) else val
except Exception as e:
logger.warning(f"Failed to get IPFS CID from Redis: {e}")
return self._ipfs_cids.get(content_hash)
def get_ipfs_cid(self, content_hash: str) -> Optional[str]:
"""Get IPFS CID for a content hash."""
return self._get_ipfs_cid_from_index(content_hash)
def _is_shared_by_node_id(self, content_hash: str) -> bool:
"""Check if a content_hash is shared via L2."""
return self.l2_checker.is_shared(content_hash)
def _load_meta(self, content_hash: str) -> dict:
"""Load metadata for a cached file."""
meta_path = self.cache_dir / f"{content_hash}.meta.json"
if meta_path.exists():
with open(meta_path) as f:
return json.load(f)
return {}
def is_pinned(self, content_hash: str) -> tuple[bool, str]:
"""
Check if a content_hash is pinned (non-deletable).
Returns:
(is_pinned, reason) tuple
"""
meta = self._load_meta(content_hash)
if meta.get("pinned"):
return True, meta.get("pin_reason", "published")
return False, ""
def _save_meta(self, content_hash: str, **updates) -> dict:
"""Save/update metadata for a cached file."""
meta = self._load_meta(content_hash)
meta.update(updates)
meta_path = self.cache_dir / f"{content_hash}.meta.json"
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
return meta
def pin(self, content_hash: str, reason: str = "published") -> None:
"""Mark an item as pinned (non-deletable)."""
self._save_meta(content_hash, pinned=True, pin_reason=reason)
# ============ File Storage ============
def put(
self,
source_path: Path,
node_type: str = "upload",
node_id: str = None,
execution_time: float = 0.0,
move: bool = False,
) -> tuple[CachedFile, Optional[str]]:
"""
Store a file in the cache and upload to IPFS.
Args:
source_path: Path to file to cache
node_type: Type of node (e.g., "upload", "source", "effect")
node_id: Optional node_id; if not provided, uses content_hash
execution_time: How long the operation took
move: If True, move instead of copy
Returns:
Tuple of (CachedFile with both node_id and content_hash, IPFS CID or None)
"""
# Compute content hash first
content_hash = file_hash(source_path)
# Use content_hash as node_id if not provided
# This is for legacy/uploaded files that don't have a DAG node
if node_id is None:
node_id = content_hash
# Check if already cached (by node_id)
existing = self.cache.get_entry(node_id)
if existing and existing.output_path.exists():
# Already cached - still try to get IPFS CID if we don't have it
ipfs_cid = self._get_ipfs_cid_from_index(content_hash)
if not ipfs_cid:
ipfs_cid = ipfs_client.add_file(existing.output_path)
if ipfs_cid:
self._set_ipfs_index(content_hash, ipfs_cid)
return CachedFile.from_cache_entry(existing), ipfs_cid
# Store in local cache
self.cache.put(
node_id=node_id,
source_path=source_path,
node_type=node_type,
execution_time=execution_time,
move=move,
)
entry = self.cache.get_entry(node_id)
# Update content index (Redis + local)
self._set_content_index(entry.content_hash, node_id)
# Upload to IPFS (async in background would be better, but sync for now)
ipfs_cid = ipfs_client.add_file(entry.output_path)
if ipfs_cid:
self._set_ipfs_index(entry.content_hash, ipfs_cid)
logger.info(f"Uploaded to IPFS: {entry.content_hash[:16]}... -> {ipfs_cid}")
return CachedFile.from_cache_entry(entry), ipfs_cid
def get_by_node_id(self, node_id: str) -> Optional[Path]:
"""Get cached file path by node_id."""
return self.cache.get(node_id)
def get_by_content_hash(self, content_hash: str) -> Optional[Path]:
"""Get cached file path by content_hash. Falls back to IPFS if not in local cache."""
# Check index first (Redis then local)
node_id = self._get_content_index(content_hash)
if node_id:
path = self.cache.get(node_id)
if path and path.exists():
logger.debug(f" Found via index: {path}")
return path
# For uploads, node_id == content_hash, so try direct lookup
# This works even if cache index hasn't been reloaded
path = self.cache.get(content_hash)
logger.debug(f" cache.get({content_hash[:16]}...) returned: {path}")
if path and path.exists():
self._set_content_index(content_hash, content_hash)
return path
# Scan cache entries (fallback for new structure)
entry = self.cache.find_by_content_hash(content_hash)
if entry and entry.output_path.exists():
logger.debug(f" Found via scan: {entry.output_path}")
self._set_content_index(content_hash, entry.node_id)
return entry.output_path
# Check legacy location (files stored directly as CACHE_DIR/{content_hash})
legacy_path = self.cache_dir / content_hash
if legacy_path.exists() and legacy_path.is_file():
return legacy_path
# Try to recover from IPFS if we have a CID
ipfs_cid = self._get_ipfs_cid_from_index(content_hash)
if ipfs_cid:
logger.info(f"Recovering from IPFS: {content_hash[:16]}... ({ipfs_cid})")
recovery_path = self.legacy_dir / content_hash
if ipfs_client.get_file(ipfs_cid, recovery_path):
logger.info(f"Recovered from IPFS: {recovery_path}")
return recovery_path
return None
def has_content(self, content_hash: str) -> bool:
"""Check if content exists in cache."""
return self.get_by_content_hash(content_hash) is not None
def get_entry_by_content_hash(self, content_hash: str) -> Optional[CacheEntry]:
"""Get cache entry by content_hash."""
node_id = self._get_content_index(content_hash)
if node_id:
return self.cache.get_entry(node_id)
return self.cache.find_by_content_hash(content_hash)
def list_all(self) -> List[CachedFile]:
"""List all cached files."""
files = []
seen_hashes = set()
# New cache structure entries
for entry in self.cache.list_entries():
files.append(CachedFile.from_cache_entry(entry))
if entry.content_hash:
seen_hashes.add(entry.content_hash)
# Legacy files stored directly in cache_dir (old structure)
# These are files named by content_hash directly in CACHE_DIR
for f in self.cache_dir.iterdir():
# Skip directories and special files
if not f.is_file():
continue
# Skip metadata/auxiliary files
if f.suffix in ('.json', '.mp4'):
continue
# Skip if name doesn't look like a hash (64 hex chars)
if len(f.name) != 64 or not all(c in '0123456789abcdef' for c in f.name):
continue
# Skip if already seen via new cache
if f.name in seen_hashes:
continue
files.append(CachedFile(
node_id=f.name,
content_hash=f.name,
path=f,
size_bytes=f.stat().st_size,
node_type="legacy",
created_at=f.stat().st_mtime,
))
seen_hashes.add(f.name)
return files
# ============ Activity Tracking ============
def record_activity(self, dag: DAG, run_id: str = None) -> Activity:
"""
Record a DAG execution as an activity.
Args:
dag: The executed DAG
run_id: Optional run ID to use as activity_id
Returns:
The created Activity
"""
activity = Activity.from_dag(dag, activity_id=run_id)
self.activity_store.add(activity)
return activity
def record_simple_activity(
self,
input_hashes: List[str],
output_hash: str,
run_id: str = None,
) -> Activity:
"""
Record a simple (non-DAG) execution as an activity.
For legacy single-effect runs that don't use full DAG execution.
Uses content_hash as node_id.
"""
activity = Activity(
activity_id=run_id or str(hash((tuple(input_hashes), output_hash))),
input_ids=sorted(input_hashes),
output_id=output_hash,
intermediate_ids=[],
created_at=datetime.now(timezone.utc).timestamp(),
status="completed",
)
self.activity_store.add(activity)
return activity
def get_activity(self, activity_id: str) -> Optional[Activity]:
"""Get activity by ID."""
return self.activity_store.get(activity_id)
def list_activities(self) -> List[Activity]:
"""List all activities."""
return self.activity_store.list()
def find_activities_by_inputs(self, input_hashes: List[str]) -> List[Activity]:
"""Find activities with matching inputs (for UI grouping)."""
return self.activity_store.find_by_input_ids(input_hashes)
# ============ Deletion Rules ============
def can_delete(self, content_hash: str) -> tuple[bool, str]:
"""
Check if a cached item can be deleted.
Returns:
(can_delete, reason) tuple
"""
# Check if pinned (published or input to published)
pinned, reason = self.is_pinned(content_hash)
if pinned:
return False, f"Item is pinned ({reason})"
# Find node_id for this content
node_id = self._get_content_index(content_hash) or content_hash
# Check if it's an input or output of any activity
for activity in self.activity_store.list():
if node_id in activity.input_ids:
return False, f"Item is input to activity {activity.activity_id}"
if node_id == activity.output_id:
return False, f"Item is output of activity {activity.activity_id}"
return True, "OK"
def can_discard_activity(self, activity_id: str) -> tuple[bool, str]:
"""
Check if an activity can be discarded.
Returns:
(can_discard, reason) tuple
"""
activity = self.activity_store.get(activity_id)
if not activity:
return False, "Activity not found"
# Check if any item is pinned
for node_id in activity.all_node_ids:
entry = self.cache.get_entry(node_id)
if entry:
pinned, reason = self.is_pinned(entry.content_hash)
if pinned:
return False, f"Item {node_id} is pinned ({reason})"
return True, "OK"
def delete_by_content_hash(self, content_hash: str) -> tuple[bool, str]:
"""
Delete a cached item by content_hash.
Enforces deletion rules.
Returns:
(success, message) tuple
"""
can_delete, reason = self.can_delete(content_hash)
if not can_delete:
return False, reason
# Find and delete
node_id = self._get_content_index(content_hash)
if node_id:
self.cache.remove(node_id)
self._del_content_index(content_hash)
return True, "Deleted"
# Try legacy
legacy_path = self.legacy_dir / content_hash
if legacy_path.exists():
legacy_path.unlink()
return True, "Deleted (legacy)"
return False, "Not found"
def discard_activity(self, activity_id: str) -> tuple[bool, str]:
"""
Discard an activity and clean up its cache entries.
Enforces deletion rules.
Returns:
(success, message) tuple
"""
can_discard, reason = self.can_discard_activity(activity_id)
if not can_discard:
return False, reason
success = self.activity_manager.discard_activity(activity_id)
if success:
return True, "Activity discarded"
return False, "Failed to discard"
def discard_activity_outputs_only(self, activity_id: str) -> tuple[bool, str]:
"""
Discard an activity, deleting only outputs and intermediates.
Inputs (cache items, configs) are preserved.
Returns:
(success, message) tuple
"""
activity = self.activity_store.get(activity_id)
if not activity:
return False, "Activity not found"
# Check if output is pinned
if activity.output_id:
entry = self.cache.get_entry(activity.output_id)
if entry:
pinned, reason = self.is_pinned(entry.content_hash)
if pinned:
return False, f"Output is pinned ({reason})"
# Delete output
if activity.output_id:
entry = self.cache.get_entry(activity.output_id)
if entry:
# Remove from cache
self.cache.remove(activity.output_id)
# Remove from content index (Redis + local)
self._del_content_index(entry.content_hash)
# Delete from legacy dir if exists
legacy_path = self.legacy_dir / entry.content_hash
if legacy_path.exists():
legacy_path.unlink()
# Delete intermediates
for node_id in activity.intermediate_ids:
entry = self.cache.get_entry(node_id)
if entry:
self.cache.remove(node_id)
self._del_content_index(entry.content_hash)
legacy_path = self.legacy_dir / entry.content_hash
if legacy_path.exists():
legacy_path.unlink()
# Remove activity record (inputs remain in cache)
self.activity_store.remove(activity_id)
return True, "Activity discarded (outputs only)"
def cleanup_intermediates(self) -> int:
"""Delete all intermediate cache entries (reconstructible)."""
return self.activity_manager.cleanup_intermediates()
def get_deletable_items(self) -> List[CachedFile]:
"""Get all items that can be deleted."""
deletable = []
for entry in self.activity_manager.get_deletable_entries():
deletable.append(CachedFile.from_cache_entry(entry))
return deletable
# ============ L2 Integration ============
def mark_published(self, content_hash: str):
"""Mark a content_hash as published to L2."""
self.l2_checker.mark_shared(content_hash)
def invalidate_shared_cache(self, content_hash: str):
"""Invalidate shared status cache (call if item might be unpublished)."""
self.l2_checker.invalidate(content_hash)
# ============ Stats ============
def get_stats(self) -> dict:
"""Get cache statistics."""
stats = self.cache.get_stats()
return {
"total_entries": stats.total_entries,
"total_size_bytes": stats.total_size_bytes,
"hits": stats.hits,
"misses": stats.misses,
"hit_rate": stats.hit_rate,
"activities": len(self.activity_store),
}
# Singleton instance (initialized on first import with env vars)
_manager: Optional[L1CacheManager] = None
def get_cache_manager() -> L1CacheManager:
"""Get the singleton cache manager instance."""
global _manager
if _manager is None:
import redis
from urllib.parse import urlparse
cache_dir = Path(os.environ.get("CACHE_DIR", str(Path.home() / ".artdag" / "cache")))
l2_server = os.environ.get("L2_SERVER", "http://localhost:8200")
# Initialize Redis client for shared cache index
redis_url = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
parsed = urlparse(redis_url)
redis_client = redis.Redis(
host=parsed.hostname or 'localhost',
port=parsed.port or 6379,
db=int(parsed.path.lstrip('/') or 0),
socket_timeout=5,
socket_connect_timeout=5
)
_manager = L1CacheManager(cache_dir=cache_dir, l2_server=l2_server, redis_client=redis_client)
return _manager
def reset_cache_manager():
"""Reset the singleton (for testing)."""
global _manager
_manager = None