Add 3-phase execution with IPFS cache and hash-based task claiming
New files:
- claiming.py - Redis Lua scripts for atomic task claiming
- tasks/analyze.py - Analysis Celery task
- tasks/execute.py - Step execution with IPFS-backed cache
- tasks/orchestrate.py - Plan orchestration (run_plan, run_recipe)
New API endpoints (/api/v2/):
- POST /api/v2/plan - Generate execution plan
- POST /api/v2/execute - Execute a plan
- POST /api/v2/run-recipe - Full 3-phase pipeline
- GET /api/v2/run/{run_id} - Get run status
Features:
- Hash-based task claiming prevents duplicate work
- Parallel execution within dependency levels
- IPFS-backed cache for durability
- Integration with artdag planning module
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
421
claiming.py
Normal file
421
claiming.py
Normal file
@@ -0,0 +1,421 @@
|
||||
"""
|
||||
Hash-based task claiming for distributed execution.
|
||||
|
||||
Prevents duplicate work when multiple workers process the same plan.
|
||||
Uses Redis Lua scripts for atomic claim operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
|
||||
|
||||
# Key prefix for task claims
|
||||
CLAIM_PREFIX = "artdag:claim:"
|
||||
|
||||
# Default TTL for claims (5 minutes)
|
||||
DEFAULT_CLAIM_TTL = 300
|
||||
|
||||
# TTL for completed results (1 hour)
|
||||
COMPLETED_TTL = 3600
|
||||
|
||||
|
||||
class ClaimStatus(Enum):
|
||||
"""Status of a task claim."""
|
||||
PENDING = "pending"
|
||||
CLAIMED = "claimed"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
CACHED = "cached"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimInfo:
|
||||
"""Information about a task claim."""
|
||||
cache_id: str
|
||||
status: ClaimStatus
|
||||
worker_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
claimed_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
output_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"cache_id": self.cache_id,
|
||||
"status": self.status.value,
|
||||
"worker_id": self.worker_id,
|
||||
"task_id": self.task_id,
|
||||
"claimed_at": self.claimed_at,
|
||||
"completed_at": self.completed_at,
|
||||
"output_path": self.output_path,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ClaimInfo":
|
||||
return cls(
|
||||
cache_id=data["cache_id"],
|
||||
status=ClaimStatus(data["status"]),
|
||||
worker_id=data.get("worker_id"),
|
||||
task_id=data.get("task_id"),
|
||||
claimed_at=data.get("claimed_at"),
|
||||
completed_at=data.get("completed_at"),
|
||||
output_path=data.get("output_path"),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
# Lua script for atomic task claiming
|
||||
# Returns 1 if claim successful, 0 if already claimed/completed
|
||||
CLAIM_TASK_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local data = redis.call('GET', key)
|
||||
|
||||
if data then
|
||||
local status = cjson.decode(data)
|
||||
local s = status['status']
|
||||
-- Already claimed, running, completed, or cached - don't claim
|
||||
if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then
|
||||
return 0
|
||||
end
|
||||
end
|
||||
|
||||
-- Claim the task
|
||||
local claim_data = ARGV[1]
|
||||
local ttl = tonumber(ARGV[2])
|
||||
redis.call('SETEX', key, ttl, claim_data)
|
||||
return 1
|
||||
"""
|
||||
|
||||
# Lua script for releasing a claim (e.g., on failure)
|
||||
RELEASE_CLAIM_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local worker_id = ARGV[1]
|
||||
local data = redis.call('GET', key)
|
||||
|
||||
if data then
|
||||
local status = cjson.decode(data)
|
||||
-- Only release if we own the claim
|
||||
if status['worker_id'] == worker_id then
|
||||
redis.call('DEL', key)
|
||||
return 1
|
||||
end
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
# Lua script for updating claim status (claimed -> running -> completed)
|
||||
UPDATE_STATUS_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local worker_id = ARGV[1]
|
||||
local new_status = ARGV[2]
|
||||
local new_data = ARGV[3]
|
||||
local ttl = tonumber(ARGV[4])
|
||||
|
||||
local data = redis.call('GET', key)
|
||||
if not data then
|
||||
return 0
|
||||
end
|
||||
|
||||
local status = cjson.decode(data)
|
||||
|
||||
-- Only update if we own the claim
|
||||
if status['worker_id'] ~= worker_id then
|
||||
return 0
|
||||
end
|
||||
|
||||
redis.call('SETEX', key, ttl, new_data)
|
||||
return 1
|
||||
"""
|
||||
|
||||
|
||||
class TaskClaimer:
|
||||
"""
|
||||
Manages hash-based task claiming for distributed execution.
|
||||
|
||||
Uses Redis for coordination between workers.
|
||||
Each task is identified by its cache_id (content-addressed).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_url: str = None):
|
||||
"""
|
||||
Initialize the claimer.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL
|
||||
"""
|
||||
self.redis_url = redis_url or REDIS_URL
|
||||
self._redis: Optional[redis.Redis] = None
|
||||
self._claim_script = None
|
||||
self._release_script = None
|
||||
self._update_script = None
|
||||
|
||||
@property
|
||||
def redis(self) -> redis.Redis:
|
||||
"""Get Redis connection (lazy initialization)."""
|
||||
if self._redis is None:
|
||||
self._redis = redis.from_url(self.redis_url, decode_responses=True)
|
||||
# Register Lua scripts
|
||||
self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT)
|
||||
self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT)
|
||||
self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT)
|
||||
return self._redis
|
||||
|
||||
def _key(self, cache_id: str) -> str:
|
||||
"""Get Redis key for a cache_id."""
|
||||
return f"{CLAIM_PREFIX}{cache_id}"
|
||||
|
||||
def claim(
|
||||
self,
|
||||
cache_id: str,
|
||||
worker_id: str,
|
||||
task_id: Optional[str] = None,
|
||||
ttl: int = DEFAULT_CLAIM_TTL,
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to claim a task.
|
||||
|
||||
Args:
|
||||
cache_id: The cache ID of the task to claim
|
||||
worker_id: Identifier for the claiming worker
|
||||
task_id: Optional Celery task ID
|
||||
ttl: Time-to-live for the claim in seconds
|
||||
|
||||
Returns:
|
||||
True if claim successful, False if already claimed
|
||||
"""
|
||||
claim_info = ClaimInfo(
|
||||
cache_id=cache_id,
|
||||
status=ClaimStatus.CLAIMED,
|
||||
worker_id=worker_id,
|
||||
task_id=task_id,
|
||||
claimed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
result = self._claim_script(
|
||||
keys=[self._key(cache_id)],
|
||||
args=[json.dumps(claim_info.to_dict()), ttl],
|
||||
client=self.redis,
|
||||
)
|
||||
|
||||
if result == 1:
|
||||
logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}")
|
||||
return True
|
||||
else:
|
||||
logger.debug(f"Task {cache_id[:16]}... already claimed")
|
||||
return False
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
cache_id: str,
|
||||
worker_id: str,
|
||||
status: ClaimStatus,
|
||||
output_path: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
ttl: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Update the status of a claimed task.
|
||||
|
||||
Args:
|
||||
cache_id: The cache ID of the task
|
||||
worker_id: Worker ID that owns the claim
|
||||
status: New status
|
||||
output_path: Path to output (for completed)
|
||||
error: Error message (for failed)
|
||||
ttl: New TTL (defaults based on status)
|
||||
|
||||
Returns:
|
||||
True if update successful
|
||||
"""
|
||||
if ttl is None:
|
||||
if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED):
|
||||
ttl = COMPLETED_TTL
|
||||
else:
|
||||
ttl = DEFAULT_CLAIM_TTL
|
||||
|
||||
# Get existing claim info
|
||||
existing = self.get_status(cache_id)
|
||||
if not existing:
|
||||
logger.warning(f"No claim found for {cache_id[:16]}...")
|
||||
return False
|
||||
|
||||
claim_info = ClaimInfo(
|
||||
cache_id=cache_id,
|
||||
status=status,
|
||||
worker_id=worker_id,
|
||||
task_id=existing.task_id,
|
||||
claimed_at=existing.claimed_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat() if status in (
|
||||
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
|
||||
) else None,
|
||||
output_path=output_path,
|
||||
error=error,
|
||||
)
|
||||
|
||||
result = self._update_script(
|
||||
keys=[self._key(cache_id)],
|
||||
args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl],
|
||||
client=self.redis,
|
||||
)
|
||||
|
||||
if result == 1:
|
||||
logger.debug(f"Updated task {cache_id[:16]}... to {status.value}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)")
|
||||
return False
|
||||
|
||||
def release(self, cache_id: str, worker_id: str) -> bool:
|
||||
"""
|
||||
Release a claim (e.g., on task failure before completion).
|
||||
|
||||
Args:
|
||||
cache_id: The cache ID of the task
|
||||
worker_id: Worker ID that owns the claim
|
||||
|
||||
Returns:
|
||||
True if release successful
|
||||
"""
|
||||
result = self._release_script(
|
||||
keys=[self._key(cache_id)],
|
||||
args=[worker_id],
|
||||
client=self.redis,
|
||||
)
|
||||
|
||||
if result == 1:
|
||||
logger.debug(f"Released claim on {cache_id[:16]}...")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_status(self, cache_id: str) -> Optional[ClaimInfo]:
|
||||
"""
|
||||
Get the current status of a task.
|
||||
|
||||
Args:
|
||||
cache_id: The cache ID of the task
|
||||
|
||||
Returns:
|
||||
ClaimInfo if task has been claimed, None otherwise
|
||||
"""
|
||||
data = self.redis.get(self._key(cache_id))
|
||||
if data:
|
||||
return ClaimInfo.from_dict(json.loads(data))
|
||||
return None
|
||||
|
||||
def is_completed(self, cache_id: str) -> bool:
|
||||
"""Check if a task is completed or cached."""
|
||||
info = self.get_status(cache_id)
|
||||
return info is not None and info.status in (
|
||||
ClaimStatus.COMPLETED, ClaimStatus.CACHED
|
||||
)
|
||||
|
||||
def wait_for_completion(
|
||||
self,
|
||||
cache_id: str,
|
||||
timeout: float = 300,
|
||||
poll_interval: float = 0.5,
|
||||
) -> Optional[ClaimInfo]:
|
||||
"""
|
||||
Wait for a task to complete.
|
||||
|
||||
Args:
|
||||
cache_id: The cache ID of the task
|
||||
timeout: Maximum time to wait in seconds
|
||||
poll_interval: How often to check status
|
||||
|
||||
Returns:
|
||||
ClaimInfo if completed, None if timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
info = self.get_status(cache_id)
|
||||
if info and info.status in (
|
||||
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
|
||||
):
|
||||
return info
|
||||
time.sleep(poll_interval)
|
||||
|
||||
logger.warning(f"Timeout waiting for {cache_id[:16]}...")
|
||||
return None
|
||||
|
||||
def mark_cached(self, cache_id: str, output_path: str) -> None:
|
||||
"""
|
||||
Mark a task as already cached (no processing needed).
|
||||
|
||||
This is used when we discover the result already exists
|
||||
before attempting to claim.
|
||||
|
||||
Args:
|
||||
cache_id: The cache ID of the task
|
||||
output_path: Path to the cached output
|
||||
"""
|
||||
claim_info = ClaimInfo(
|
||||
cache_id=cache_id,
|
||||
status=ClaimStatus.CACHED,
|
||||
output_path=output_path,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
self.redis.setex(
|
||||
self._key(cache_id),
|
||||
COMPLETED_TTL,
|
||||
json.dumps(claim_info.to_dict()),
|
||||
)
|
||||
|
||||
def clear_all(self) -> int:
|
||||
"""
|
||||
Clear all claims (for testing/reset).
|
||||
|
||||
Returns:
|
||||
Number of claims cleared
|
||||
"""
|
||||
pattern = f"{CLAIM_PREFIX}*"
|
||||
keys = list(self.redis.scan_iter(match=pattern))
|
||||
if keys:
|
||||
return self.redis.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
# Global claimer instance
|
||||
_claimer: Optional[TaskClaimer] = None
|
||||
|
||||
|
||||
def get_claimer() -> TaskClaimer:
|
||||
"""Get the global TaskClaimer instance."""
|
||||
global _claimer
|
||||
if _claimer is None:
|
||||
_claimer = TaskClaimer()
|
||||
return _claimer
|
||||
|
||||
|
||||
def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool:
|
||||
"""Convenience function to claim a task."""
|
||||
return get_claimer().claim(cache_id, worker_id, task_id)
|
||||
|
||||
|
||||
def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool:
|
||||
"""Convenience function to mark a task as completed."""
|
||||
return get_claimer().update_status(
|
||||
cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path
|
||||
)
|
||||
|
||||
|
||||
def fail_task(cache_id: str, worker_id: str, error: str) -> bool:
|
||||
"""Convenience function to mark a task as failed."""
|
||||
return get_claimer().update_status(
|
||||
cache_id, worker_id, ClaimStatus.FAILED, error=error
|
||||
)
|
||||
Reference in New Issue
Block a user