Files
mono/artdag/sexp/scheduler.py
giles cc2dcbddd4 Squashed 'core/' content from commit 4957443
git-subtree-dir: core
git-subtree-split: 4957443184ae0eb6323635a90a19acffb3e01d07
2026-02-24 23:09:39 +00:00

780 lines
26 KiB
Python

"""
Celery scheduler for S-expression execution plans.
Distributes plan steps to workers as S-expressions.
The S-expression is the canonical format - workers receive
serialized S-expressions and can verify cache_ids by hashing them.
Usage:
from artdag.sexp import compile_string, create_plan
from artdag.sexp.scheduler import schedule_plan
recipe = compile_string(sexp_content)
plan = create_plan(recipe, inputs={'video': 'abc123...'})
result = schedule_plan(plan)
"""
import hashlib
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Callable
from .parser import Symbol, Keyword, serialize, parse
from .planner import ExecutionPlanSexp, PlanStep
logger = logging.getLogger(__name__)
@dataclass
class StepResult:
"""Result from executing a step."""
step_id: str
cache_id: str
status: str # 'completed', 'cached', 'failed', 'pending'
output_path: Optional[str] = None
error: Optional[str] = None
ipfs_cid: Optional[str] = None
@dataclass
class PlanResult:
"""Result from executing a complete plan."""
plan_id: str
status: str # 'completed', 'failed', 'partial'
steps_completed: int = 0
steps_cached: int = 0
steps_failed: int = 0
output_cache_id: Optional[str] = None
output_path: Optional[str] = None
output_ipfs_cid: Optional[str] = None
step_results: Dict[str, StepResult] = field(default_factory=dict)
error: Optional[str] = None
def step_to_sexp(step: PlanStep) -> List:
"""
Convert a PlanStep to minimal S-expression for worker.
This is the canonical form that workers receive.
Workers can verify cache_id by hashing this S-expression.
"""
sexp = [Symbol(step.node_type.lower())]
# Add config as keywords
for key, value in step.config.items():
sexp.extend([Keyword(key.replace('_', '-')), value])
# Add inputs as cache IDs (not step IDs)
if step.inputs:
sexp.extend([Keyword("inputs"), step.inputs])
return sexp
def step_sexp_to_string(step: PlanStep) -> str:
"""Serialize step to S-expression string for Celery task."""
return serialize(step_to_sexp(step))
def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool:
"""
Verify that a step's cache_id matches its S-expression.
Workers should call this to verify they're executing the correct task.
"""
data = {"sexp": step_sexp}
if cluster_key:
data = {"_cluster_key": cluster_key, "_data": data}
json_str = json.dumps(data, sort_keys=True, separators=(",", ":"))
computed = hashlib.sha3_256(json_str.encode()).hexdigest()
return computed == expected_cache_id
class PlanScheduler:
"""
Schedules execution of S-expression plans on Celery workers.
The scheduler:
1. Groups steps by dependency level
2. Checks cache for already-computed results
3. Dispatches uncached steps to workers
4. Waits for completion before proceeding to next level
"""
def __init__(
self,
cache_manager=None,
celery_app=None,
execute_task_name: str = 'tasks.execute_step_sexp',
):
"""
Initialize the scheduler.
Args:
cache_manager: L1 cache manager for checking cached results
celery_app: Celery application instance
execute_task_name: Name of the Celery task for step execution
"""
self.cache_manager = cache_manager
self.celery_app = celery_app
self.execute_task_name = execute_task_name
def schedule(
self,
plan: ExecutionPlanSexp,
timeout: int = 3600,
) -> PlanResult:
"""
Schedule and execute a plan.
Args:
plan: The execution plan (S-expression format)
timeout: Timeout in seconds for the entire plan
Returns:
PlanResult with execution results
"""
from celery import group
logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)")
# Build step lookup and group by level
steps_by_id = {s.step_id: s for s in plan.steps}
steps_by_level = self._group_by_level(plan.steps)
max_level = max(steps_by_level.keys()) if steps_by_level else 0
# Track results
result = PlanResult(
plan_id=plan.plan_id,
status="pending",
)
# Map step_id -> cache_id for resolving inputs
cache_ids = dict(plan.inputs) # Start with input hashes
for step in plan.steps:
cache_ids[step.step_id] = step.cache_id
# Execute level by level
for level in range(max_level + 1):
level_steps = steps_by_level.get(level, [])
if not level_steps:
continue
logger.info(f"Level {level}: {len(level_steps)} steps")
# Check cache for each step
steps_to_run = []
for step in level_steps:
if self._is_cached(step.cache_id):
result.steps_cached += 1
result.step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="cached",
output_path=self._get_cached_path(step.cache_id),
)
else:
steps_to_run.append(step)
if not steps_to_run:
logger.info(f"Level {level}: all {len(level_steps)} steps cached")
continue
# Dispatch uncached steps to workers
logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps")
tasks = []
for step in steps_to_run:
# Build task arguments
step_sexp = step_sexp_to_string(step)
input_cache_ids = {
inp: cache_ids.get(inp, inp)
for inp in step.inputs
}
task = self._get_execute_task().s(
step_sexp=step_sexp,
step_id=step.step_id,
cache_id=step.cache_id,
plan_id=plan.plan_id,
input_cache_ids=input_cache_ids,
)
tasks.append(task)
# Execute in parallel
job = group(tasks)
async_result = job.apply_async()
try:
step_results = async_result.get(timeout=timeout)
except Exception as e:
logger.error(f"Level {level} failed: {e}")
result.status = "failed"
result.error = f"Level {level} failed: {e}"
return result
# Process results
for step_result in step_results:
step_id = step_result.get("step_id")
status = step_result.get("status")
result.step_results[step_id] = StepResult(
step_id=step_id,
cache_id=step_result.get("cache_id"),
status=status,
output_path=step_result.get("output_path"),
error=step_result.get("error"),
ipfs_cid=step_result.get("ipfs_cid"),
)
if status in ("completed", "cached", "completed_by_other"):
result.steps_completed += 1
elif status == "failed":
result.steps_failed += 1
result.status = "failed"
result.error = step_result.get("error")
return result
# Get final output
output_step = steps_by_id.get(plan.output_step_id)
if output_step:
output_result = result.step_results.get(output_step.step_id)
if output_result:
result.output_cache_id = output_step.cache_id
result.output_path = output_result.output_path
result.output_ipfs_cid = output_result.ipfs_cid
result.status = "completed"
logger.info(
f"Plan {plan.plan_id[:16]}... completed: "
f"{result.steps_completed} executed, {result.steps_cached} cached"
)
return result
def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]:
"""Group steps by dependency level."""
by_level = {}
for step in steps:
by_level.setdefault(step.level, []).append(step)
return by_level
def _is_cached(self, cache_id: str) -> bool:
"""Check if a cache_id exists in cache."""
if self.cache_manager is None:
return False
path = self.cache_manager.get_by_cid(cache_id)
return path is not None
def _get_cached_path(self, cache_id: str) -> Optional[str]:
"""Get the path for a cached item."""
if self.cache_manager is None:
return None
path = self.cache_manager.get_by_cid(cache_id)
return str(path) if path else None
def _get_execute_task(self):
"""Get the Celery task for step execution."""
if self.celery_app is None:
raise RuntimeError("Celery app not configured")
return self.celery_app.tasks[self.execute_task_name]
def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler:
"""
Create a scheduler with the given dependencies.
If not provided, attempts to import from art-celery.
"""
if celery_app is None:
try:
from celery_app import app as celery_app
except ImportError:
pass
if cache_manager is None:
try:
from cache_manager import get_cache_manager
cache_manager = get_cache_manager()
except ImportError:
pass
return PlanScheduler(
cache_manager=cache_manager,
celery_app=celery_app,
)
def schedule_plan(
plan: ExecutionPlanSexp,
cache_manager=None,
celery_app=None,
timeout: int = 3600,
) -> PlanResult:
"""
Convenience function to schedule a plan.
Args:
plan: The execution plan
cache_manager: Optional cache manager
celery_app: Optional Celery app
timeout: Execution timeout
Returns:
PlanResult
"""
scheduler = create_scheduler(cache_manager, celery_app)
return scheduler.schedule(plan, timeout=timeout)
# Stage-aware scheduling
@dataclass
class StageResult:
"""Result from executing a stage."""
stage_name: str
cache_id: str
status: str # 'completed', 'cached', 'failed', 'pending'
step_results: Dict[str, StepResult] = field(default_factory=dict)
outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id
error: Optional[str] = None
@dataclass
class StagePlanResult:
"""Result from executing a plan with stages."""
plan_id: str
status: str # 'completed', 'failed', 'partial'
stages_completed: int = 0
stages_cached: int = 0
stages_failed: int = 0
steps_completed: int = 0
steps_cached: int = 0
steps_failed: int = 0
stage_results: Dict[str, StageResult] = field(default_factory=dict)
output_cache_id: Optional[str] = None
output_path: Optional[str] = None
error: Optional[str] = None
class StagePlanScheduler:
"""
Stage-aware scheduler for S-expression plans.
The scheduler:
1. Groups stages by level (parallel groups)
2. For each stage level:
- Check stage cache, skip entire stage if hit
- Execute stage steps (grouped by level within stage)
- Cache stage outputs
3. Stages at same level can run in parallel
"""
def __init__(
self,
cache_manager=None,
stage_cache=None,
celery_app=None,
execute_task_name: str = 'tasks.execute_step_sexp',
):
"""
Initialize the stage-aware scheduler.
Args:
cache_manager: L1 cache manager for step-level caching
stage_cache: StageCache instance for stage-level caching
celery_app: Celery application instance
execute_task_name: Name of the Celery task for step execution
"""
self.cache_manager = cache_manager
self.stage_cache = stage_cache
self.celery_app = celery_app
self.execute_task_name = execute_task_name
def schedule(
self,
plan: ExecutionPlanSexp,
timeout: int = 3600,
) -> StagePlanResult:
"""
Schedule and execute a plan with stage awareness.
If the plan has stages, uses stage-level scheduling.
Otherwise, falls back to step-level scheduling.
Args:
plan: The execution plan (S-expression format)
timeout: Timeout in seconds for the entire plan
Returns:
StagePlanResult with execution results
"""
# If no stages, use regular scheduling
if not plan.stage_plans:
logger.info("Plan has no stages, using step-level scheduling")
regular_scheduler = PlanScheduler(
cache_manager=self.cache_manager,
celery_app=self.celery_app,
execute_task_name=self.execute_task_name,
)
step_result = regular_scheduler.schedule(plan, timeout)
return StagePlanResult(
plan_id=step_result.plan_id,
status=step_result.status,
steps_completed=step_result.steps_completed,
steps_cached=step_result.steps_cached,
steps_failed=step_result.steps_failed,
output_cache_id=step_result.output_cache_id,
output_path=step_result.output_path,
error=step_result.error,
)
logger.info(
f"Scheduling plan {plan.plan_id[:16]}... "
f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)"
)
result = StagePlanResult(
plan_id=plan.plan_id,
status="pending",
)
# Group stages by level
stages_by_level = self._group_stages_by_level(plan.stage_plans)
max_level = max(stages_by_level.keys()) if stages_by_level else 0
# Track stage outputs for data flow
stage_outputs = {} # stage_name -> {binding_name -> cache_id}
# Execute stage by stage level
for level in range(max_level + 1):
level_stages = stages_by_level.get(level, [])
if not level_stages:
continue
logger.info(f"Stage level {level}: {len(level_stages)} stages")
# Check stage cache for each stage
stages_to_run = []
for stage_plan in level_stages:
if self._is_stage_cached(stage_plan.cache_id):
result.stages_cached += 1
cached_entry = self._load_cached_stage(stage_plan.cache_id)
if cached_entry:
stage_outputs[stage_plan.stage_name] = {
name: out.cache_id
for name, out in cached_entry.outputs.items()
}
result.stage_results[stage_plan.stage_name] = StageResult(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
status="cached",
outputs=stage_outputs[stage_plan.stage_name],
)
logger.info(f"Stage {stage_plan.stage_name}: cached")
else:
stages_to_run.append(stage_plan)
if not stages_to_run:
logger.info(f"Stage level {level}: all {len(level_stages)} stages cached")
continue
# Execute uncached stages
# For now, execute sequentially; L1 Celery will add parallel execution
for stage_plan in stages_to_run:
logger.info(f"Executing stage: {stage_plan.stage_name}")
stage_result = self._execute_stage(
stage_plan,
plan,
stage_outputs,
timeout,
)
result.stage_results[stage_plan.stage_name] = stage_result
if stage_result.status == "completed":
result.stages_completed += 1
stage_outputs[stage_plan.stage_name] = stage_result.outputs
# Cache the stage result
self._cache_stage(stage_plan, stage_result)
elif stage_result.status == "failed":
result.stages_failed += 1
result.status = "failed"
result.error = stage_result.error
return result
# Accumulate step counts
for sr in stage_result.step_results.values():
if sr.status == "completed":
result.steps_completed += 1
elif sr.status == "cached":
result.steps_cached += 1
elif sr.status == "failed":
result.steps_failed += 1
# Get final output
if plan.stage_plans:
last_stage = plan.stage_plans[-1]
if last_stage.stage_name in result.stage_results:
stage_res = result.stage_results[last_stage.stage_name]
result.output_cache_id = last_stage.cache_id
# Find the output step's path from step results
for step_res in stage_res.step_results.values():
if step_res.output_path:
result.output_path = step_res.output_path
result.status = "completed"
logger.info(
f"Plan {plan.plan_id[:16]}... completed: "
f"{result.stages_completed} stages executed, "
f"{result.stages_cached} stages cached"
)
return result
def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]:
"""Group stage plans by their level."""
by_level = {}
for stage_plan in stage_plans:
by_level.setdefault(stage_plan.level, []).append(stage_plan)
return by_level
def _is_stage_cached(self, cache_id: str) -> bool:
"""Check if a stage is cached."""
if self.stage_cache is None:
return False
return self.stage_cache.has_stage(cache_id)
def _load_cached_stage(self, cache_id: str):
"""Load a cached stage entry."""
if self.stage_cache is None:
return None
return self.stage_cache.load_stage(cache_id)
def _cache_stage(self, stage_plan, stage_result: StageResult) -> None:
"""Cache a stage result."""
if self.stage_cache is None:
return
from .stage_cache import StageCacheEntry, StageOutput
outputs = {}
for name, cache_id in stage_result.outputs.items():
outputs[name] = StageOutput(
cache_id=cache_id,
output_type="artifact",
)
entry = StageCacheEntry(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
outputs=outputs,
)
self.stage_cache.save_stage(entry)
def _execute_stage(
self,
stage_plan,
plan: ExecutionPlanSexp,
stage_outputs: Dict,
timeout: int,
) -> StageResult:
"""
Execute a single stage.
Uses the PlanScheduler to execute the stage's steps.
"""
# Create a mini-plan with just this stage's steps
stage_steps = stage_plan.steps
# Build step lookup
steps_by_id = {s.step_id: s for s in stage_steps}
steps_by_level = {}
for step in stage_steps:
steps_by_level.setdefault(step.level, []).append(step)
max_level = max(steps_by_level.keys()) if steps_by_level else 0
# Track step results
step_results = {}
cache_ids = dict(plan.inputs) # Start with input hashes
for step in plan.steps:
cache_ids[step.step_id] = step.cache_id
# Include outputs from previous stages
for stage_name, outputs in stage_outputs.items():
for binding_name, binding_cache_id in outputs.items():
cache_ids[binding_name] = binding_cache_id
# Execute steps level by level
for level in range(max_level + 1):
level_steps = steps_by_level.get(level, [])
if not level_steps:
continue
# Check cache for each step
steps_to_run = []
for step in level_steps:
if self._is_step_cached(step.cache_id):
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="cached",
output_path=self._get_cached_path(step.cache_id),
)
else:
steps_to_run.append(step)
if not steps_to_run:
continue
# Execute steps (for now, sequentially - L1 will add Celery dispatch)
for step in steps_to_run:
# In a full implementation, this would dispatch to Celery
# For now, mark as pending
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="pending",
)
# If Celery is configured, dispatch the task
if self.celery_app:
try:
task_result = self._dispatch_step(step, cache_ids, timeout)
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status=task_result.get("status", "completed"),
output_path=task_result.get("output_path"),
error=task_result.get("error"),
ipfs_cid=task_result.get("ipfs_cid"),
)
except Exception as e:
step_results[step.step_id] = StepResult(
step_id=step.step_id,
cache_id=step.cache_id,
status="failed",
error=str(e),
)
return StageResult(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
status="failed",
step_results=step_results,
error=str(e),
)
# Build output bindings
outputs = {}
for out_name, node_id in stage_plan.output_bindings.items():
outputs[out_name] = cache_ids.get(node_id, node_id)
return StageResult(
stage_name=stage_plan.stage_name,
cache_id=stage_plan.cache_id,
status="completed",
step_results=step_results,
outputs=outputs,
)
def _is_step_cached(self, cache_id: str) -> bool:
"""Check if a step is cached."""
if self.cache_manager is None:
return False
path = self.cache_manager.get_by_cid(cache_id)
return path is not None
def _get_cached_path(self, cache_id: str) -> Optional[str]:
"""Get the path for a cached step."""
if self.cache_manager is None:
return None
path = self.cache_manager.get_by_cid(cache_id)
return str(path) if path else None
def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict:
"""Dispatch a step to Celery for execution."""
if self.celery_app is None:
raise RuntimeError("Celery app not configured")
task = self.celery_app.tasks[self.execute_task_name]
step_sexp = step_sexp_to_string(step)
input_cache_ids = {
inp: cache_ids.get(inp, inp)
for inp in step.inputs
}
async_result = task.apply_async(
kwargs={
"step_sexp": step_sexp,
"step_id": step.step_id,
"cache_id": step.cache_id,
"input_cache_ids": input_cache_ids,
}
)
return async_result.get(timeout=timeout)
def create_stage_scheduler(
cache_manager=None,
stage_cache=None,
celery_app=None,
) -> StagePlanScheduler:
"""
Create a stage-aware scheduler with the given dependencies.
Args:
cache_manager: L1 cache manager for step-level caching
stage_cache: StageCache instance for stage-level caching
celery_app: Celery application instance
Returns:
StagePlanScheduler
"""
if celery_app is None:
try:
from celery_app import app as celery_app
except ImportError:
pass
if cache_manager is None:
try:
from cache_manager import get_cache_manager
cache_manager = get_cache_manager()
except ImportError:
pass
return StagePlanScheduler(
cache_manager=cache_manager,
stage_cache=stage_cache,
celery_app=celery_app,
)
def schedule_staged_plan(
plan: ExecutionPlanSexp,
cache_manager=None,
stage_cache=None,
celery_app=None,
timeout: int = 3600,
) -> StagePlanResult:
"""
Convenience function to schedule a plan with stage awareness.
Args:
plan: The execution plan
cache_manager: Optional step-level cache manager
stage_cache: Optional stage-level cache
celery_app: Optional Celery app
timeout: Execution timeout
Returns:
StagePlanResult
"""
scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app)
return scheduler.schedule(plan, timeout=timeout)