""" Celery scheduler for S-expression execution plans. Distributes plan steps to workers as S-expressions. The S-expression is the canonical format - workers receive serialized S-expressions and can verify cache_ids by hashing them. Usage: from artdag.sexp import compile_string, create_plan from artdag.sexp.scheduler import schedule_plan recipe = compile_string(sexp_content) plan = create_plan(recipe, inputs={'video': 'abc123...'}) result = schedule_plan(plan) """ import hashlib import json import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Callable from .parser import Symbol, Keyword, serialize, parse from .planner import ExecutionPlanSexp, PlanStep logger = logging.getLogger(__name__) @dataclass class StepResult: """Result from executing a step.""" step_id: str cache_id: str status: str # 'completed', 'cached', 'failed', 'pending' output_path: Optional[str] = None error: Optional[str] = None ipfs_cid: Optional[str] = None @dataclass class PlanResult: """Result from executing a complete plan.""" plan_id: str status: str # 'completed', 'failed', 'partial' steps_completed: int = 0 steps_cached: int = 0 steps_failed: int = 0 output_cache_id: Optional[str] = None output_path: Optional[str] = None output_ipfs_cid: Optional[str] = None step_results: Dict[str, StepResult] = field(default_factory=dict) error: Optional[str] = None def step_to_sexp(step: PlanStep) -> List: """ Convert a PlanStep to minimal S-expression for worker. This is the canonical form that workers receive. Workers can verify cache_id by hashing this S-expression. """ sexp = [Symbol(step.node_type.lower())] # Add config as keywords for key, value in step.config.items(): sexp.extend([Keyword(key.replace('_', '-')), value]) # Add inputs as cache IDs (not step IDs) if step.inputs: sexp.extend([Keyword("inputs"), step.inputs]) return sexp def step_sexp_to_string(step: PlanStep) -> str: """Serialize step to S-expression string for Celery task.""" return serialize(step_to_sexp(step)) def verify_step_cache_id(step_sexp: str, expected_cache_id: str, cluster_key: str = None) -> bool: """ Verify that a step's cache_id matches its S-expression. Workers should call this to verify they're executing the correct task. """ data = {"sexp": step_sexp} if cluster_key: data = {"_cluster_key": cluster_key, "_data": data} json_str = json.dumps(data, sort_keys=True, separators=(",", ":")) computed = hashlib.sha3_256(json_str.encode()).hexdigest() return computed == expected_cache_id class PlanScheduler: """ Schedules execution of S-expression plans on Celery workers. The scheduler: 1. Groups steps by dependency level 2. Checks cache for already-computed results 3. Dispatches uncached steps to workers 4. Waits for completion before proceeding to next level """ def __init__( self, cache_manager=None, celery_app=None, execute_task_name: str = 'tasks.execute_step_sexp', ): """ Initialize the scheduler. Args: cache_manager: L1 cache manager for checking cached results celery_app: Celery application instance execute_task_name: Name of the Celery task for step execution """ self.cache_manager = cache_manager self.celery_app = celery_app self.execute_task_name = execute_task_name def schedule( self, plan: ExecutionPlanSexp, timeout: int = 3600, ) -> PlanResult: """ Schedule and execute a plan. Args: plan: The execution plan (S-expression format) timeout: Timeout in seconds for the entire plan Returns: PlanResult with execution results """ from celery import group logger.info(f"Scheduling plan {plan.plan_id[:16]}... ({len(plan.steps)} steps)") # Build step lookup and group by level steps_by_id = {s.step_id: s for s in plan.steps} steps_by_level = self._group_by_level(plan.steps) max_level = max(steps_by_level.keys()) if steps_by_level else 0 # Track results result = PlanResult( plan_id=plan.plan_id, status="pending", ) # Map step_id -> cache_id for resolving inputs cache_ids = dict(plan.inputs) # Start with input hashes for step in plan.steps: cache_ids[step.step_id] = step.cache_id # Execute level by level for level in range(max_level + 1): level_steps = steps_by_level.get(level, []) if not level_steps: continue logger.info(f"Level {level}: {len(level_steps)} steps") # Check cache for each step steps_to_run = [] for step in level_steps: if self._is_cached(step.cache_id): result.steps_cached += 1 result.step_results[step.step_id] = StepResult( step_id=step.step_id, cache_id=step.cache_id, status="cached", output_path=self._get_cached_path(step.cache_id), ) else: steps_to_run.append(step) if not steps_to_run: logger.info(f"Level {level}: all {len(level_steps)} steps cached") continue # Dispatch uncached steps to workers logger.info(f"Level {level}: dispatching {len(steps_to_run)} steps") tasks = [] for step in steps_to_run: # Build task arguments step_sexp = step_sexp_to_string(step) input_cache_ids = { inp: cache_ids.get(inp, inp) for inp in step.inputs } task = self._get_execute_task().s( step_sexp=step_sexp, step_id=step.step_id, cache_id=step.cache_id, plan_id=plan.plan_id, input_cache_ids=input_cache_ids, ) tasks.append(task) # Execute in parallel job = group(tasks) async_result = job.apply_async() try: step_results = async_result.get(timeout=timeout) except Exception as e: logger.error(f"Level {level} failed: {e}") result.status = "failed" result.error = f"Level {level} failed: {e}" return result # Process results for step_result in step_results: step_id = step_result.get("step_id") status = step_result.get("status") result.step_results[step_id] = StepResult( step_id=step_id, cache_id=step_result.get("cache_id"), status=status, output_path=step_result.get("output_path"), error=step_result.get("error"), ipfs_cid=step_result.get("ipfs_cid"), ) if status in ("completed", "cached", "completed_by_other"): result.steps_completed += 1 elif status == "failed": result.steps_failed += 1 result.status = "failed" result.error = step_result.get("error") return result # Get final output output_step = steps_by_id.get(plan.output_step_id) if output_step: output_result = result.step_results.get(output_step.step_id) if output_result: result.output_cache_id = output_step.cache_id result.output_path = output_result.output_path result.output_ipfs_cid = output_result.ipfs_cid result.status = "completed" logger.info( f"Plan {plan.plan_id[:16]}... completed: " f"{result.steps_completed} executed, {result.steps_cached} cached" ) return result def _group_by_level(self, steps: List[PlanStep]) -> Dict[int, List[PlanStep]]: """Group steps by dependency level.""" by_level = {} for step in steps: by_level.setdefault(step.level, []).append(step) return by_level def _is_cached(self, cache_id: str) -> bool: """Check if a cache_id exists in cache.""" if self.cache_manager is None: return False path = self.cache_manager.get_by_cid(cache_id) return path is not None def _get_cached_path(self, cache_id: str) -> Optional[str]: """Get the path for a cached item.""" if self.cache_manager is None: return None path = self.cache_manager.get_by_cid(cache_id) return str(path) if path else None def _get_execute_task(self): """Get the Celery task for step execution.""" if self.celery_app is None: raise RuntimeError("Celery app not configured") return self.celery_app.tasks[self.execute_task_name] def create_scheduler(cache_manager=None, celery_app=None) -> PlanScheduler: """ Create a scheduler with the given dependencies. If not provided, attempts to import from art-celery. """ if celery_app is None: try: from celery_app import app as celery_app except ImportError: pass if cache_manager is None: try: from cache_manager import get_cache_manager cache_manager = get_cache_manager() except ImportError: pass return PlanScheduler( cache_manager=cache_manager, celery_app=celery_app, ) def schedule_plan( plan: ExecutionPlanSexp, cache_manager=None, celery_app=None, timeout: int = 3600, ) -> PlanResult: """ Convenience function to schedule a plan. Args: plan: The execution plan cache_manager: Optional cache manager celery_app: Optional Celery app timeout: Execution timeout Returns: PlanResult """ scheduler = create_scheduler(cache_manager, celery_app) return scheduler.schedule(plan, timeout=timeout) # Stage-aware scheduling @dataclass class StageResult: """Result from executing a stage.""" stage_name: str cache_id: str status: str # 'completed', 'cached', 'failed', 'pending' step_results: Dict[str, StepResult] = field(default_factory=dict) outputs: Dict[str, str] = field(default_factory=dict) # binding_name -> cache_id error: Optional[str] = None @dataclass class StagePlanResult: """Result from executing a plan with stages.""" plan_id: str status: str # 'completed', 'failed', 'partial' stages_completed: int = 0 stages_cached: int = 0 stages_failed: int = 0 steps_completed: int = 0 steps_cached: int = 0 steps_failed: int = 0 stage_results: Dict[str, StageResult] = field(default_factory=dict) output_cache_id: Optional[str] = None output_path: Optional[str] = None error: Optional[str] = None class StagePlanScheduler: """ Stage-aware scheduler for S-expression plans. The scheduler: 1. Groups stages by level (parallel groups) 2. For each stage level: - Check stage cache, skip entire stage if hit - Execute stage steps (grouped by level within stage) - Cache stage outputs 3. Stages at same level can run in parallel """ def __init__( self, cache_manager=None, stage_cache=None, celery_app=None, execute_task_name: str = 'tasks.execute_step_sexp', ): """ Initialize the stage-aware scheduler. Args: cache_manager: L1 cache manager for step-level caching stage_cache: StageCache instance for stage-level caching celery_app: Celery application instance execute_task_name: Name of the Celery task for step execution """ self.cache_manager = cache_manager self.stage_cache = stage_cache self.celery_app = celery_app self.execute_task_name = execute_task_name def schedule( self, plan: ExecutionPlanSexp, timeout: int = 3600, ) -> StagePlanResult: """ Schedule and execute a plan with stage awareness. If the plan has stages, uses stage-level scheduling. Otherwise, falls back to step-level scheduling. Args: plan: The execution plan (S-expression format) timeout: Timeout in seconds for the entire plan Returns: StagePlanResult with execution results """ # If no stages, use regular scheduling if not plan.stage_plans: logger.info("Plan has no stages, using step-level scheduling") regular_scheduler = PlanScheduler( cache_manager=self.cache_manager, celery_app=self.celery_app, execute_task_name=self.execute_task_name, ) step_result = regular_scheduler.schedule(plan, timeout) return StagePlanResult( plan_id=step_result.plan_id, status=step_result.status, steps_completed=step_result.steps_completed, steps_cached=step_result.steps_cached, steps_failed=step_result.steps_failed, output_cache_id=step_result.output_cache_id, output_path=step_result.output_path, error=step_result.error, ) logger.info( f"Scheduling plan {plan.plan_id[:16]}... " f"({len(plan.stage_plans)} stages, {len(plan.steps)} steps)" ) result = StagePlanResult( plan_id=plan.plan_id, status="pending", ) # Group stages by level stages_by_level = self._group_stages_by_level(plan.stage_plans) max_level = max(stages_by_level.keys()) if stages_by_level else 0 # Track stage outputs for data flow stage_outputs = {} # stage_name -> {binding_name -> cache_id} # Execute stage by stage level for level in range(max_level + 1): level_stages = stages_by_level.get(level, []) if not level_stages: continue logger.info(f"Stage level {level}: {len(level_stages)} stages") # Check stage cache for each stage stages_to_run = [] for stage_plan in level_stages: if self._is_stage_cached(stage_plan.cache_id): result.stages_cached += 1 cached_entry = self._load_cached_stage(stage_plan.cache_id) if cached_entry: stage_outputs[stage_plan.stage_name] = { name: out.cache_id for name, out in cached_entry.outputs.items() } result.stage_results[stage_plan.stage_name] = StageResult( stage_name=stage_plan.stage_name, cache_id=stage_plan.cache_id, status="cached", outputs=stage_outputs[stage_plan.stage_name], ) logger.info(f"Stage {stage_plan.stage_name}: cached") else: stages_to_run.append(stage_plan) if not stages_to_run: logger.info(f"Stage level {level}: all {len(level_stages)} stages cached") continue # Execute uncached stages # For now, execute sequentially; L1 Celery will add parallel execution for stage_plan in stages_to_run: logger.info(f"Executing stage: {stage_plan.stage_name}") stage_result = self._execute_stage( stage_plan, plan, stage_outputs, timeout, ) result.stage_results[stage_plan.stage_name] = stage_result if stage_result.status == "completed": result.stages_completed += 1 stage_outputs[stage_plan.stage_name] = stage_result.outputs # Cache the stage result self._cache_stage(stage_plan, stage_result) elif stage_result.status == "failed": result.stages_failed += 1 result.status = "failed" result.error = stage_result.error return result # Accumulate step counts for sr in stage_result.step_results.values(): if sr.status == "completed": result.steps_completed += 1 elif sr.status == "cached": result.steps_cached += 1 elif sr.status == "failed": result.steps_failed += 1 # Get final output if plan.stage_plans: last_stage = plan.stage_plans[-1] if last_stage.stage_name in result.stage_results: stage_res = result.stage_results[last_stage.stage_name] result.output_cache_id = last_stage.cache_id # Find the output step's path from step results for step_res in stage_res.step_results.values(): if step_res.output_path: result.output_path = step_res.output_path result.status = "completed" logger.info( f"Plan {plan.plan_id[:16]}... completed: " f"{result.stages_completed} stages executed, " f"{result.stages_cached} stages cached" ) return result def _group_stages_by_level(self, stage_plans: List) -> Dict[int, List]: """Group stage plans by their level.""" by_level = {} for stage_plan in stage_plans: by_level.setdefault(stage_plan.level, []).append(stage_plan) return by_level def _is_stage_cached(self, cache_id: str) -> bool: """Check if a stage is cached.""" if self.stage_cache is None: return False return self.stage_cache.has_stage(cache_id) def _load_cached_stage(self, cache_id: str): """Load a cached stage entry.""" if self.stage_cache is None: return None return self.stage_cache.load_stage(cache_id) def _cache_stage(self, stage_plan, stage_result: StageResult) -> None: """Cache a stage result.""" if self.stage_cache is None: return from .stage_cache import StageCacheEntry, StageOutput outputs = {} for name, cache_id in stage_result.outputs.items(): outputs[name] = StageOutput( cache_id=cache_id, output_type="artifact", ) entry = StageCacheEntry( stage_name=stage_plan.stage_name, cache_id=stage_plan.cache_id, outputs=outputs, ) self.stage_cache.save_stage(entry) def _execute_stage( self, stage_plan, plan: ExecutionPlanSexp, stage_outputs: Dict, timeout: int, ) -> StageResult: """ Execute a single stage. Uses the PlanScheduler to execute the stage's steps. """ # Create a mini-plan with just this stage's steps stage_steps = stage_plan.steps # Build step lookup steps_by_id = {s.step_id: s for s in stage_steps} steps_by_level = {} for step in stage_steps: steps_by_level.setdefault(step.level, []).append(step) max_level = max(steps_by_level.keys()) if steps_by_level else 0 # Track step results step_results = {} cache_ids = dict(plan.inputs) # Start with input hashes for step in plan.steps: cache_ids[step.step_id] = step.cache_id # Include outputs from previous stages for stage_name, outputs in stage_outputs.items(): for binding_name, binding_cache_id in outputs.items(): cache_ids[binding_name] = binding_cache_id # Execute steps level by level for level in range(max_level + 1): level_steps = steps_by_level.get(level, []) if not level_steps: continue # Check cache for each step steps_to_run = [] for step in level_steps: if self._is_step_cached(step.cache_id): step_results[step.step_id] = StepResult( step_id=step.step_id, cache_id=step.cache_id, status="cached", output_path=self._get_cached_path(step.cache_id), ) else: steps_to_run.append(step) if not steps_to_run: continue # Execute steps (for now, sequentially - L1 will add Celery dispatch) for step in steps_to_run: # In a full implementation, this would dispatch to Celery # For now, mark as pending step_results[step.step_id] = StepResult( step_id=step.step_id, cache_id=step.cache_id, status="pending", ) # If Celery is configured, dispatch the task if self.celery_app: try: task_result = self._dispatch_step(step, cache_ids, timeout) step_results[step.step_id] = StepResult( step_id=step.step_id, cache_id=step.cache_id, status=task_result.get("status", "completed"), output_path=task_result.get("output_path"), error=task_result.get("error"), ipfs_cid=task_result.get("ipfs_cid"), ) except Exception as e: step_results[step.step_id] = StepResult( step_id=step.step_id, cache_id=step.cache_id, status="failed", error=str(e), ) return StageResult( stage_name=stage_plan.stage_name, cache_id=stage_plan.cache_id, status="failed", step_results=step_results, error=str(e), ) # Build output bindings outputs = {} for out_name, node_id in stage_plan.output_bindings.items(): outputs[out_name] = cache_ids.get(node_id, node_id) return StageResult( stage_name=stage_plan.stage_name, cache_id=stage_plan.cache_id, status="completed", step_results=step_results, outputs=outputs, ) def _is_step_cached(self, cache_id: str) -> bool: """Check if a step is cached.""" if self.cache_manager is None: return False path = self.cache_manager.get_by_cid(cache_id) return path is not None def _get_cached_path(self, cache_id: str) -> Optional[str]: """Get the path for a cached step.""" if self.cache_manager is None: return None path = self.cache_manager.get_by_cid(cache_id) return str(path) if path else None def _dispatch_step(self, step, cache_ids: Dict, timeout: int) -> Dict: """Dispatch a step to Celery for execution.""" if self.celery_app is None: raise RuntimeError("Celery app not configured") task = self.celery_app.tasks[self.execute_task_name] step_sexp = step_sexp_to_string(step) input_cache_ids = { inp: cache_ids.get(inp, inp) for inp in step.inputs } async_result = task.apply_async( kwargs={ "step_sexp": step_sexp, "step_id": step.step_id, "cache_id": step.cache_id, "input_cache_ids": input_cache_ids, } ) return async_result.get(timeout=timeout) def create_stage_scheduler( cache_manager=None, stage_cache=None, celery_app=None, ) -> StagePlanScheduler: """ Create a stage-aware scheduler with the given dependencies. Args: cache_manager: L1 cache manager for step-level caching stage_cache: StageCache instance for stage-level caching celery_app: Celery application instance Returns: StagePlanScheduler """ if celery_app is None: try: from celery_app import app as celery_app except ImportError: pass if cache_manager is None: try: from cache_manager import get_cache_manager cache_manager = get_cache_manager() except ImportError: pass return StagePlanScheduler( cache_manager=cache_manager, stage_cache=stage_cache, celery_app=celery_app, ) def schedule_staged_plan( plan: ExecutionPlanSexp, cache_manager=None, stage_cache=None, celery_app=None, timeout: int = 3600, ) -> StagePlanResult: """ Convenience function to schedule a plan with stage awareness. Args: plan: The execution plan cache_manager: Optional step-level cache manager stage_cache: Optional stage-level cache celery_app: Optional Celery app timeout: Execution timeout Returns: StagePlanResult """ scheduler = create_stage_scheduler(cache_manager, stage_cache, celery_app) return scheduler.schedule(plan, timeout=timeout)