Files
celery/hybrid_state.py
gilesb ca8bfd8705 Add hybrid state manager for distributed L1 coordination
Implements HybridStateManager providing fast local Redis operations
with background IPNS sync for eventual consistency across L1 nodes.

- hybrid_state.py: Centralized state management (cache, claims, analysis, plans, runs)
- Updated execute_cid.py, analyze_cid.py, orchestrate_cid.py to use state manager
- Background IPNS sync (configurable interval, disabled by default)
- Atomic claiming with Redis SETNX for preventing duplicate work

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-11 09:36:14 +00:00

295 lines
9.8 KiB
Python

"""
Hybrid State Manager: Local Redis + IPNS Sync.
Provides fast local operations with eventual consistency across L1 nodes.
- Local Redis: Fast reads/writes (microseconds)
- IPNS Sync: Background sync with other nodes (every N seconds)
- Duplicate work: Accepted, idempotent (same inputs → same CID)
Usage:
from hybrid_state import get_state_manager
state = get_state_manager()
# Fast local lookup
cid = state.get_cached_cid(cache_id)
# Fast local write (synced in background)
state.set_cached_cid(cache_id, output_cid)
"""
import json
import logging
import os
import threading
import time
from typing import Dict, Optional
import redis
logger = logging.getLogger(__name__)
# Configuration
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/5")
CLUSTER_KEY = os.environ.get("ARTDAG_CLUSTER_KEY", "default")
IPNS_SYNC_INTERVAL = int(os.environ.get("ARTDAG_IPNS_SYNC_INTERVAL", "30"))
IPNS_ENABLED = os.environ.get("ARTDAG_IPNS_SYNC", "").lower() in ("true", "1", "yes")
# Redis keys
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → output CID
ANALYSIS_KEY = "artdag:analysis_cache" # hash: input_hash:features → analysis CID
PLAN_KEY = "artdag:plan_cache" # hash: plan_id → plan CID
RUN_KEY = "artdag:run_cache" # hash: run_id → output CID
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker (with TTL)
# IPNS names (relative to cluster key)
IPNS_CACHE_NAME = "cache"
IPNS_ANALYSIS_NAME = "analysis"
IPNS_PLAN_NAME = "plans"
class HybridStateManager:
"""
Local Redis + async IPNS sync for distributed L1 coordination.
Fast path (local Redis):
- get_cached_cid / set_cached_cid
- try_claim / release_claim
Slow path (background IPNS sync):
- Periodically syncs local state with global IPNS state
- Merges remote state into local (pulls new entries)
- Publishes local state to IPNS (pushes updates)
"""
def __init__(
self,
redis_url: str = REDIS_URL,
cluster_key: str = CLUSTER_KEY,
sync_interval: int = IPNS_SYNC_INTERVAL,
ipns_enabled: bool = IPNS_ENABLED,
):
self.cluster_key = cluster_key
self.sync_interval = sync_interval
self.ipns_enabled = ipns_enabled
# Connect to Redis
self._redis = redis.from_url(redis_url, decode_responses=True)
# IPNS client (lazy import)
self._ipfs = None
# Sync thread
self._sync_thread = None
self._stop_sync = threading.Event()
# Start background sync if enabled
if self.ipns_enabled:
self._start_background_sync()
@property
def ipfs(self):
"""Lazy import of IPFS client."""
if self._ipfs is None:
try:
import ipfs_client
self._ipfs = ipfs_client
except ImportError:
logger.warning("ipfs_client not available, IPNS sync disabled")
self._ipfs = False
return self._ipfs if self._ipfs else None
# ========== CID Cache ==========
def get_cached_cid(self, cache_id: str) -> Optional[str]:
"""Get output CID for a cache_id. Fast local lookup."""
return self._redis.hget(CACHE_KEY, cache_id)
def set_cached_cid(self, cache_id: str, cid: str) -> None:
"""Set output CID for a cache_id. Fast local write."""
self._redis.hset(CACHE_KEY, cache_id, cid)
def get_all_cached_cids(self) -> Dict[str, str]:
"""Get all cached CIDs."""
return self._redis.hgetall(CACHE_KEY)
# ========== Analysis Cache ==========
def get_analysis_cid(self, input_hash: str, features: list) -> Optional[str]:
"""Get analysis CID for input + features."""
key = f"{input_hash}:{','.join(sorted(features))}"
return self._redis.hget(ANALYSIS_KEY, key)
def set_analysis_cid(self, input_hash: str, features: list, cid: str) -> None:
"""Set analysis CID for input + features."""
key = f"{input_hash}:{','.join(sorted(features))}"
self._redis.hset(ANALYSIS_KEY, key, cid)
def get_all_analysis_cids(self) -> Dict[str, str]:
"""Get all analysis CIDs."""
return self._redis.hgetall(ANALYSIS_KEY)
# ========== Plan Cache ==========
def get_plan_cid(self, plan_id: str) -> Optional[str]:
"""Get plan CID for a plan_id."""
return self._redis.hget(PLAN_KEY, plan_id)
def set_plan_cid(self, plan_id: str, cid: str) -> None:
"""Set plan CID for a plan_id."""
self._redis.hset(PLAN_KEY, plan_id, cid)
def get_all_plan_cids(self) -> Dict[str, str]:
"""Get all plan CIDs."""
return self._redis.hgetall(PLAN_KEY)
# ========== Run Cache ==========
def get_run_cid(self, run_id: str) -> Optional[str]:
"""Get output CID for a run_id."""
return self._redis.hget(RUN_KEY, run_id)
def set_run_cid(self, run_id: str, cid: str) -> None:
"""Set output CID for a run_id."""
self._redis.hset(RUN_KEY, run_id, cid)
# ========== Claiming ==========
def try_claim(self, cache_id: str, worker_id: str, ttl: int = 300) -> bool:
"""
Try to claim a cache_id for execution.
Returns True if claimed, False if already claimed by another worker.
Uses Redis SETNX for atomic claim.
"""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
return self._redis.set(key, worker_id, nx=True, ex=ttl)
def release_claim(self, cache_id: str) -> None:
"""Release a claim."""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
self._redis.delete(key)
def get_claim(self, cache_id: str) -> Optional[str]:
"""Get current claim holder for a cache_id."""
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
return self._redis.get(key)
# ========== IPNS Sync ==========
def _start_background_sync(self):
"""Start background IPNS sync thread."""
if self._sync_thread is not None:
return
def sync_loop():
logger.info(f"IPNS sync started (interval={self.sync_interval}s)")
while not self._stop_sync.wait(timeout=self.sync_interval):
try:
self._sync_with_ipns()
except Exception as e:
logger.warning(f"IPNS sync failed: {e}")
self._sync_thread = threading.Thread(target=sync_loop, daemon=True)
self._sync_thread.start()
def stop_sync(self):
"""Stop background sync thread."""
self._stop_sync.set()
if self._sync_thread:
self._sync_thread.join(timeout=5)
def _sync_with_ipns(self):
"""Sync local state with IPNS global state."""
if not self.ipfs:
return
logger.debug("Starting IPNS sync...")
# Sync each cache type
self._sync_hash(CACHE_KEY, IPNS_CACHE_NAME)
self._sync_hash(ANALYSIS_KEY, IPNS_ANALYSIS_NAME)
self._sync_hash(PLAN_KEY, IPNS_PLAN_NAME)
logger.debug("IPNS sync complete")
def _sync_hash(self, redis_key: str, ipns_name: str):
"""Sync a Redis hash with IPNS."""
ipns_full_name = f"{self.cluster_key}/{ipns_name}"
# Pull: resolve IPNS → get global state
global_state = {}
try:
global_cid = self.ipfs.name_resolve(ipns_full_name)
if global_cid:
global_bytes = self.ipfs.get_bytes(global_cid)
if global_bytes:
global_state = json.loads(global_bytes.decode('utf-8'))
logger.debug(f"Pulled {len(global_state)} entries from {ipns_name}")
except Exception as e:
logger.debug(f"Could not resolve {ipns_full_name}: {e}")
# Merge global into local (add entries we don't have)
if global_state:
pipe = self._redis.pipeline()
for key, value in global_state.items():
pipe.hsetnx(redis_key, key, value)
results = pipe.execute()
added = sum(1 for r in results if r)
if added:
logger.info(f"Merged {added} new entries from IPNS/{ipns_name}")
# Push: get local state, merge with global, publish
local_state = self._redis.hgetall(redis_key)
if local_state:
merged = {**global_state, **local_state}
# Only publish if we have new entries
if len(merged) > len(global_state):
try:
new_cid = self.ipfs.add_json(merged)
if new_cid:
# Note: name_publish can be slow
self.ipfs.name_publish(ipns_full_name, new_cid)
logger.info(f"Published {len(merged)} entries to IPNS/{ipns_name}")
except Exception as e:
logger.warning(f"Failed to publish to {ipns_full_name}: {e}")
def force_sync(self):
"""Force an immediate IPNS sync (blocking)."""
self._sync_with_ipns()
# ========== Stats ==========
def get_stats(self) -> Dict:
"""Get cache statistics."""
return {
"cached_cids": self._redis.hlen(CACHE_KEY),
"analysis_cids": self._redis.hlen(ANALYSIS_KEY),
"plan_cids": self._redis.hlen(PLAN_KEY),
"run_cids": self._redis.hlen(RUN_KEY),
"ipns_enabled": self.ipns_enabled,
"cluster_key": self.cluster_key[:16] + "..." if len(self.cluster_key) > 16 else self.cluster_key,
}
# Singleton instance
_state_manager: Optional[HybridStateManager] = None
def get_state_manager() -> HybridStateManager:
"""Get the singleton state manager instance."""
global _state_manager
if _state_manager is None:
_state_manager = HybridStateManager()
return _state_manager
def reset_state_manager():
"""Reset the singleton (for testing)."""
global _state_manager
if _state_manager:
_state_manager.stop_sync()
_state_manager = None