Implements HybridStateManager providing fast local Redis operations with background IPNS sync for eventual consistency across L1 nodes. - hybrid_state.py: Centralized state management (cache, claims, analysis, plans, runs) - Updated execute_cid.py, analyze_cid.py, orchestrate_cid.py to use state manager - Background IPNS sync (configurable interval, disabled by default) - Atomic claiming with Redis SETNX for preventing duplicate work Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
295 lines
9.8 KiB
Python
295 lines
9.8 KiB
Python
"""
|
|
Hybrid State Manager: Local Redis + IPNS Sync.
|
|
|
|
Provides fast local operations with eventual consistency across L1 nodes.
|
|
|
|
- Local Redis: Fast reads/writes (microseconds)
|
|
- IPNS Sync: Background sync with other nodes (every N seconds)
|
|
- Duplicate work: Accepted, idempotent (same inputs → same CID)
|
|
|
|
Usage:
|
|
from hybrid_state import get_state_manager
|
|
|
|
state = get_state_manager()
|
|
|
|
# Fast local lookup
|
|
cid = state.get_cached_cid(cache_id)
|
|
|
|
# Fast local write (synced in background)
|
|
state.set_cached_cid(cache_id, output_cid)
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Dict, Optional
|
|
|
|
import redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/5")
|
|
CLUSTER_KEY = os.environ.get("ARTDAG_CLUSTER_KEY", "default")
|
|
IPNS_SYNC_INTERVAL = int(os.environ.get("ARTDAG_IPNS_SYNC_INTERVAL", "30"))
|
|
IPNS_ENABLED = os.environ.get("ARTDAG_IPNS_SYNC", "").lower() in ("true", "1", "yes")
|
|
|
|
# Redis keys
|
|
CACHE_KEY = "artdag:cid_cache" # hash: cache_id → output CID
|
|
ANALYSIS_KEY = "artdag:analysis_cache" # hash: input_hash:features → analysis CID
|
|
PLAN_KEY = "artdag:plan_cache" # hash: plan_id → plan CID
|
|
RUN_KEY = "artdag:run_cache" # hash: run_id → output CID
|
|
CLAIM_KEY_PREFIX = "artdag:claim:" # string: cache_id → worker (with TTL)
|
|
|
|
# IPNS names (relative to cluster key)
|
|
IPNS_CACHE_NAME = "cache"
|
|
IPNS_ANALYSIS_NAME = "analysis"
|
|
IPNS_PLAN_NAME = "plans"
|
|
|
|
|
|
class HybridStateManager:
|
|
"""
|
|
Local Redis + async IPNS sync for distributed L1 coordination.
|
|
|
|
Fast path (local Redis):
|
|
- get_cached_cid / set_cached_cid
|
|
- try_claim / release_claim
|
|
|
|
Slow path (background IPNS sync):
|
|
- Periodically syncs local state with global IPNS state
|
|
- Merges remote state into local (pulls new entries)
|
|
- Publishes local state to IPNS (pushes updates)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_url: str = REDIS_URL,
|
|
cluster_key: str = CLUSTER_KEY,
|
|
sync_interval: int = IPNS_SYNC_INTERVAL,
|
|
ipns_enabled: bool = IPNS_ENABLED,
|
|
):
|
|
self.cluster_key = cluster_key
|
|
self.sync_interval = sync_interval
|
|
self.ipns_enabled = ipns_enabled
|
|
|
|
# Connect to Redis
|
|
self._redis = redis.from_url(redis_url, decode_responses=True)
|
|
|
|
# IPNS client (lazy import)
|
|
self._ipfs = None
|
|
|
|
# Sync thread
|
|
self._sync_thread = None
|
|
self._stop_sync = threading.Event()
|
|
|
|
# Start background sync if enabled
|
|
if self.ipns_enabled:
|
|
self._start_background_sync()
|
|
|
|
@property
|
|
def ipfs(self):
|
|
"""Lazy import of IPFS client."""
|
|
if self._ipfs is None:
|
|
try:
|
|
import ipfs_client
|
|
self._ipfs = ipfs_client
|
|
except ImportError:
|
|
logger.warning("ipfs_client not available, IPNS sync disabled")
|
|
self._ipfs = False
|
|
return self._ipfs if self._ipfs else None
|
|
|
|
# ========== CID Cache ==========
|
|
|
|
def get_cached_cid(self, cache_id: str) -> Optional[str]:
|
|
"""Get output CID for a cache_id. Fast local lookup."""
|
|
return self._redis.hget(CACHE_KEY, cache_id)
|
|
|
|
def set_cached_cid(self, cache_id: str, cid: str) -> None:
|
|
"""Set output CID for a cache_id. Fast local write."""
|
|
self._redis.hset(CACHE_KEY, cache_id, cid)
|
|
|
|
def get_all_cached_cids(self) -> Dict[str, str]:
|
|
"""Get all cached CIDs."""
|
|
return self._redis.hgetall(CACHE_KEY)
|
|
|
|
# ========== Analysis Cache ==========
|
|
|
|
def get_analysis_cid(self, input_hash: str, features: list) -> Optional[str]:
|
|
"""Get analysis CID for input + features."""
|
|
key = f"{input_hash}:{','.join(sorted(features))}"
|
|
return self._redis.hget(ANALYSIS_KEY, key)
|
|
|
|
def set_analysis_cid(self, input_hash: str, features: list, cid: str) -> None:
|
|
"""Set analysis CID for input + features."""
|
|
key = f"{input_hash}:{','.join(sorted(features))}"
|
|
self._redis.hset(ANALYSIS_KEY, key, cid)
|
|
|
|
def get_all_analysis_cids(self) -> Dict[str, str]:
|
|
"""Get all analysis CIDs."""
|
|
return self._redis.hgetall(ANALYSIS_KEY)
|
|
|
|
# ========== Plan Cache ==========
|
|
|
|
def get_plan_cid(self, plan_id: str) -> Optional[str]:
|
|
"""Get plan CID for a plan_id."""
|
|
return self._redis.hget(PLAN_KEY, plan_id)
|
|
|
|
def set_plan_cid(self, plan_id: str, cid: str) -> None:
|
|
"""Set plan CID for a plan_id."""
|
|
self._redis.hset(PLAN_KEY, plan_id, cid)
|
|
|
|
def get_all_plan_cids(self) -> Dict[str, str]:
|
|
"""Get all plan CIDs."""
|
|
return self._redis.hgetall(PLAN_KEY)
|
|
|
|
# ========== Run Cache ==========
|
|
|
|
def get_run_cid(self, run_id: str) -> Optional[str]:
|
|
"""Get output CID for a run_id."""
|
|
return self._redis.hget(RUN_KEY, run_id)
|
|
|
|
def set_run_cid(self, run_id: str, cid: str) -> None:
|
|
"""Set output CID for a run_id."""
|
|
self._redis.hset(RUN_KEY, run_id, cid)
|
|
|
|
# ========== Claiming ==========
|
|
|
|
def try_claim(self, cache_id: str, worker_id: str, ttl: int = 300) -> bool:
|
|
"""
|
|
Try to claim a cache_id for execution.
|
|
|
|
Returns True if claimed, False if already claimed by another worker.
|
|
Uses Redis SETNX for atomic claim.
|
|
"""
|
|
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
|
return self._redis.set(key, worker_id, nx=True, ex=ttl)
|
|
|
|
def release_claim(self, cache_id: str) -> None:
|
|
"""Release a claim."""
|
|
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
|
self._redis.delete(key)
|
|
|
|
def get_claim(self, cache_id: str) -> Optional[str]:
|
|
"""Get current claim holder for a cache_id."""
|
|
key = f"{CLAIM_KEY_PREFIX}{cache_id}"
|
|
return self._redis.get(key)
|
|
|
|
# ========== IPNS Sync ==========
|
|
|
|
def _start_background_sync(self):
|
|
"""Start background IPNS sync thread."""
|
|
if self._sync_thread is not None:
|
|
return
|
|
|
|
def sync_loop():
|
|
logger.info(f"IPNS sync started (interval={self.sync_interval}s)")
|
|
while not self._stop_sync.wait(timeout=self.sync_interval):
|
|
try:
|
|
self._sync_with_ipns()
|
|
except Exception as e:
|
|
logger.warning(f"IPNS sync failed: {e}")
|
|
|
|
self._sync_thread = threading.Thread(target=sync_loop, daemon=True)
|
|
self._sync_thread.start()
|
|
|
|
def stop_sync(self):
|
|
"""Stop background sync thread."""
|
|
self._stop_sync.set()
|
|
if self._sync_thread:
|
|
self._sync_thread.join(timeout=5)
|
|
|
|
def _sync_with_ipns(self):
|
|
"""Sync local state with IPNS global state."""
|
|
if not self.ipfs:
|
|
return
|
|
|
|
logger.debug("Starting IPNS sync...")
|
|
|
|
# Sync each cache type
|
|
self._sync_hash(CACHE_KEY, IPNS_CACHE_NAME)
|
|
self._sync_hash(ANALYSIS_KEY, IPNS_ANALYSIS_NAME)
|
|
self._sync_hash(PLAN_KEY, IPNS_PLAN_NAME)
|
|
|
|
logger.debug("IPNS sync complete")
|
|
|
|
def _sync_hash(self, redis_key: str, ipns_name: str):
|
|
"""Sync a Redis hash with IPNS."""
|
|
ipns_full_name = f"{self.cluster_key}/{ipns_name}"
|
|
|
|
# Pull: resolve IPNS → get global state
|
|
global_state = {}
|
|
try:
|
|
global_cid = self.ipfs.name_resolve(ipns_full_name)
|
|
if global_cid:
|
|
global_bytes = self.ipfs.get_bytes(global_cid)
|
|
if global_bytes:
|
|
global_state = json.loads(global_bytes.decode('utf-8'))
|
|
logger.debug(f"Pulled {len(global_state)} entries from {ipns_name}")
|
|
except Exception as e:
|
|
logger.debug(f"Could not resolve {ipns_full_name}: {e}")
|
|
|
|
# Merge global into local (add entries we don't have)
|
|
if global_state:
|
|
pipe = self._redis.pipeline()
|
|
for key, value in global_state.items():
|
|
pipe.hsetnx(redis_key, key, value)
|
|
results = pipe.execute()
|
|
added = sum(1 for r in results if r)
|
|
if added:
|
|
logger.info(f"Merged {added} new entries from IPNS/{ipns_name}")
|
|
|
|
# Push: get local state, merge with global, publish
|
|
local_state = self._redis.hgetall(redis_key)
|
|
if local_state:
|
|
merged = {**global_state, **local_state}
|
|
|
|
# Only publish if we have new entries
|
|
if len(merged) > len(global_state):
|
|
try:
|
|
new_cid = self.ipfs.add_json(merged)
|
|
if new_cid:
|
|
# Note: name_publish can be slow
|
|
self.ipfs.name_publish(ipns_full_name, new_cid)
|
|
logger.info(f"Published {len(merged)} entries to IPNS/{ipns_name}")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to publish to {ipns_full_name}: {e}")
|
|
|
|
def force_sync(self):
|
|
"""Force an immediate IPNS sync (blocking)."""
|
|
self._sync_with_ipns()
|
|
|
|
# ========== Stats ==========
|
|
|
|
def get_stats(self) -> Dict:
|
|
"""Get cache statistics."""
|
|
return {
|
|
"cached_cids": self._redis.hlen(CACHE_KEY),
|
|
"analysis_cids": self._redis.hlen(ANALYSIS_KEY),
|
|
"plan_cids": self._redis.hlen(PLAN_KEY),
|
|
"run_cids": self._redis.hlen(RUN_KEY),
|
|
"ipns_enabled": self.ipns_enabled,
|
|
"cluster_key": self.cluster_key[:16] + "..." if len(self.cluster_key) > 16 else self.cluster_key,
|
|
}
|
|
|
|
|
|
# Singleton instance
|
|
_state_manager: Optional[HybridStateManager] = None
|
|
|
|
|
|
def get_state_manager() -> HybridStateManager:
|
|
"""Get the singleton state manager instance."""
|
|
global _state_manager
|
|
if _state_manager is None:
|
|
_state_manager = HybridStateManager()
|
|
return _state_manager
|
|
|
|
|
|
def reset_state_manager():
|
|
"""Reset the singleton (for testing)."""
|
|
global _state_manager
|
|
if _state_manager:
|
|
_state_manager.stop_sync()
|
|
_state_manager = None
|