Add hybrid state manager for distributed L1 coordination

Implements HybridStateManager providing fast local Redis operations
with background IPNS sync for eventual consistency across L1 nodes.

- hybrid_state.py: Centralized state management (cache, claims, analysis, plans, runs)
- Updated execute_cid.py, analyze_cid.py, orchestrate_cid.py to use state manager
- Background IPNS sync (configurable interval, disabled by default)
- Atomic claiming with Redis SETNX for preventing duplicate work

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
gilesb
2026-01-11 09:35:50 +00:00
parent f11cec9d48
commit ca8bfd8705
4 changed files with 319 additions and 67 deletions

294
hybrid_state.py Normal file
View File

@@ -0,0 +1,294 @@
"""
Hybrid State Manager: Local Redis + IPNS Sync.
Provides fast local operations with eventual consistency across L1 nodes.
- Local Redis: Fast reads/writes (microseconds)
- IPNS Sync: Background sync with other nodes (every N seconds)
- Duplicate work: Accepted, idempotent (same inputs → same CID)
Usage:
from hybrid_state import get_state_manager
state = get_state_manager()
# Fast local lookup
cid = state.get_cached_cid(cache_id)
# Fast local write (synced in background)
state.set_cached_cid(cache_id, output_cid)
"""
import json
import logging
import os
import threading
import time
from typing import Dict, Optional
import redis
logger = logging.getLogger(__name__)
# Configuration
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/5")
CLUSTER_KEY = os.environ.get("ARTDAG_CLUSTER_KEY", "default")
IPNS_SYNC_INTERVAL = int(os.environ.get("ARTDAG_IPNS_SYNC_INTERVAL", "30"))
IPNS_ENABLED = os.environ.get("ARTDAG_IPNS_SYNC", "").lower() in ("true", "1", "yes")
# Redis keys
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → output CID
ANALYSIS_KEY = "artdag:analysis_cache" # hash: input_hash:features → analysis CID
PLAN_KEY = "artdag:plan_cache" # hash: plan_id → plan CID
RUN_KEY = "artdag:run_cache" # hash: run_id → output CID
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker (with TTL)
# IPNS names (relative to cluster key)
IPNS_CACHE_NAME = "cache"
IPNS_ANALYSIS_NAME = "analysis"
IPNS_PLAN_NAME = "plans"
class HybridStateManager:
"""
Local Redis + async IPNS sync for distributed L1 coordination.
Fast path (local Redis):
- get_cached_cid / set_cached_cid
- try_claim / release_claim
Slow path (background IPNS sync):
- Periodically syncs local state with global IPNS state
- Merges remote state into local (pulls new entries)
- Publishes local state to IPNS (pushes updates)
"""
def __init__(
self,
redis_url: str = REDIS_URL,
cluster_key: str = CLUSTER_KEY,
sync_interval: int = IPNS_SYNC_INTERVAL,
ipns_enabled: bool = IPNS_ENABLED,
):
self.cluster_key = cluster_key
self.sync_interval = sync_interval
self.ipns_enabled = ipns_enabled
# Connect to Redis
self._redis = redis.from_url(redis_url, decode_responses=True)
# IPNS client (lazy import)
self._ipfs = None
# Sync thread
self._sync_thread = None
self._stop_sync = threading.Event()
# Start background sync if enabled
if self.ipns_enabled:
self._start_background_sync()
@property
def ipfs(self):
"""Lazy import of IPFS client."""
if self._ipfs is None:
try:
import ipfs_client
self._ipfs = ipfs_client
except ImportError:
logger.warning("ipfs_client not available, IPNS sync disabled")
self._ipfs = False
return self._ipfs if self._ipfs else None
# ========== CID Cache ==========
def get_cached_cid(self, cache_id: str) -> Optional[str]:
"""Get output CID for a cache_id. Fast local lookup."""
return self._redis.hget(CACHE_KEY, cache_id)
def set_cached_cid(self, cache_id: str, cid: str) -> None:
"""Set output CID for a cache_id. Fast local write."""
self._redis.hset(CACHE_KEY, cache_id, cid)
def get_all_cached_cids(self) -> Dict[str, str]:
"""Get all cached CIDs."""
return self._redis.hgetall(CACHE_KEY)
# ========== Analysis Cache ==========
def get_analysis_cid(self, input_hash: str, features: list) -> Optional[str]:
"""Get analysis CID for input + features."""
key = f"{input_hash}:{','.join(sorted(features))}"
return self._redis.hget(ANALYSIS_KEY, key)
def set_analysis_cid(self, input_hash: str, features: list, cid: str) -> None:
"""Set analysis CID for input + features."""
key = f"{input_hash}:{','.join(sorted(features))}"
self._redis.hset(ANALYSIS_KEY, key, cid)
def get_all_analysis_cids(self) -> Dict[str, str]:
"""Get all analysis CIDs."""
return self._redis.hgetall(ANALYSIS_KEY)
# ========== Plan Cache ==========
def get_plan_cid(self, plan_id: str) -> Optional[str]:
"""Get plan CID for a plan_id."""
return self._redis.hget(PLAN_KEY, plan_id)
def set_plan_cid(self, plan_id: str, cid: str) -> None:
"""Set plan CID for a plan_id."""
self._redis.hset(PLAN_KEY, plan_id, cid)
def get_all_plan_cids(self) -> Dict[str, str]:
"""Get all plan CIDs."""
return self._redis.hgetall(PLAN_KEY)
# ========== Run Cache ==========
def get_run_cid(self, run_id: str) -> Optional[str]:
"""Get output CID for a run_id."""
return self._redis.hget(RUN_KEY, run_id)
def set_run_cid(self, run_id: str, cid: str) -> None:
"""Set output CID for a run_id."""
self._redis.hset(RUN_KEY, run_id, cid)
# ========== Claiming ==========
def try_claim(self, cache_id: str, worker_id: str, ttl: int = 300) -> bool:
"""
Try to claim a cache_id for execution.
Returns True if claimed, False if already claimed by another worker.
Uses Redis SETNX for atomic claim.
"""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
return self._redis.set(key, worker_id, nx=True, ex=ttl)
def release_claim(self, cache_id: str) -> None:
"""Release a claim."""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
self._redis.delete(key)
def get_claim(self, cache_id: str) -> Optional[str]:
"""Get current claim holder for a cache_id."""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
return self._redis.get(key)
# ========== IPNS Sync ==========
def _start_background_sync(self):
"""Start background IPNS sync thread."""
if self._sync_thread is not None:
return
def sync_loop():
logger.info(f"IPNS sync started (interval={self.sync_interval}s)")
while not self._stop_sync.wait(timeout=self.sync_interval):
try:
self._sync_with_ipns()
except Exception as e:
logger.warning(f"IPNS sync failed: {e}")
self._sync_thread = threading.Thread(target=sync_loop, daemon=True)
self._sync_thread.start()
def stop_sync(self):
"""Stop background sync thread."""
self._stop_sync.set()
if self._sync_thread:
self._sync_thread.join(timeout=5)
def _sync_with_ipns(self):
"""Sync local state with IPNS global state."""
if not self.ipfs:
return
logger.debug("Starting IPNS sync...")
# Sync each cache type
self._sync_hash(CACHE_KEY, IPNS_CACHE_NAME)
self._sync_hash(ANALYSIS_KEY, IPNS_ANALYSIS_NAME)
self._sync_hash(PLAN_KEY, IPNS_PLAN_NAME)
logger.debug("IPNS sync complete")
def _sync_hash(self, redis_key: str, ipns_name: str):
"""Sync a Redis hash with IPNS."""
ipns_full_name = f"{self.cluster_key}/{ipns_name}"
# Pull: resolve IPNS → get global state
global_state = {}
try:
global_cid = self.ipfs.name_resolve(ipns_full_name)
if global_cid:
global_bytes = self.ipfs.get_bytes(global_cid)
if global_bytes:
global_state = json.loads(global_bytes.decode('utf-8'))
logger.debug(f"Pulled {len(global_state)} entries from {ipns_name}")
except Exception as e:
logger.debug(f"Could not resolve {ipns_full_name}: {e}")
# Merge global into local (add entries we don't have)
if global_state:
pipe = self._redis.pipeline()
for key, value in global_state.items():
pipe.hsetnx(redis_key, key, value)
results = pipe.execute()
added = sum(1 for r in results if r)
if added:
logger.info(f"Merged {added} new entries from IPNS/{ipns_name}")
# Push: get local state, merge with global, publish
local_state = self._redis.hgetall(redis_key)
if local_state:
merged = {**global_state, **local_state}
# Only publish if we have new entries
if len(merged) > len(global_state):
try:
new_cid = self.ipfs.add_json(merged)
if new_cid:
# Note: name_publish can be slow
self.ipfs.name_publish(ipns_full_name, new_cid)
logger.info(f"Published {len(merged)} entries to IPNS/{ipns_name}")
except Exception as e:
logger.warning(f"Failed to publish to {ipns_full_name}: {e}")
def force_sync(self):
"""Force an immediate IPNS sync (blocking)."""
self._sync_with_ipns()
# ========== Stats ==========
def get_stats(self) -> Dict:
"""Get cache statistics."""
return {
"cached_cids": self._redis.hlen(CACHE_KEY),
"analysis_cids": self._redis.hlen(ANALYSIS_KEY),
"plan_cids": self._redis.hlen(PLAN_KEY),
"run_cids": self._redis.hlen(RUN_KEY),
"ipns_enabled": self.ipns_enabled,
"cluster_key": self.cluster_key[:16] + "..." if len(self.cluster_key) > 16 else self.cluster_key,
}
# Singleton instance
_state_manager: Optional[HybridStateManager] = None
def get_state_manager() -> HybridStateManager:
"""Get the singleton state manager instance."""
global _state_manager
if _state_manager is None:
_state_manager = HybridStateManager()
return _state_manager
def reset_state_manager():
"""Reset the singleton (for testing)."""
global _state_manager
if _state_manager:
_state_manager.stop_sync()
_state_manager = None

View File

@@ -2,6 +2,10 @@
IPFS-primary analysis tasks.
Fetches inputs from IPFS, stores analysis results on IPFS.
Uses HybridStateManager for:
- Fast local Redis operations
- Background IPNS sync with other L1 nodes
"""
import json
@@ -18,17 +22,7 @@ import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
import ipfs_client
# Redis for caching analysis CIDs
import redis
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
_redis: Optional[redis.Redis] = None
def get_redis() -> redis.Redis:
global _redis
if _redis is None:
_redis = redis.from_url(REDIS_URL, decode_responses=True)
return _redis
from hybrid_state import get_state_manager
# Import artdag analysis module
try:
@@ -39,26 +33,15 @@ except ImportError:
logger = logging.getLogger(__name__)
# Redis key for analysis cache
ANALYSIS_CACHE_KEY = "artdag:analysis_cid" # hash: input_hash:features → analysis CID
def get_analysis_cache_key(input_hash: str, features: List[str]) -> str:
"""Generate cache key for analysis results."""
features_str = ",".join(sorted(features))
return f"{input_hash}:{features_str}"
def get_cached_analysis_cid(input_hash: str, features: List[str]) -> Optional[str]:
"""Check if analysis is already cached."""
key = get_analysis_cache_key(input_hash, features)
return get_redis().hget(ANALYSIS_CACHE_KEY, key)
return get_state_manager().get_analysis_cid(input_hash, features)
def set_cached_analysis_cid(input_hash: str, features: List[str], cid: str) -> None:
"""Store analysis CID in cache."""
key = get_analysis_cache_key(input_hash, features)
get_redis().hset(ANALYSIS_CACHE_KEY, key, cid)
get_state_manager().set_analysis_cid(input_hash, features, cid)
@app.task(bind=True, name='tasks.analyze_input_cid')

View File

@@ -3,6 +3,10 @@ Simplified step execution with IPFS-primary architecture.
Steps receive CIDs, produce CIDs. No file paths cross machine boundaries.
IPFS nodes form a distributed cache automatically.
Uses HybridStateManager for:
- Fast local Redis operations
- Background IPNS sync with other L1 nodes
"""
import logging
@@ -19,17 +23,7 @@ import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
import ipfs_client
# Redis for claiming and cache_id → CID mapping
import redis
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
_redis: Optional[redis.Redis] = None
def get_redis() -> redis.Redis:
global _redis
if _redis is None:
_redis = redis.from_url(REDIS_URL, decode_responses=True)
return _redis
from hybrid_state import get_state_manager
# Import artdag
try:
@@ -44,10 +38,6 @@ except ImportError:
logger = logging.getLogger(__name__)
# Redis keys
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → CID
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker_id
def get_worker_id() -> str:
"""Get unique worker identifier."""
@@ -56,24 +46,22 @@ def get_worker_id() -> str:
def get_cached_cid(cache_id: str) -> Optional[str]:
"""Check if cache_id has a known CID."""
return get_redis().hget(CACHE_KEY, cache_id)
return get_state_manager().get_cached_cid(cache_id)
def set_cached_cid(cache_id: str, cid: str) -> None:
"""Store cache_id → CID mapping."""
get_redis().hset(CACHE_KEY, cache_id, cid)
get_state_manager().set_cached_cid(cache_id, cid)
def try_claim(cache_id: str, worker_id: str, ttl: int = 300) -> bool:
"""Try to claim a cache_id for execution. Returns True if claimed."""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
return get_redis().set(key, worker_id, nx=True, ex=ttl)
return get_state_manager().try_claim(cache_id, worker_id, ttl)
def release_claim(cache_id: str) -> None:
"""Release a claim."""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
get_redis().delete(key)
get_state_manager().release_claim(cache_id)
def wait_for_cid(cache_id: str, timeout: int = 600, poll_interval: float = 0.5) -> Optional[str]:

View File

@@ -8,33 +8,25 @@ Everything on IPFS:
- Step outputs (media files)
The entire pipeline just passes CIDs around.
Uses HybridStateManager for:
- Fast local Redis operations
- Background IPNS sync with other L1 nodes
"""
import json
import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Dict, List, Optional
from celery import current_task, group
from celery import group
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from celery_app import app
import ipfs_client
# Redis for caching
import redis
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
_redis: Optional[redis.Redis] = None
def get_redis() -> redis.Redis:
global _redis
if _redis is None:
_redis = redis.from_url(REDIS_URL, decode_responses=True)
return _redis
from hybrid_state import get_state_manager
# Import artdag modules
try:
@@ -53,11 +45,6 @@ from .execute_cid import execute_step_cid
logger = logging.getLogger(__name__)
# Redis keys
PLAN_CACHE_KEY = "artdag:plan_cid" # hash: plan_id → plan CID
RECIPE_CACHE_KEY = "artdag:recipe_cid" # hash: recipe_hash → recipe CID
RUN_CACHE_KEY = "artdag:run_cid" # hash: run_id → output CID
def compute_run_id(recipe_cid: str, input_cids: Dict[str, str]) -> str:
"""Compute deterministic run ID from recipe and inputs."""
@@ -203,7 +190,7 @@ def generate_plan_cid(
return {"status": "failed", "error": "Failed to store plan on IPFS"}
# Cache plan_id → plan_cid mapping
get_redis().hset(PLAN_CACHE_KEY, plan.plan_id, plan_cid)
get_state_manager().set_plan_cid(plan.plan_id, plan_cid)
logger.info(f"[CID] Generated plan: {plan.plan_id[:16]}... → {plan_cid}")
@@ -327,7 +314,7 @@ def run_recipe_cid(
run_id = compute_run_id(recipe_cid, input_cids)
# Check if run is already cached
cached_output = get_redis().hget(RUN_CACHE_KEY, run_id)
cached_output = get_state_manager().get_run_cid(run_id)
if cached_output:
logger.info(f"[CID] Run cache hit: {run_id[:16]}... → {cached_output}")
return {
@@ -385,7 +372,7 @@ def run_recipe_cid(
output_cid = exec_result["output_cid"]
# Cache the run
get_redis().hset(RUN_CACHE_KEY, run_id, output_cid)
get_state_manager().set_run_cid(run_id, output_cid)
logger.info(f"[CID] Run complete: {run_id[:16]}... → {output_cid}")