Add hybrid state manager for distributed L1 coordination
Implements HybridStateManager providing fast local Redis operations with background IPNS sync for eventual consistency across L1 nodes. - hybrid_state.py: Centralized state management (cache, claims, analysis, plans, runs) - Updated execute_cid.py, analyze_cid.py, orchestrate_cid.py to use state manager - Background IPNS sync (configurable interval, disabled by default) - Atomic claiming with Redis SETNX for preventing duplicate work Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
294
hybrid_state.py
Normal file
294
hybrid_state.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Hybrid State Manager: Local Redis + IPNS Sync.
|
||||
|
||||
Provides fast local operations with eventual consistency across L1 nodes.
|
||||
|
||||
- Local Redis: Fast reads/writes (microseconds)
|
||||
- IPNS Sync: Background sync with other nodes (every N seconds)
|
||||
- Duplicate work: Accepted, idempotent (same inputs → same CID)
|
||||
|
||||
Usage:
|
||||
from hybrid_state import get_state_manager
|
||||
|
||||
state = get_state_manager()
|
||||
|
||||
# Fast local lookup
|
||||
cid = state.get_cached_cid(cache_id)
|
||||
|
||||
# Fast local write (synced in background)
|
||||
state.set_cached_cid(cache_id, output_cid)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/5")
|
||||
CLUSTER_KEY = os.environ.get("ARTDAG_CLUSTER_KEY", "default")
|
||||
IPNS_SYNC_INTERVAL = int(os.environ.get("ARTDAG_IPNS_SYNC_INTERVAL", "30"))
|
||||
IPNS_ENABLED = os.environ.get("ARTDAG_IPNS_SYNC", "").lower() in ("true", "1", "yes")
|
||||
|
||||
# Redis keys
|
||||
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → output CID
|
||||
ANALYSIS_KEY = "artdag:analysis_cache" # hash: input_hash:features → analysis CID
|
||||
PLAN_KEY = "artdag:plan_cache" # hash: plan_id → plan CID
|
||||
RUN_KEY = "artdag:run_cache" # hash: run_id → output CID
|
||||
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker (with TTL)
|
||||
|
||||
# IPNS names (relative to cluster key)
|
||||
IPNS_CACHE_NAME = "cache"
|
||||
IPNS_ANALYSIS_NAME = "analysis"
|
||||
IPNS_PLAN_NAME = "plans"
|
||||
|
||||
|
||||
class HybridStateManager:
|
||||
"""
|
||||
Local Redis + async IPNS sync for distributed L1 coordination.
|
||||
|
||||
Fast path (local Redis):
|
||||
- get_cached_cid / set_cached_cid
|
||||
- try_claim / release_claim
|
||||
|
||||
Slow path (background IPNS sync):
|
||||
- Periodically syncs local state with global IPNS state
|
||||
- Merges remote state into local (pulls new entries)
|
||||
- Publishes local state to IPNS (pushes updates)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str = REDIS_URL,
|
||||
cluster_key: str = CLUSTER_KEY,
|
||||
sync_interval: int = IPNS_SYNC_INTERVAL,
|
||||
ipns_enabled: bool = IPNS_ENABLED,
|
||||
):
|
||||
self.cluster_key = cluster_key
|
||||
self.sync_interval = sync_interval
|
||||
self.ipns_enabled = ipns_enabled
|
||||
|
||||
# Connect to Redis
|
||||
self._redis = redis.from_url(redis_url, decode_responses=True)
|
||||
|
||||
# IPNS client (lazy import)
|
||||
self._ipfs = None
|
||||
|
||||
# Sync thread
|
||||
self._sync_thread = None
|
||||
self._stop_sync = threading.Event()
|
||||
|
||||
# Start background sync if enabled
|
||||
if self.ipns_enabled:
|
||||
self._start_background_sync()
|
||||
|
||||
@property
|
||||
def ipfs(self):
|
||||
"""Lazy import of IPFS client."""
|
||||
if self._ipfs is None:
|
||||
try:
|
||||
import ipfs_client
|
||||
self._ipfs = ipfs_client
|
||||
except ImportError:
|
||||
logger.warning("ipfs_client not available, IPNS sync disabled")
|
||||
self._ipfs = False
|
||||
return self._ipfs if self._ipfs else None
|
||||
|
||||
# ========== CID Cache ==========
|
||||
|
||||
def get_cached_cid(self, cache_id: str) -> Optional[str]:
|
||||
"""Get output CID for a cache_id. Fast local lookup."""
|
||||
return self._redis.hget(CACHE_KEY, cache_id)
|
||||
|
||||
def set_cached_cid(self, cache_id: str, cid: str) -> None:
|
||||
"""Set output CID for a cache_id. Fast local write."""
|
||||
self._redis.hset(CACHE_KEY, cache_id, cid)
|
||||
|
||||
def get_all_cached_cids(self) -> Dict[str, str]:
|
||||
"""Get all cached CIDs."""
|
||||
return self._redis.hgetall(CACHE_KEY)
|
||||
|
||||
# ========== Analysis Cache ==========
|
||||
|
||||
def get_analysis_cid(self, input_hash: str, features: list) -> Optional[str]:
|
||||
"""Get analysis CID for input + features."""
|
||||
key = f"{input_hash}:{','.join(sorted(features))}"
|
||||
return self._redis.hget(ANALYSIS_KEY, key)
|
||||
|
||||
def set_analysis_cid(self, input_hash: str, features: list, cid: str) -> None:
|
||||
"""Set analysis CID for input + features."""
|
||||
key = f"{input_hash}:{','.join(sorted(features))}"
|
||||
self._redis.hset(ANALYSIS_KEY, key, cid)
|
||||
|
||||
def get_all_analysis_cids(self) -> Dict[str, str]:
|
||||
"""Get all analysis CIDs."""
|
||||
return self._redis.hgetall(ANALYSIS_KEY)
|
||||
|
||||
# ========== Plan Cache ==========
|
||||
|
||||
def get_plan_cid(self, plan_id: str) -> Optional[str]:
|
||||
"""Get plan CID for a plan_id."""
|
||||
return self._redis.hget(PLAN_KEY, plan_id)
|
||||
|
||||
def set_plan_cid(self, plan_id: str, cid: str) -> None:
|
||||
"""Set plan CID for a plan_id."""
|
||||
self._redis.hset(PLAN_KEY, plan_id, cid)
|
||||
|
||||
def get_all_plan_cids(self) -> Dict[str, str]:
|
||||
"""Get all plan CIDs."""
|
||||
return self._redis.hgetall(PLAN_KEY)
|
||||
|
||||
# ========== Run Cache ==========
|
||||
|
||||
def get_run_cid(self, run_id: str) -> Optional[str]:
|
||||
"""Get output CID for a run_id."""
|
||||
return self._redis.hget(RUN_KEY, run_id)
|
||||
|
||||
def set_run_cid(self, run_id: str, cid: str) -> None:
|
||||
"""Set output CID for a run_id."""
|
||||
self._redis.hset(RUN_KEY, run_id, cid)
|
||||
|
||||
# ========== Claiming ==========
|
||||
|
||||
def try_claim(self, cache_id: str, worker_id: str, ttl: int = 300) -> bool:
|
||||
"""
|
||||
Try to claim a cache_id for execution.
|
||||
|
||||
Returns True if claimed, False if already claimed by another worker.
|
||||
Uses Redis SETNX for atomic claim.
|
||||
"""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
return self._redis.set(key, worker_id, nx=True, ex=ttl)
|
||||
|
||||
def release_claim(self, cache_id: str) -> None:
|
||||
"""Release a claim."""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
self._redis.delete(key)
|
||||
|
||||
def get_claim(self, cache_id: str) -> Optional[str]:
|
||||
"""Get current claim holder for a cache_id."""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
return self._redis.get(key)
|
||||
|
||||
# ========== IPNS Sync ==========
|
||||
|
||||
def _start_background_sync(self):
|
||||
"""Start background IPNS sync thread."""
|
||||
if self._sync_thread is not None:
|
||||
return
|
||||
|
||||
def sync_loop():
|
||||
logger.info(f"IPNS sync started (interval={self.sync_interval}s)")
|
||||
while not self._stop_sync.wait(timeout=self.sync_interval):
|
||||
try:
|
||||
self._sync_with_ipns()
|
||||
except Exception as e:
|
||||
logger.warning(f"IPNS sync failed: {e}")
|
||||
|
||||
self._sync_thread = threading.Thread(target=sync_loop, daemon=True)
|
||||
self._sync_thread.start()
|
||||
|
||||
def stop_sync(self):
|
||||
"""Stop background sync thread."""
|
||||
self._stop_sync.set()
|
||||
if self._sync_thread:
|
||||
self._sync_thread.join(timeout=5)
|
||||
|
||||
def _sync_with_ipns(self):
|
||||
"""Sync local state with IPNS global state."""
|
||||
if not self.ipfs:
|
||||
return
|
||||
|
||||
logger.debug("Starting IPNS sync...")
|
||||
|
||||
# Sync each cache type
|
||||
self._sync_hash(CACHE_KEY, IPNS_CACHE_NAME)
|
||||
self._sync_hash(ANALYSIS_KEY, IPNS_ANALYSIS_NAME)
|
||||
self._sync_hash(PLAN_KEY, IPNS_PLAN_NAME)
|
||||
|
||||
logger.debug("IPNS sync complete")
|
||||
|
||||
def _sync_hash(self, redis_key: str, ipns_name: str):
|
||||
"""Sync a Redis hash with IPNS."""
|
||||
ipns_full_name = f"{self.cluster_key}/{ipns_name}"
|
||||
|
||||
# Pull: resolve IPNS → get global state
|
||||
global_state = {}
|
||||
try:
|
||||
global_cid = self.ipfs.name_resolve(ipns_full_name)
|
||||
if global_cid:
|
||||
global_bytes = self.ipfs.get_bytes(global_cid)
|
||||
if global_bytes:
|
||||
global_state = json.loads(global_bytes.decode('utf-8'))
|
||||
logger.debug(f"Pulled {len(global_state)} entries from {ipns_name}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not resolve {ipns_full_name}: {e}")
|
||||
|
||||
# Merge global into local (add entries we don't have)
|
||||
if global_state:
|
||||
pipe = self._redis.pipeline()
|
||||
for key, value in global_state.items():
|
||||
pipe.hsetnx(redis_key, key, value)
|
||||
results = pipe.execute()
|
||||
added = sum(1 for r in results if r)
|
||||
if added:
|
||||
logger.info(f"Merged {added} new entries from IPNS/{ipns_name}")
|
||||
|
||||
# Push: get local state, merge with global, publish
|
||||
local_state = self._redis.hgetall(redis_key)
|
||||
if local_state:
|
||||
merged = {**global_state, **local_state}
|
||||
|
||||
# Only publish if we have new entries
|
||||
if len(merged) > len(global_state):
|
||||
try:
|
||||
new_cid = self.ipfs.add_json(merged)
|
||||
if new_cid:
|
||||
# Note: name_publish can be slow
|
||||
self.ipfs.name_publish(ipns_full_name, new_cid)
|
||||
logger.info(f"Published {len(merged)} entries to IPNS/{ipns_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish to {ipns_full_name}: {e}")
|
||||
|
||||
def force_sync(self):
|
||||
"""Force an immediate IPNS sync (blocking)."""
|
||||
self._sync_with_ipns()
|
||||
|
||||
# ========== Stats ==========
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"cached_cids": self._redis.hlen(CACHE_KEY),
|
||||
"analysis_cids": self._redis.hlen(ANALYSIS_KEY),
|
||||
"plan_cids": self._redis.hlen(PLAN_KEY),
|
||||
"run_cids": self._redis.hlen(RUN_KEY),
|
||||
"ipns_enabled": self.ipns_enabled,
|
||||
"cluster_key": self.cluster_key[:16] + "..." if len(self.cluster_key) > 16 else self.cluster_key,
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_state_manager: Optional[HybridStateManager] = None
|
||||
|
||||
|
||||
def get_state_manager() -> HybridStateManager:
|
||||
"""Get the singleton state manager instance."""
|
||||
global _state_manager
|
||||
if _state_manager is None:
|
||||
_state_manager = HybridStateManager()
|
||||
return _state_manager
|
||||
|
||||
|
||||
def reset_state_manager():
|
||||
"""Reset the singleton (for testing)."""
|
||||
global _state_manager
|
||||
if _state_manager:
|
||||
_state_manager.stop_sync()
|
||||
_state_manager = None
|
||||
@@ -2,6 +2,10 @@
|
||||
IPFS-primary analysis tasks.
|
||||
|
||||
Fetches inputs from IPFS, stores analysis results on IPFS.
|
||||
|
||||
Uses HybridStateManager for:
|
||||
- Fast local Redis operations
|
||||
- Background IPNS sync with other L1 nodes
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -18,17 +22,7 @@ import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from celery_app import app
|
||||
import ipfs_client
|
||||
|
||||
# Redis for caching analysis CIDs
|
||||
import redis
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
from hybrid_state import get_state_manager
|
||||
|
||||
# Import artdag analysis module
|
||||
try:
|
||||
@@ -39,26 +33,15 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key for analysis cache
|
||||
ANALYSIS_CACHE_KEY = "artdag:analysis_cid" # hash: input_hash:features → analysis CID
|
||||
|
||||
|
||||
def get_analysis_cache_key(input_hash: str, features: List[str]) -> str:
|
||||
"""Generate cache key for analysis results."""
|
||||
features_str = ",".join(sorted(features))
|
||||
return f"{input_hash}:{features_str}"
|
||||
|
||||
|
||||
def get_cached_analysis_cid(input_hash: str, features: List[str]) -> Optional[str]:
|
||||
"""Check if analysis is already cached."""
|
||||
key = get_analysis_cache_key(input_hash, features)
|
||||
return get_redis().hget(ANALYSIS_CACHE_KEY, key)
|
||||
return get_state_manager().get_analysis_cid(input_hash, features)
|
||||
|
||||
|
||||
def set_cached_analysis_cid(input_hash: str, features: List[str], cid: str) -> None:
|
||||
"""Store analysis CID in cache."""
|
||||
key = get_analysis_cache_key(input_hash, features)
|
||||
get_redis().hset(ANALYSIS_CACHE_KEY, key, cid)
|
||||
get_state_manager().set_analysis_cid(input_hash, features, cid)
|
||||
|
||||
|
||||
@app.task(bind=True, name='tasks.analyze_input_cid')
|
||||
|
||||
@@ -3,6 +3,10 @@ Simplified step execution with IPFS-primary architecture.
|
||||
|
||||
Steps receive CIDs, produce CIDs. No file paths cross machine boundaries.
|
||||
IPFS nodes form a distributed cache automatically.
|
||||
|
||||
Uses HybridStateManager for:
|
||||
- Fast local Redis operations
|
||||
- Background IPNS sync with other L1 nodes
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -19,17 +23,7 @@ import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from celery_app import app
|
||||
import ipfs_client
|
||||
|
||||
# Redis for claiming and cache_id → CID mapping
|
||||
import redis
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
from hybrid_state import get_state_manager
|
||||
|
||||
# Import artdag
|
||||
try:
|
||||
@@ -44,10 +38,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis keys
|
||||
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → CID
|
||||
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker_id
|
||||
|
||||
|
||||
def get_worker_id() -> str:
|
||||
"""Get unique worker identifier."""
|
||||
@@ -56,24 +46,22 @@ def get_worker_id() -> str:
|
||||
|
||||
def get_cached_cid(cache_id: str) -> Optional[str]:
|
||||
"""Check if cache_id has a known CID."""
|
||||
return get_redis().hget(CACHE_KEY, cache_id)
|
||||
return get_state_manager().get_cached_cid(cache_id)
|
||||
|
||||
|
||||
def set_cached_cid(cache_id: str, cid: str) -> None:
|
||||
"""Store cache_id → CID mapping."""
|
||||
get_redis().hset(CACHE_KEY, cache_id, cid)
|
||||
get_state_manager().set_cached_cid(cache_id, cid)
|
||||
|
||||
|
||||
def try_claim(cache_id: str, worker_id: str, ttl: int = 300) -> bool:
|
||||
"""Try to claim a cache_id for execution. Returns True if claimed."""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
return get_redis().set(key, worker_id, nx=True, ex=ttl)
|
||||
return get_state_manager().try_claim(cache_id, worker_id, ttl)
|
||||
|
||||
|
||||
def release_claim(cache_id: str) -> None:
|
||||
"""Release a claim."""
|
||||
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
||||
get_redis().delete(key)
|
||||
get_state_manager().release_claim(cache_id)
|
||||
|
||||
|
||||
def wait_for_cid(cache_id: str, timeout: int = 600, poll_interval: float = 0.5) -> Optional[str]:
|
||||
|
||||
@@ -8,33 +8,25 @@ Everything on IPFS:
|
||||
- Step outputs (media files)
|
||||
|
||||
The entire pipeline just passes CIDs around.
|
||||
|
||||
Uses HybridStateManager for:
|
||||
- Fast local Redis operations
|
||||
- Background IPNS sync with other L1 nodes
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from celery import current_task, group
|
||||
from celery import group
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from celery_app import app
|
||||
import ipfs_client
|
||||
|
||||
# Redis for caching
|
||||
import redis
|
||||
REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5")
|
||||
_redis: Optional[redis.Redis] = None
|
||||
|
||||
def get_redis() -> redis.Redis:
|
||||
global _redis
|
||||
if _redis is None:
|
||||
_redis = redis.from_url(REDIS_URL, decode_responses=True)
|
||||
return _redis
|
||||
from hybrid_state import get_state_manager
|
||||
|
||||
# Import artdag modules
|
||||
try:
|
||||
@@ -53,11 +45,6 @@ from .execute_cid import execute_step_cid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis keys
|
||||
PLAN_CACHE_KEY = "artdag:plan_cid" # hash: plan_id → plan CID
|
||||
RECIPE_CACHE_KEY = "artdag:recipe_cid" # hash: recipe_hash → recipe CID
|
||||
RUN_CACHE_KEY = "artdag:run_cid" # hash: run_id → output CID
|
||||
|
||||
|
||||
def compute_run_id(recipe_cid: str, input_cids: Dict[str, str]) -> str:
|
||||
"""Compute deterministic run ID from recipe and inputs."""
|
||||
@@ -203,7 +190,7 @@ def generate_plan_cid(
|
||||
return {"status": "failed", "error": "Failed to store plan on IPFS"}
|
||||
|
||||
# Cache plan_id → plan_cid mapping
|
||||
get_redis().hset(PLAN_CACHE_KEY, plan.plan_id, plan_cid)
|
||||
get_state_manager().set_plan_cid(plan.plan_id, plan_cid)
|
||||
|
||||
logger.info(f"[CID] Generated plan: {plan.plan_id[:16]}... → {plan_cid}")
|
||||
|
||||
@@ -327,7 +314,7 @@ def run_recipe_cid(
|
||||
run_id = compute_run_id(recipe_cid, input_cids)
|
||||
|
||||
# Check if run is already cached
|
||||
cached_output = get_redis().hget(RUN_CACHE_KEY, run_id)
|
||||
cached_output = get_state_manager().get_run_cid(run_id)
|
||||
if cached_output:
|
||||
logger.info(f"[CID] Run cache hit: {run_id[:16]}... → {cached_output}")
|
||||
return {
|
||||
@@ -385,7 +372,7 @@ def run_recipe_cid(
|
||||
output_cid = exec_result["output_cid"]
|
||||
|
||||
# Cache the run
|
||||
get_redis().hset(RUN_CACHE_KEY, run_id, output_cid)
|
||||
get_state_manager().set_run_cid(run_id, output_cid)
|
||||
|
||||
logger.info(f"[CID] Run complete: {run_id[:16]}... → {output_cid}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user