From ca8bfd87054a7bd879d6b968af8becc6f6a27e7e Mon Sep 17 00:00:00 2001 From: gilesb Date: Sun, 11 Jan 2026 09:35:50 +0000 Subject: [PATCH] Add hybrid state manager for distributed L1 coordination Implements HybridStateManager providing fast local Redis operations with background IPNS sync for eventual consistency across L1 nodes. - hybrid_state.py: Centralized state management (cache, claims, analysis, plans, runs) - Updated execute_cid.py, analyze_cid.py, orchestrate_cid.py to use state manager - Background IPNS sync (configurable interval, disabled by default) - Atomic claiming with Redis SETNX for preventing duplicate work Co-Authored-By: Claude Opus 4.5 --- hybrid_state.py | 294 +++++++++++++++++++++++++++++++++++++++ tasks/analyze_cid.py | 31 +---- tasks/execute_cid.py | 30 ++-- tasks/orchestrate_cid.py | 31 ++--- 4 files changed, 319 insertions(+), 67 deletions(-) create mode 100644 hybrid_state.py diff --git a/hybrid_state.py b/hybrid_state.py new file mode 100644 index 0000000..b351a7c --- /dev/null +++ b/hybrid_state.py @@ -0,0 +1,294 @@ +""" +Hybrid State Manager: Local Redis + IPNS Sync. + +Provides fast local operations with eventual consistency across L1 nodes. + +- Local Redis: Fast reads/writes (microseconds) +- IPNS Sync: Background sync with other nodes (every N seconds) +- Duplicate work: Accepted, idempotent (same inputs → same CID) + +Usage: + from hybrid_state import get_state_manager + + state = get_state_manager() + + # Fast local lookup + cid = state.get_cached_cid(cache_id) + + # Fast local write (synced in background) + state.set_cached_cid(cache_id, output_cid) +""" + +import json +import logging +import os +import threading +import time +from typing import Dict, Optional + +import redis + +logger = logging.getLogger(__name__) + +# Configuration +REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/5") +CLUSTER_KEY = os.environ.get("ARTDAG_CLUSTER_KEY", "default") +IPNS_SYNC_INTERVAL = int(os.environ.get("ARTDAG_IPNS_SYNC_INTERVAL", "30")) +IPNS_ENABLED = os.environ.get("ARTDAG_IPNS_SYNC", "").lower() in ("true", "1", "yes") + +# Redis keys +CACHE_KEY = "artdag:cid_cache" # hash: cache_id → output CID +ANALYSIS_KEY = "artdag:analysis_cache" # hash: input_hash:features → analysis CID +PLAN_KEY = "artdag:plan_cache" # hash: plan_id → plan CID +RUN_KEY = "artdag:run_cache" # hash: run_id → output CID +CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker (with TTL) + +# IPNS names (relative to cluster key) +IPNS_CACHE_NAME = "cache" +IPNS_ANALYSIS_NAME = "analysis" +IPNS_PLAN_NAME = "plans" + + +class HybridStateManager: + """ + Local Redis + async IPNS sync for distributed L1 coordination. + + Fast path (local Redis): + - get_cached_cid / set_cached_cid + - try_claim / release_claim + + Slow path (background IPNS sync): + - Periodically syncs local state with global IPNS state + - Merges remote state into local (pulls new entries) + - Publishes local state to IPNS (pushes updates) + """ + + def __init__( + self, + redis_url: str = REDIS_URL, + cluster_key: str = CLUSTER_KEY, + sync_interval: int = IPNS_SYNC_INTERVAL, + ipns_enabled: bool = IPNS_ENABLED, + ): + self.cluster_key = cluster_key + self.sync_interval = sync_interval + self.ipns_enabled = ipns_enabled + + # Connect to Redis + self._redis = redis.from_url(redis_url, decode_responses=True) + + # IPNS client (lazy import) + self._ipfs = None + + # Sync thread + self._sync_thread = None + self._stop_sync = threading.Event() + + # Start background sync if enabled + if self.ipns_enabled: + self._start_background_sync() + + @property + def ipfs(self): + """Lazy import of IPFS client.""" + if self._ipfs is None: + try: + import ipfs_client + self._ipfs = ipfs_client + except ImportError: + logger.warning("ipfs_client not available, IPNS sync disabled") + self._ipfs = False + return self._ipfs if self._ipfs else None + + # ========== CID Cache ========== + + def get_cached_cid(self, cache_id: str) -> Optional[str]: + """Get output CID for a cache_id. Fast local lookup.""" + return self._redis.hget(CACHE_KEY, cache_id) + + def set_cached_cid(self, cache_id: str, cid: str) -> None: + """Set output CID for a cache_id. Fast local write.""" + self._redis.hset(CACHE_KEY, cache_id, cid) + + def get_all_cached_cids(self) -> Dict[str, str]: + """Get all cached CIDs.""" + return self._redis.hgetall(CACHE_KEY) + + # ========== Analysis Cache ========== + + def get_analysis_cid(self, input_hash: str, features: list) -> Optional[str]: + """Get analysis CID for input + features.""" + key = f"{input_hash}:{','.join(sorted(features))}" + return self._redis.hget(ANALYSIS_KEY, key) + + def set_analysis_cid(self, input_hash: str, features: list, cid: str) -> None: + """Set analysis CID for input + features.""" + key = f"{input_hash}:{','.join(sorted(features))}" + self._redis.hset(ANALYSIS_KEY, key, cid) + + def get_all_analysis_cids(self) -> Dict[str, str]: + """Get all analysis CIDs.""" + return self._redis.hgetall(ANALYSIS_KEY) + + # ========== Plan Cache ========== + + def get_plan_cid(self, plan_id: str) -> Optional[str]: + """Get plan CID for a plan_id.""" + return self._redis.hget(PLAN_KEY, plan_id) + + def set_plan_cid(self, plan_id: str, cid: str) -> None: + """Set plan CID for a plan_id.""" + self._redis.hset(PLAN_KEY, plan_id, cid) + + def get_all_plan_cids(self) -> Dict[str, str]: + """Get all plan CIDs.""" + return self._redis.hgetall(PLAN_KEY) + + # ========== Run Cache ========== + + def get_run_cid(self, run_id: str) -> Optional[str]: + """Get output CID for a run_id.""" + return self._redis.hget(RUN_KEY, run_id) + + def set_run_cid(self, run_id: str, cid: str) -> None: + """Set output CID for a run_id.""" + self._redis.hset(RUN_KEY, run_id, cid) + + # ========== Claiming ========== + + def try_claim(self, cache_id: str, worker_id: str, ttl: int = 300) -> bool: + """ + Try to claim a cache_id for execution. + + Returns True if claimed, False if already claimed by another worker. + Uses Redis SETNX for atomic claim. + """ + key = f"{CLAIM_KEY_PREFIX}{cache_id}" + return self._redis.set(key, worker_id, nx=True, ex=ttl) + + def release_claim(self, cache_id: str) -> None: + """Release a claim.""" + key = f"{CLAIM_KEY_PREFIX}{cache_id}" + self._redis.delete(key) + + def get_claim(self, cache_id: str) -> Optional[str]: + """Get current claim holder for a cache_id.""" + key = f"{CLAIM_KEY_PREFIX}{cache_id}" + return self._redis.get(key) + + # ========== IPNS Sync ========== + + def _start_background_sync(self): + """Start background IPNS sync thread.""" + if self._sync_thread is not None: + return + + def sync_loop(): + logger.info(f"IPNS sync started (interval={self.sync_interval}s)") + while not self._stop_sync.wait(timeout=self.sync_interval): + try: + self._sync_with_ipns() + except Exception as e: + logger.warning(f"IPNS sync failed: {e}") + + self._sync_thread = threading.Thread(target=sync_loop, daemon=True) + self._sync_thread.start() + + def stop_sync(self): + """Stop background sync thread.""" + self._stop_sync.set() + if self._sync_thread: + self._sync_thread.join(timeout=5) + + def _sync_with_ipns(self): + """Sync local state with IPNS global state.""" + if not self.ipfs: + return + + logger.debug("Starting IPNS sync...") + + # Sync each cache type + self._sync_hash(CACHE_KEY, IPNS_CACHE_NAME) + self._sync_hash(ANALYSIS_KEY, IPNS_ANALYSIS_NAME) + self._sync_hash(PLAN_KEY, IPNS_PLAN_NAME) + + logger.debug("IPNS sync complete") + + def _sync_hash(self, redis_key: str, ipns_name: str): + """Sync a Redis hash with IPNS.""" + ipns_full_name = f"{self.cluster_key}/{ipns_name}" + + # Pull: resolve IPNS → get global state + global_state = {} + try: + global_cid = self.ipfs.name_resolve(ipns_full_name) + if global_cid: + global_bytes = self.ipfs.get_bytes(global_cid) + if global_bytes: + global_state = json.loads(global_bytes.decode('utf-8')) + logger.debug(f"Pulled {len(global_state)} entries from {ipns_name}") + except Exception as e: + logger.debug(f"Could not resolve {ipns_full_name}: {e}") + + # Merge global into local (add entries we don't have) + if global_state: + pipe = self._redis.pipeline() + for key, value in global_state.items(): + pipe.hsetnx(redis_key, key, value) + results = pipe.execute() + added = sum(1 for r in results if r) + if added: + logger.info(f"Merged {added} new entries from IPNS/{ipns_name}") + + # Push: get local state, merge with global, publish + local_state = self._redis.hgetall(redis_key) + if local_state: + merged = {**global_state, **local_state} + + # Only publish if we have new entries + if len(merged) > len(global_state): + try: + new_cid = self.ipfs.add_json(merged) + if new_cid: + # Note: name_publish can be slow + self.ipfs.name_publish(ipns_full_name, new_cid) + logger.info(f"Published {len(merged)} entries to IPNS/{ipns_name}") + except Exception as e: + logger.warning(f"Failed to publish to {ipns_full_name}: {e}") + + def force_sync(self): + """Force an immediate IPNS sync (blocking).""" + self._sync_with_ipns() + + # ========== Stats ========== + + def get_stats(self) -> Dict: + """Get cache statistics.""" + return { + "cached_cids": self._redis.hlen(CACHE_KEY), + "analysis_cids": self._redis.hlen(ANALYSIS_KEY), + "plan_cids": self._redis.hlen(PLAN_KEY), + "run_cids": self._redis.hlen(RUN_KEY), + "ipns_enabled": self.ipns_enabled, + "cluster_key": self.cluster_key[:16] + "..." if len(self.cluster_key) > 16 else self.cluster_key, + } + + +# Singleton instance +_state_manager: Optional[HybridStateManager] = None + + +def get_state_manager() -> HybridStateManager: + """Get the singleton state manager instance.""" + global _state_manager + if _state_manager is None: + _state_manager = HybridStateManager() + return _state_manager + + +def reset_state_manager(): + """Reset the singleton (for testing).""" + global _state_manager + if _state_manager: + _state_manager.stop_sync() + _state_manager = None diff --git a/tasks/analyze_cid.py b/tasks/analyze_cid.py index 673c902..72e4313 100644 --- a/tasks/analyze_cid.py +++ b/tasks/analyze_cid.py @@ -2,6 +2,10 @@ IPFS-primary analysis tasks. Fetches inputs from IPFS, stores analysis results on IPFS. + +Uses HybridStateManager for: +- Fast local Redis operations +- Background IPNS sync with other L1 nodes """ import json @@ -18,17 +22,7 @@ import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from celery_app import app import ipfs_client - -# Redis for caching analysis CIDs -import redis -REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5") -_redis: Optional[redis.Redis] = None - -def get_redis() -> redis.Redis: - global _redis - if _redis is None: - _redis = redis.from_url(REDIS_URL, decode_responses=True) - return _redis +from hybrid_state import get_state_manager # Import artdag analysis module try: @@ -39,26 +33,15 @@ except ImportError: logger = logging.getLogger(__name__) -# Redis key for analysis cache -ANALYSIS_CACHE_KEY = "artdag:analysis_cid" # hash: input_hash:features → analysis CID - - -def get_analysis_cache_key(input_hash: str, features: List[str]) -> str: - """Generate cache key for analysis results.""" - features_str = ",".join(sorted(features)) - return f"{input_hash}:{features_str}" - def get_cached_analysis_cid(input_hash: str, features: List[str]) -> Optional[str]: """Check if analysis is already cached.""" - key = get_analysis_cache_key(input_hash, features) - return get_redis().hget(ANALYSIS_CACHE_KEY, key) + return get_state_manager().get_analysis_cid(input_hash, features) def set_cached_analysis_cid(input_hash: str, features: List[str], cid: str) -> None: """Store analysis CID in cache.""" - key = get_analysis_cache_key(input_hash, features) - get_redis().hset(ANALYSIS_CACHE_KEY, key, cid) + get_state_manager().set_analysis_cid(input_hash, features, cid) @app.task(bind=True, name='tasks.analyze_input_cid') diff --git a/tasks/execute_cid.py b/tasks/execute_cid.py index 094bd0e..f4c9699 100644 --- a/tasks/execute_cid.py +++ b/tasks/execute_cid.py @@ -3,6 +3,10 @@ Simplified step execution with IPFS-primary architecture. Steps receive CIDs, produce CIDs. No file paths cross machine boundaries. IPFS nodes form a distributed cache automatically. + +Uses HybridStateManager for: +- Fast local Redis operations +- Background IPNS sync with other L1 nodes """ import logging @@ -19,17 +23,7 @@ import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from celery_app import app import ipfs_client - -# Redis for claiming and cache_id → CID mapping -import redis -REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5") -_redis: Optional[redis.Redis] = None - -def get_redis() -> redis.Redis: - global _redis - if _redis is None: - _redis = redis.from_url(REDIS_URL, decode_responses=True) - return _redis +from hybrid_state import get_state_manager # Import artdag try: @@ -44,10 +38,6 @@ except ImportError: logger = logging.getLogger(__name__) -# Redis keys -CACHE_KEY = "artdag:cid_cache" # hash: cache_id → CID -CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker_id - def get_worker_id() -> str: """Get unique worker identifier.""" @@ -56,24 +46,22 @@ def get_worker_id() -> str: def get_cached_cid(cache_id: str) -> Optional[str]: """Check if cache_id has a known CID.""" - return get_redis().hget(CACHE_KEY, cache_id) + return get_state_manager().get_cached_cid(cache_id) def set_cached_cid(cache_id: str, cid: str) -> None: """Store cache_id → CID mapping.""" - get_redis().hset(CACHE_KEY, cache_id, cid) + get_state_manager().set_cached_cid(cache_id, cid) def try_claim(cache_id: str, worker_id: str, ttl: int = 300) -> bool: """Try to claim a cache_id for execution. Returns True if claimed.""" - key = f"{CLAIM_KEY_PREFIX}{cache_id}" - return get_redis().set(key, worker_id, nx=True, ex=ttl) + return get_state_manager().try_claim(cache_id, worker_id, ttl) def release_claim(cache_id: str) -> None: """Release a claim.""" - key = f"{CLAIM_KEY_PREFIX}{cache_id}" - get_redis().delete(key) + get_state_manager().release_claim(cache_id) def wait_for_cid(cache_id: str, timeout: int = 600, poll_interval: float = 0.5) -> Optional[str]: diff --git a/tasks/orchestrate_cid.py b/tasks/orchestrate_cid.py index da4c724..c43264c 100644 --- a/tasks/orchestrate_cid.py +++ b/tasks/orchestrate_cid.py @@ -8,33 +8,25 @@ Everything on IPFS: - Step outputs (media files) The entire pipeline just passes CIDs around. + +Uses HybridStateManager for: +- Fast local Redis operations +- Background IPNS sync with other L1 nodes """ import json import logging import os -import shutil -import tempfile from pathlib import Path from typing import Dict, List, Optional -from celery import current_task, group +from celery import group import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from celery_app import app import ipfs_client - -# Redis for caching -import redis -REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/5") -_redis: Optional[redis.Redis] = None - -def get_redis() -> redis.Redis: - global _redis - if _redis is None: - _redis = redis.from_url(REDIS_URL, decode_responses=True) - return _redis +from hybrid_state import get_state_manager # Import artdag modules try: @@ -53,11 +45,6 @@ from .execute_cid import execute_step_cid logger = logging.getLogger(__name__) -# Redis keys -PLAN_CACHE_KEY = "artdag:plan_cid" # hash: plan_id → plan CID -RECIPE_CACHE_KEY = "artdag:recipe_cid" # hash: recipe_hash → recipe CID -RUN_CACHE_KEY = "artdag:run_cid" # hash: run_id → output CID - def compute_run_id(recipe_cid: str, input_cids: Dict[str, str]) -> str: """Compute deterministic run ID from recipe and inputs.""" @@ -203,7 +190,7 @@ def generate_plan_cid( return {"status": "failed", "error": "Failed to store plan on IPFS"} # Cache plan_id → plan_cid mapping - get_redis().hset(PLAN_CACHE_KEY, plan.plan_id, plan_cid) + get_state_manager().set_plan_cid(plan.plan_id, plan_cid) logger.info(f"[CID] Generated plan: {plan.plan_id[:16]}... → {plan_cid}") @@ -327,7 +314,7 @@ def run_recipe_cid( run_id = compute_run_id(recipe_cid, input_cids) # Check if run is already cached - cached_output = get_redis().hget(RUN_CACHE_KEY, run_id) + cached_output = get_state_manager().get_run_cid(run_id) if cached_output: logger.info(f"[CID] Run cache hit: {run_id[:16]}... → {cached_output}") return { @@ -385,7 +372,7 @@ def run_recipe_cid( output_cid = exec_result["output_cid"] # Cache the run - get_redis().hset(RUN_CACHE_KEY, run_id, output_cid) + get_state_manager().set_run_cid(run_id, output_cid) logger.info(f"[CID] Run complete: {run_id[:16]}... → {output_cid}")