""" Hash-based task claiming for distributed execution. Prevents duplicate work when multiple workers process the same plan. Uses Redis Lua scripts for atomic claim operations. """ import json import logging import os import time from dataclasses import dataclass from datetime import datetime, timezone from enum import Enum from typing import Optional import redis logger = logging.getLogger(__name__) REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5') # Key prefix for task claims CLAIM_PREFIX = "artdag:claim:" # Default TTL for claims (5 minutes) DEFAULT_CLAIM_TTL = 300 # TTL for completed results (1 hour) COMPLETED_TTL = 3600 class ClaimStatus(Enum): """Status of a task claim.""" PENDING = "pending" CLAIMED = "claimed" RUNNING = "running" COMPLETED = "completed" CACHED = "cached" FAILED = "failed" @dataclass class ClaimInfo: """Information about a task claim.""" cache_id: str status: ClaimStatus worker_id: Optional[str] = None task_id: Optional[str] = None claimed_at: Optional[str] = None completed_at: Optional[str] = None output_path: Optional[str] = None error: Optional[str] = None def to_dict(self) -> dict: return { "cache_id": self.cache_id, "status": self.status.value, "worker_id": self.worker_id, "task_id": self.task_id, "claimed_at": self.claimed_at, "completed_at": self.completed_at, "output_path": self.output_path, "error": self.error, } @classmethod def from_dict(cls, data: dict) -> "ClaimInfo": return cls( cache_id=data["cache_id"], status=ClaimStatus(data["status"]), worker_id=data.get("worker_id"), task_id=data.get("task_id"), claimed_at=data.get("claimed_at"), completed_at=data.get("completed_at"), output_path=data.get("output_path"), error=data.get("error"), ) # Lua script for atomic task claiming # Returns 1 if claim successful, 0 if already claimed/completed CLAIM_TASK_SCRIPT = """ local key = KEYS[1] local data = redis.call('GET', key) if data then local status = cjson.decode(data) local s = status['status'] -- Already claimed, running, completed, or cached - don't claim if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then return 0 end end -- Claim the task local claim_data = ARGV[1] local ttl = tonumber(ARGV[2]) redis.call('SETEX', key, ttl, claim_data) return 1 """ # Lua script for releasing a claim (e.g., on failure) RELEASE_CLAIM_SCRIPT = """ local key = KEYS[1] local worker_id = ARGV[1] local data = redis.call('GET', key) if data then local status = cjson.decode(data) -- Only release if we own the claim if status['worker_id'] == worker_id then redis.call('DEL', key) return 1 end end return 0 """ # Lua script for updating claim status (claimed -> running -> completed) UPDATE_STATUS_SCRIPT = """ local key = KEYS[1] local worker_id = ARGV[1] local new_status = ARGV[2] local new_data = ARGV[3] local ttl = tonumber(ARGV[4]) local data = redis.call('GET', key) if not data then return 0 end local status = cjson.decode(data) -- Only update if we own the claim if status['worker_id'] ~= worker_id then return 0 end redis.call('SETEX', key, ttl, new_data) return 1 """ class TaskClaimer: """ Manages hash-based task claiming for distributed execution. Uses Redis for coordination between workers. Each task is identified by its cache_id (content-addressed). """ def __init__(self, redis_url: str = None): """ Initialize the claimer. Args: redis_url: Redis connection URL """ self.redis_url = redis_url or REDIS_URL self._redis: Optional[redis.Redis] = None self._claim_script = None self._release_script = None self._update_script = None @property def redis(self) -> redis.Redis: """Get Redis connection (lazy initialization).""" if self._redis is None: self._redis = redis.from_url(self.redis_url, decode_responses=True) # Register Lua scripts self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT) self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT) self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT) return self._redis def _key(self, cache_id: str) -> str: """Get Redis key for a cache_id.""" return f"{CLAIM_PREFIX}{cache_id}" def claim( self, cache_id: str, worker_id: str, task_id: Optional[str] = None, ttl: int = DEFAULT_CLAIM_TTL, ) -> bool: """ Attempt to claim a task. Args: cache_id: The cache ID of the task to claim worker_id: Identifier for the claiming worker task_id: Optional Celery task ID ttl: Time-to-live for the claim in seconds Returns: True if claim successful, False if already claimed """ claim_info = ClaimInfo( cache_id=cache_id, status=ClaimStatus.CLAIMED, worker_id=worker_id, task_id=task_id, claimed_at=datetime.now(timezone.utc).isoformat(), ) result = self._claim_script( keys=[self._key(cache_id)], args=[json.dumps(claim_info.to_dict()), ttl], client=self.redis, ) if result == 1: logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}") return True else: logger.debug(f"Task {cache_id[:16]}... already claimed") return False def update_status( self, cache_id: str, worker_id: str, status: ClaimStatus, output_path: Optional[str] = None, error: Optional[str] = None, ttl: Optional[int] = None, ) -> bool: """ Update the status of a claimed task. Args: cache_id: The cache ID of the task worker_id: Worker ID that owns the claim status: New status output_path: Path to output (for completed) error: Error message (for failed) ttl: New TTL (defaults based on status) Returns: True if update successful """ if ttl is None: if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED): ttl = COMPLETED_TTL else: ttl = DEFAULT_CLAIM_TTL # Get existing claim info existing = self.get_status(cache_id) if not existing: logger.warning(f"No claim found for {cache_id[:16]}...") return False claim_info = ClaimInfo( cache_id=cache_id, status=status, worker_id=worker_id, task_id=existing.task_id, claimed_at=existing.claimed_at, completed_at=datetime.now(timezone.utc).isoformat() if status in ( ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED ) else None, output_path=output_path, error=error, ) result = self._update_script( keys=[self._key(cache_id)], args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl], client=self.redis, ) if result == 1: logger.debug(f"Updated task {cache_id[:16]}... to {status.value}") return True else: logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)") return False def release(self, cache_id: str, worker_id: str) -> bool: """ Release a claim (e.g., on task failure before completion). Args: cache_id: The cache ID of the task worker_id: Worker ID that owns the claim Returns: True if release successful """ result = self._release_script( keys=[self._key(cache_id)], args=[worker_id], client=self.redis, ) if result == 1: logger.debug(f"Released claim on {cache_id[:16]}...") return True return False def get_status(self, cache_id: str) -> Optional[ClaimInfo]: """ Get the current status of a task. Args: cache_id: The cache ID of the task Returns: ClaimInfo if task has been claimed, None otherwise """ data = self.redis.get(self._key(cache_id)) if data: return ClaimInfo.from_dict(json.loads(data)) return None def is_completed(self, cache_id: str) -> bool: """Check if a task is completed or cached.""" info = self.get_status(cache_id) return info is not None and info.status in ( ClaimStatus.COMPLETED, ClaimStatus.CACHED ) def wait_for_completion( self, cache_id: str, timeout: float = 300, poll_interval: float = 0.5, ) -> Optional[ClaimInfo]: """ Wait for a task to complete. Args: cache_id: The cache ID of the task timeout: Maximum time to wait in seconds poll_interval: How often to check status Returns: ClaimInfo if completed, None if timeout """ start_time = time.time() while time.time() - start_time < timeout: info = self.get_status(cache_id) if info and info.status in ( ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED ): return info time.sleep(poll_interval) logger.warning(f"Timeout waiting for {cache_id[:16]}...") return None def mark_cached(self, cache_id: str, output_path: str) -> None: """ Mark a task as already cached (no processing needed). This is used when we discover the result already exists before attempting to claim. Args: cache_id: The cache ID of the task output_path: Path to the cached output """ claim_info = ClaimInfo( cache_id=cache_id, status=ClaimStatus.CACHED, output_path=output_path, completed_at=datetime.now(timezone.utc).isoformat(), ) self.redis.setex( self._key(cache_id), COMPLETED_TTL, json.dumps(claim_info.to_dict()), ) def clear_all(self) -> int: """ Clear all claims (for testing/reset). Returns: Number of claims cleared """ pattern = f"{CLAIM_PREFIX}*" keys = list(self.redis.scan_iter(match=pattern)) if keys: return self.redis.delete(*keys) return 0 # Global claimer instance _claimer: Optional[TaskClaimer] = None def get_claimer() -> TaskClaimer: """Get the global TaskClaimer instance.""" global _claimer if _claimer is None: _claimer = TaskClaimer() return _claimer def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool: """Convenience function to claim a task.""" return get_claimer().claim(cache_id, worker_id, task_id) def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool: """Convenience function to mark a task as completed.""" return get_claimer().update_status( cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path ) def fail_task(cache_id: str, worker_id: str, error: str) -> bool: """Convenience function to mark a task as failed.""" return get_claimer().update_status( cache_id, worker_id, ClaimStatus.FAILED, error=error )