Files
rose-ash/artdag/l1/claiming.py
giles 1a74d811f7
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 2m33s
Incorporate art-dag-mono repo into artdag/ subfolder
Merges full history from art-dag/mono.git into the monorepo
under the artdag/ directory. Contains: core (DAG engine),
l1 (Celery rendering server), l2 (ActivityPub registry),
common (shared templates/middleware), client (CLI), test (e2e).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

git-subtree-dir: artdag
git-subtree-mainline: 1a179de547
git-subtree-split: 4c2e716558
2026-02-27 09:07:23 +00:00

422 lines
12 KiB
Python

"""
Hash-based task claiming for distributed execution.
Prevents duplicate work when multiple workers process the same plan.
Uses Redis Lua scripts for atomic claim operations.
"""
import json
import logging
import os
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
import redis
logger = logging.getLogger(__name__)
REDIS_URL = os.environ.get('REDIS_URL', 'redis://localhost:6379/5')
# Key prefix for task claims
CLAIM_PREFIX = "artdag:claim:"
# Default TTL for claims (5 minutes)
DEFAULT_CLAIM_TTL = 300
# TTL for completed results (1 hour)
COMPLETED_TTL = 3600
class ClaimStatus(Enum):
"""Status of a task claim."""
PENDING = "pending"
CLAIMED = "claimed"
RUNNING = "running"
COMPLETED = "completed"
CACHED = "cached"
FAILED = "failed"
@dataclass
class ClaimInfo:
"""Information about a task claim."""
cache_id: str
status: ClaimStatus
worker_id: Optional[str] = None
task_id: Optional[str] = None
claimed_at: Optional[str] = None
completed_at: Optional[str] = None
output_path: Optional[str] = None
error: Optional[str] = None
def to_dict(self) -> dict:
return {
"cache_id": self.cache_id,
"status": self.status.value,
"worker_id": self.worker_id,
"task_id": self.task_id,
"claimed_at": self.claimed_at,
"completed_at": self.completed_at,
"output_path": self.output_path,
"error": self.error,
}
@classmethod
def from_dict(cls, data: dict) -> "ClaimInfo":
return cls(
cache_id=data["cache_id"],
status=ClaimStatus(data["status"]),
worker_id=data.get("worker_id"),
task_id=data.get("task_id"),
claimed_at=data.get("claimed_at"),
completed_at=data.get("completed_at"),
output_path=data.get("output_path"),
error=data.get("error"),
)
# Lua script for atomic task claiming
# Returns 1 if claim successful, 0 if already claimed/completed
CLAIM_TASK_SCRIPT = """
local key = KEYS[1]
local data = redis.call('GET', key)
if data then
local status = cjson.decode(data)
local s = status['status']
-- Already claimed, running, completed, or cached - don't claim
if s == 'claimed' or s == 'running' or s == 'completed' or s == 'cached' then
return 0
end
end
-- Claim the task
local claim_data = ARGV[1]
local ttl = tonumber(ARGV[2])
redis.call('SETEX', key, ttl, claim_data)
return 1
"""
# Lua script for releasing a claim (e.g., on failure)
RELEASE_CLAIM_SCRIPT = """
local key = KEYS[1]
local worker_id = ARGV[1]
local data = redis.call('GET', key)
if data then
local status = cjson.decode(data)
-- Only release if we own the claim
if status['worker_id'] == worker_id then
redis.call('DEL', key)
return 1
end
end
return 0
"""
# Lua script for updating claim status (claimed -> running -> completed)
UPDATE_STATUS_SCRIPT = """
local key = KEYS[1]
local worker_id = ARGV[1]
local new_status = ARGV[2]
local new_data = ARGV[3]
local ttl = tonumber(ARGV[4])
local data = redis.call('GET', key)
if not data then
return 0
end
local status = cjson.decode(data)
-- Only update if we own the claim
if status['worker_id'] ~= worker_id then
return 0
end
redis.call('SETEX', key, ttl, new_data)
return 1
"""
class TaskClaimer:
"""
Manages hash-based task claiming for distributed execution.
Uses Redis for coordination between workers.
Each task is identified by its cache_id (content-addressed).
"""
def __init__(self, redis_url: str = None):
"""
Initialize the claimer.
Args:
redis_url: Redis connection URL
"""
self.redis_url = redis_url or REDIS_URL
self._redis: Optional[redis.Redis] = None
self._claim_script = None
self._release_script = None
self._update_script = None
@property
def redis(self) -> redis.Redis:
"""Get Redis connection (lazy initialization)."""
if self._redis is None:
self._redis = redis.from_url(self.redis_url, decode_responses=True)
# Register Lua scripts
self._claim_script = self._redis.register_script(CLAIM_TASK_SCRIPT)
self._release_script = self._redis.register_script(RELEASE_CLAIM_SCRIPT)
self._update_script = self._redis.register_script(UPDATE_STATUS_SCRIPT)
return self._redis
def _key(self, cache_id: str) -> str:
"""Get Redis key for a cache_id."""
return f"{CLAIM_PREFIX}{cache_id}"
def claim(
self,
cache_id: str,
worker_id: str,
task_id: Optional[str] = None,
ttl: int = DEFAULT_CLAIM_TTL,
) -> bool:
"""
Attempt to claim a task.
Args:
cache_id: The cache ID of the task to claim
worker_id: Identifier for the claiming worker
task_id: Optional Celery task ID
ttl: Time-to-live for the claim in seconds
Returns:
True if claim successful, False if already claimed
"""
claim_info = ClaimInfo(
cache_id=cache_id,
status=ClaimStatus.CLAIMED,
worker_id=worker_id,
task_id=task_id,
claimed_at=datetime.now(timezone.utc).isoformat(),
)
result = self._claim_script(
keys=[self._key(cache_id)],
args=[json.dumps(claim_info.to_dict()), ttl],
client=self.redis,
)
if result == 1:
logger.debug(f"Claimed task {cache_id[:16]}... for worker {worker_id}")
return True
else:
logger.debug(f"Task {cache_id[:16]}... already claimed")
return False
def update_status(
self,
cache_id: str,
worker_id: str,
status: ClaimStatus,
output_path: Optional[str] = None,
error: Optional[str] = None,
ttl: Optional[int] = None,
) -> bool:
"""
Update the status of a claimed task.
Args:
cache_id: The cache ID of the task
worker_id: Worker ID that owns the claim
status: New status
output_path: Path to output (for completed)
error: Error message (for failed)
ttl: New TTL (defaults based on status)
Returns:
True if update successful
"""
if ttl is None:
if status in (ClaimStatus.COMPLETED, ClaimStatus.CACHED):
ttl = COMPLETED_TTL
else:
ttl = DEFAULT_CLAIM_TTL
# Get existing claim info
existing = self.get_status(cache_id)
if not existing:
logger.warning(f"No claim found for {cache_id[:16]}...")
return False
claim_info = ClaimInfo(
cache_id=cache_id,
status=status,
worker_id=worker_id,
task_id=existing.task_id,
claimed_at=existing.claimed_at,
completed_at=datetime.now(timezone.utc).isoformat() if status in (
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
) else None,
output_path=output_path,
error=error,
)
result = self._update_script(
keys=[self._key(cache_id)],
args=[worker_id, status.value, json.dumps(claim_info.to_dict()), ttl],
client=self.redis,
)
if result == 1:
logger.debug(f"Updated task {cache_id[:16]}... to {status.value}")
return True
else:
logger.warning(f"Failed to update task {cache_id[:16]}... (not owner?)")
return False
def release(self, cache_id: str, worker_id: str) -> bool:
"""
Release a claim (e.g., on task failure before completion).
Args:
cache_id: The cache ID of the task
worker_id: Worker ID that owns the claim
Returns:
True if release successful
"""
result = self._release_script(
keys=[self._key(cache_id)],
args=[worker_id],
client=self.redis,
)
if result == 1:
logger.debug(f"Released claim on {cache_id[:16]}...")
return True
return False
def get_status(self, cache_id: str) -> Optional[ClaimInfo]:
"""
Get the current status of a task.
Args:
cache_id: The cache ID of the task
Returns:
ClaimInfo if task has been claimed, None otherwise
"""
data = self.redis.get(self._key(cache_id))
if data:
return ClaimInfo.from_dict(json.loads(data))
return None
def is_completed(self, cache_id: str) -> bool:
"""Check if a task is completed or cached."""
info = self.get_status(cache_id)
return info is not None and info.status in (
ClaimStatus.COMPLETED, ClaimStatus.CACHED
)
def wait_for_completion(
self,
cache_id: str,
timeout: float = 300,
poll_interval: float = 0.5,
) -> Optional[ClaimInfo]:
"""
Wait for a task to complete.
Args:
cache_id: The cache ID of the task
timeout: Maximum time to wait in seconds
poll_interval: How often to check status
Returns:
ClaimInfo if completed, None if timeout
"""
start_time = time.time()
while time.time() - start_time < timeout:
info = self.get_status(cache_id)
if info and info.status in (
ClaimStatus.COMPLETED, ClaimStatus.CACHED, ClaimStatus.FAILED
):
return info
time.sleep(poll_interval)
logger.warning(f"Timeout waiting for {cache_id[:16]}...")
return None
def mark_cached(self, cache_id: str, output_path: str) -> None:
"""
Mark a task as already cached (no processing needed).
This is used when we discover the result already exists
before attempting to claim.
Args:
cache_id: The cache ID of the task
output_path: Path to the cached output
"""
claim_info = ClaimInfo(
cache_id=cache_id,
status=ClaimStatus.CACHED,
output_path=output_path,
completed_at=datetime.now(timezone.utc).isoformat(),
)
self.redis.setex(
self._key(cache_id),
COMPLETED_TTL,
json.dumps(claim_info.to_dict()),
)
def clear_all(self) -> int:
"""
Clear all claims (for testing/reset).
Returns:
Number of claims cleared
"""
pattern = f"{CLAIM_PREFIX}*"
keys = list(self.redis.scan_iter(match=pattern))
if keys:
return self.redis.delete(*keys)
return 0
# Global claimer instance
_claimer: Optional[TaskClaimer] = None
def get_claimer() -> TaskClaimer:
"""Get the global TaskClaimer instance."""
global _claimer
if _claimer is None:
_claimer = TaskClaimer()
return _claimer
def claim_task(cache_id: str, worker_id: str, task_id: str = None) -> bool:
"""Convenience function to claim a task."""
return get_claimer().claim(cache_id, worker_id, task_id)
def complete_task(cache_id: str, worker_id: str, output_path: str) -> bool:
"""Convenience function to mark a task as completed."""
return get_claimer().update_status(
cache_id, worker_id, ClaimStatus.COMPLETED, output_path=output_path
)
def fail_task(cache_id: str, worker_id: str, error: str) -> bool:
"""Convenience function to mark a task as failed."""
return get_claimer().update_status(
cache_id, worker_id, ClaimStatus.FAILED, error=error
)