All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 2m33s
Merges full history from art-dag/mono.git into the monorepo under the artdag/ directory. Contains: core (DAG engine), l1 (Celery rendering server), l2 (ActivityPub registry), common (shared templates/middleware), client (CLI), test (e2e). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> git-subtree-dir: artdag git-subtree-mainline:1a179de547git-subtree-split:4c2e716558
422 lines
12 KiB
Python
422 lines
12 KiB
Python
"""
|
|
Hash-based task claiming for distributed execution.
|
|
|
|
Prevents duplicate work when multiple workers process the same plan.
|
|
Uses Redis Lua scripts for atomic claim operations.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
import redis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
|
|
|
|
# Key prefix for task claims
|
|
CLAIM_PREFIX = "artdag:claim:"
|
|
|
|
# Default TTL for claims (5 minutes)
|
|
DEFAULT_CLAIM_TTL = 300
|
|
|
|
# TTL for completed results (1 hour)
|
|
COMPLETED_TTL = 3600
|
|
|
|
|
|
class ClaimStatus(Enum):
|
|
"""Status of a task claim."""
|
|
PENDING = "pending"
|
|
CLAIMED = "claimed"
|
|
RUNNING = "running"
|
|
COMPLETED = "completed"
|
|
CACHED = "cached"
|
|
FAILED = "failed"
|
|
|
|
|
|
@dataclass
|
|
class ClaimInfo:
|
|
"""Information about a task claim."""
|
|
cache_id: str
|
|
status: ClaimStatus
|
|
worker_id: Optional[str] = None
|
|
task_id: Optional[str] = None
|
|
claimed_at: Optional[str] = None
|
|
completed_at: Optional[str] = None
|
|
output_path: Optional[str] = None
|
|
error: Optional[str] = None
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"cache_id": self.cache_id,
|
|
"status": self.status.value,
|
|
"worker_id": self.worker_id,
|
|
"task_id": self.task_id,
|
|
"claimed_at": self.claimed_at,
|
|
"completed_at": self.completed_at,
|
|
"output_path": self.output_path,
|
|
"error": self.error,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "ClaimInfo":
|
|
return cls(
|
|
cache_id=data["cache_id"],
|
|
status=ClaimStatus(data["status"]),
|
|
worker_id=data.get("worker_id"),
|
|
task_id=data.get("task_id"),
|
|
claimed_at=data.get("claimed_at"),
|
|
completed_at=data.get("completed_at"),
|
|
output_path=data.get("output_path"),
|
|
error=data.get("error"),
|
|
)
|
|
|
|
|
|
# Lua script for atomic task claiming
|
|
# Returns 1 if claim successful, 0 if already claimed/completed
|
|
CLAIM_TASK_SCRIPT = """
|
|
local key = KEYS[1]
|
|
local data = redis.call('GET', key)
|
|
|
|
if data then
|
|
local status = cjson.decode(data)
|
|
local s = status['status']
|
|
-- Already claimed, running, completed, or cached - don't claim
|
|
if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then
|
|
return 0
|
|
end
|
|
end
|
|
|
|
-- Claim the task
|
|
local claim_data = ARGV[1]
|
|
local ttl = tonumber(ARGV[2])
|
|
redis.call('SETEX', key, ttl, claim_data)
|
|
return 1
|
|
"""
|
|
|
|
# Lua script for releasing a claim (e.g., on failure)
|
|
RELEASE_CLAIM_SCRIPT = """
|
|
local key = KEYS[1]
|
|
local worker_id = ARGV[1]
|
|
local data = redis.call('GET', key)
|
|
|
|
if data then
|
|
local status = cjson.decode(data)
|
|
-- Only release if we own the claim
|
|
if status['worker_id'] == worker_id then
|
|
redis.call('DEL', key)
|
|
return 1
|
|
end
|
|
end
|
|
return 0
|
|
"""
|
|
|
|
# Lua script for updating claim status (claimed -> running -> completed)
|
|
UPDATE_STATUS_SCRIPT = """
|
|
local key = KEYS[1]
|
|
local worker_id = ARGV[1]
|
|
local new_status = ARGV[2]
|
|
local new_data = ARGV[3]
|
|
local ttl = tonumber(ARGV[4])
|
|
|
|
local data = redis.call('GET', key)
|
|
if not data then
|
|
return 0
|
|
end
|
|
|
|
local status = cjson.decode(data)
|
|
|
|
-- Only update if we own the claim
|
|
if status['worker_id'] ~= worker_id then
|
|
return 0
|
|
end
|
|
|
|
redis.call('SETEX', key, ttl, new_data)
|
|
return 1
|
|
"""
|
|
|
|
|
|
class TaskClaimer:
|
|
"""
|
|
Manages hash-based task claiming for distributed execution.
|
|
|
|
Uses Redis for coordination between workers.
|
|
Each task is identified by its cache_id (content-addressed).
|
|
"""
|
|
|
|
def __init__(self, redis_url: str = None):
|
|
"""
|
|
Initialize the claimer.
|
|
|
|
Args:
|
|
redis_url: Redis connection URL
|
|
"""
|
|
self.redis_url = redis_url or REDIS_URL
|
|
self._redis: Optional[redis.Redis] = None
|
|
self._claim_script = None
|
|
self._release_script = None
|
|
self._update_script = None
|
|
|
|
@property
|
|
def redis(self) -> redis.Redis:
|
|
"""Get Redis connection (lazy initialization)."""
|
|
if self._redis is None:
|
|
self._redis = redis.from_url(self.redis_url, decode_responses=True)
|
|
# Register Lua scripts
|
|
self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT)
|
|
self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT)
|
|
self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT)
|
|
return self._redis
|
|
|
|
def _key(self, cache_id: str) -> str:
|
|
"""Get Redis key for a cache_id."""
|
|
return f"{CLAIM_PREFIX}{cache_id}"
|
|
|
|
def claim(
|
|
self,
|
|
cache_id: str,
|
|
worker_id: str,
|
|
task_id: Optional[str] = None,
|
|
ttl: int = DEFAULT_CLAIM_TTL,
|
|
) -> bool:
|
|
"""
|
|
Attempt to claim a task.
|
|
|
|
Args:
|
|
cache_id: The cache ID of the task to claim
|
|
worker_id: Identifier for the claiming worker
|
|
task_id: Optional Celery task ID
|
|
ttl: Time-to-live for the claim in seconds
|
|
|
|
Returns:
|
|
True if claim successful, False if already claimed
|
|
"""
|
|
claim_info = ClaimInfo(
|
|
cache_id=cache_id,
|
|
status=ClaimStatus.CLAIMED,
|
|
worker_id=worker_id,
|
|
task_id=task_id,
|
|
claimed_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
|
|
result = self._claim_script(
|
|
keys=[self._key(cache_id)],
|
|
args=[json.dumps(claim_info.to_dict()), ttl],
|
|
client=self.redis,
|
|
)
|
|
|
|
if result == 1:
|
|
logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}")
|
|
return True
|
|
else:
|
|
logger.debug(f"Task {cache_id[:16]}... already claimed")
|
|
return False
|
|
|
|
def update_status(
|
|
self,
|
|
cache_id: str,
|
|
worker_id: str,
|
|
status: ClaimStatus,
|
|
output_path: Optional[str] = None,
|
|
error: Optional[str] = None,
|
|
ttl: Optional[int] = None,
|
|
) -> bool:
|
|
"""
|
|
Update the status of a claimed task.
|
|
|
|
Args:
|
|
cache_id: The cache ID of the task
|
|
worker_id: Worker ID that owns the claim
|
|
status: New status
|
|
output_path: Path to output (for completed)
|
|
error: Error message (for failed)
|
|
ttl: New TTL (defaults based on status)
|
|
|
|
Returns:
|
|
True if update successful
|
|
"""
|
|
if ttl is None:
|
|
if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED):
|
|
ttl = COMPLETED_TTL
|
|
else:
|
|
ttl = DEFAULT_CLAIM_TTL
|
|
|
|
# Get existing claim info
|
|
existing = self.get_status(cache_id)
|
|
if not existing:
|
|
logger.warning(f"No claim found for {cache_id[:16]}...")
|
|
return False
|
|
|
|
claim_info = ClaimInfo(
|
|
cache_id=cache_id,
|
|
status=status,
|
|
worker_id=worker_id,
|
|
task_id=existing.task_id,
|
|
claimed_at=existing.claimed_at,
|
|
completed_at=datetime.now(timezone.utc).isoformat() if status in (
|
|
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
|
|
) else None,
|
|
output_path=output_path,
|
|
error=error,
|
|
)
|
|
|
|
result = self._update_script(
|
|
keys=[self._key(cache_id)],
|
|
args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl],
|
|
client=self.redis,
|
|
)
|
|
|
|
if result == 1:
|
|
logger.debug(f"Updated task {cache_id[:16]}... to {status.value}")
|
|
return True
|
|
else:
|
|
logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)")
|
|
return False
|
|
|
|
def release(self, cache_id: str, worker_id: str) -> bool:
|
|
"""
|
|
Release a claim (e.g., on task failure before completion).
|
|
|
|
Args:
|
|
cache_id: The cache ID of the task
|
|
worker_id: Worker ID that owns the claim
|
|
|
|
Returns:
|
|
True if release successful
|
|
"""
|
|
result = self._release_script(
|
|
keys=[self._key(cache_id)],
|
|
args=[worker_id],
|
|
client=self.redis,
|
|
)
|
|
|
|
if result == 1:
|
|
logger.debug(f"Released claim on {cache_id[:16]}...")
|
|
return True
|
|
return False
|
|
|
|
def get_status(self, cache_id: str) -> Optional[ClaimInfo]:
|
|
"""
|
|
Get the current status of a task.
|
|
|
|
Args:
|
|
cache_id: The cache ID of the task
|
|
|
|
Returns:
|
|
ClaimInfo if task has been claimed, None otherwise
|
|
"""
|
|
data = self.redis.get(self._key(cache_id))
|
|
if data:
|
|
return ClaimInfo.from_dict(json.loads(data))
|
|
return None
|
|
|
|
def is_completed(self, cache_id: str) -> bool:
|
|
"""Check if a task is completed or cached."""
|
|
info = self.get_status(cache_id)
|
|
return info is not None and info.status in (
|
|
ClaimStatus.COMPLETED, ClaimStatus.CACHED
|
|
)
|
|
|
|
def wait_for_completion(
|
|
self,
|
|
cache_id: str,
|
|
timeout: float = 300,
|
|
poll_interval: float = 0.5,
|
|
) -> Optional[ClaimInfo]:
|
|
"""
|
|
Wait for a task to complete.
|
|
|
|
Args:
|
|
cache_id: The cache ID of the task
|
|
timeout: Maximum time to wait in seconds
|
|
poll_interval: How often to check status
|
|
|
|
Returns:
|
|
ClaimInfo if completed, None if timeout
|
|
"""
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout:
|
|
info = self.get_status(cache_id)
|
|
if info and info.status in (
|
|
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
|
|
):
|
|
return info
|
|
time.sleep(poll_interval)
|
|
|
|
logger.warning(f"Timeout waiting for {cache_id[:16]}...")
|
|
return None
|
|
|
|
def mark_cached(self, cache_id: str, output_path: str) -> None:
|
|
"""
|
|
Mark a task as already cached (no processing needed).
|
|
|
|
This is used when we discover the result already exists
|
|
before attempting to claim.
|
|
|
|
Args:
|
|
cache_id: The cache ID of the task
|
|
output_path: Path to the cached output
|
|
"""
|
|
claim_info = ClaimInfo(
|
|
cache_id=cache_id,
|
|
status=ClaimStatus.CACHED,
|
|
output_path=output_path,
|
|
completed_at=datetime.now(timezone.utc).isoformat(),
|
|
)
|
|
|
|
self.redis.setex(
|
|
self._key(cache_id),
|
|
COMPLETED_TTL,
|
|
json.dumps(claim_info.to_dict()),
|
|
)
|
|
|
|
def clear_all(self) -> int:
|
|
"""
|
|
Clear all claims (for testing/reset).
|
|
|
|
Returns:
|
|
Number of claims cleared
|
|
"""
|
|
pattern = f"{CLAIM_PREFIX}*"
|
|
keys = list(self.redis.scan_iter(match=pattern))
|
|
if keys:
|
|
return self.redis.delete(*keys)
|
|
return 0
|
|
|
|
|
|
# Global claimer instance
|
|
_claimer: Optional[TaskClaimer] = None
|
|
|
|
|
|
def get_claimer() -> TaskClaimer:
|
|
"""Get the global TaskClaimer instance."""
|
|
global _claimer
|
|
if _claimer is None:
|
|
_claimer = TaskClaimer()
|
|
return _claimer
|
|
|
|
|
|
def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool:
|
|
"""Convenience function to claim a task."""
|
|
return get_claimer().claim(cache_id, worker_id, task_id)
|
|
|
|
|
|
def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool:
|
|
"""Convenience function to mark a task as completed."""
|
|
return get_claimer().update_status(
|
|
cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path
|
|
)
|
|
|
|
|
|
def fail_task(cache_id: str, worker_id: str, error: str) -> bool:
|
|
"""Convenience function to mark a task as failed."""
|
|
return get_claimer().update_status(
|
|
cache_id, worker_id, ClaimStatus.FAILED, error=error
|
|
)
|