""" S-expression interpreter for streaming execution. Evaluates sexp expressions including: - let bindings - lambda definitions and calls - Arithmetic, comparison, logic operators - dict/list operations - Random number generation """ import random from typing import Any, Dict, List, Callable from dataclasses import dataclass @dataclass class Lambda: """Runtime lambda value.""" params: List[str] body: Any closure: Dict[str, Any] class Symbol: """Symbol reference.""" def __init__(self, name: str): self.name = name def __repr__(self): return f"Symbol({self.name})" class SexpInterpreter: """ Interprets S-expressions in real-time. Handles the full sexp language used in recipes. """ def __init__(self, rng: random.Random = None): self.rng = rng or random.Random() self.globals: Dict[str, Any] = {} def eval(self, expr: Any, env: Dict[str, Any] = None) -> Any: """Evaluate an expression in the given environment.""" if env is None: env = {} # Literals if isinstance(expr, (int, float, str, bool)) or expr is None: return expr # Symbol lookup if isinstance(expr, Symbol) or (hasattr(expr, 'name') and hasattr(expr, '__class__') and expr.__class__.__name__ == 'Symbol'): name = expr.name if hasattr(expr, 'name') else str(expr) if name in env: return env[name] if name in self.globals: return self.globals[name] raise NameError(f"Undefined symbol: {name}") # Compiled expression dict (from compiler) if isinstance(expr, dict): if expr.get('_expr'): return self._eval_compiled_expr(expr, env) # Plain dict - evaluate values that might be expressions result = {} for k, v in expr.items(): # Some keys should keep Symbol values as strings (effect names, modes) if k in ('effect', 'mode') and hasattr(v, 'name'): result[k] = v.name else: result[k] = self.eval(v, env) return result # List expression (sexp) if isinstance(expr, (list, tuple)) and len(expr) > 0: return self._eval_list(expr, env) # Empty list if isinstance(expr, (list, tuple)): return [] return expr def _eval_compiled_expr(self, expr: dict, env: Dict[str, Any]) -> Any: """Evaluate a compiled expression dict.""" op = expr.get('op') args = expr.get('args', []) if op == 'var': name = expr.get('name') if name in env: return env[name] if name in self.globals: return self.globals[name] raise NameError(f"Undefined: {name}") elif op == 'dict': keys = expr.get('keys', []) values = [self.eval(a, env) for a in args] return dict(zip(keys, values)) elif op == 'get': obj = self.eval(args[0], env) key = args[1] return obj.get(key) if isinstance(obj, dict) else obj[key] elif op == 'if': cond = self.eval(args[0], env) if cond: return self.eval(args[1], env) elif len(args) > 2: return self.eval(args[2], env) return None # Comparison elif op == '<': return self.eval(args[0], env) < self.eval(args[1], env) elif op == '>': return self.eval(args[0], env) > self.eval(args[1], env) elif op == '<=': return self.eval(args[0], env) <= self.eval(args[1], env) elif op == '>=': return self.eval(args[0], env) >= self.eval(args[1], env) elif op == '=': return self.eval(args[0], env) == self.eval(args[1], env) elif op == '!=': return self.eval(args[0], env) != self.eval(args[1], env) # Arithmetic elif op == '+': return self.eval(args[0], env) + self.eval(args[1], env) elif op == '-': return self.eval(args[0], env) - self.eval(args[1], env) elif op == '*': return self.eval(args[0], env) * self.eval(args[1], env) elif op == '/': return self.eval(args[0], env) / self.eval(args[1], env) elif op == 'mod': return self.eval(args[0], env) % self.eval(args[1], env) # Random elif op == 'rand': return self.rng.random() elif op == 'rand-int': return self.rng.randint(self.eval(args[0], env), self.eval(args[1], env)) elif op == 'rand-range': return self.rng.uniform(self.eval(args[0], env), self.eval(args[1], env)) # Logic elif op == 'and': return all(self.eval(a, env) for a in args) elif op == 'or': return any(self.eval(a, env) for a in args) elif op == 'not': return not self.eval(args[0], env) else: raise ValueError(f"Unknown op: {op}") def _eval_list(self, expr: list, env: Dict[str, Any]) -> Any: """Evaluate a list expression (sexp form).""" if len(expr) == 0: return [] head = expr[0] # Get head name if isinstance(head, Symbol) or (hasattr(head, 'name') and hasattr(head, '__class__')): head_name = head.name if hasattr(head, 'name') else str(head) elif isinstance(head, str): head_name = head else: # Not a symbol - check if it's a data list or function call if isinstance(head, dict): # List of dicts - evaluate each element as data return [self.eval(item, env) for item in expr] # Otherwise evaluate as function call fn = self.eval(head, env) args = [self.eval(a, env) for a in expr[1:]] return self._call(fn, args, env) # Special forms if head_name == 'let': return self._eval_let(expr, env) elif head_name in ('lambda', 'fn'): return self._eval_lambda(expr, env) elif head_name == 'if': return self._eval_if(expr, env) elif head_name == 'dict': return self._eval_dict(expr, env) elif head_name == 'get': obj = self.eval(expr[1], env) key = self.eval(expr[2], env) if len(expr) > 2 else expr[2] if isinstance(key, str): return obj.get(key) if isinstance(obj, dict) else getattr(obj, key, None) return obj[key] elif head_name == 'len': return len(self.eval(expr[1], env)) elif head_name == 'range': start = self.eval(expr[1], env) end = self.eval(expr[2], env) if len(expr) > 2 else start if len(expr) == 2: return list(range(end)) return list(range(start, end)) elif head_name == 'map': fn = self.eval(expr[1], env) lst = self.eval(expr[2], env) return [self._call(fn, [x], env) for x in lst] elif head_name == 'mod': return self.eval(expr[1], env) % self.eval(expr[2], env) # Arithmetic elif head_name == '+': return self.eval(expr[1], env) + self.eval(expr[2], env) elif head_name == '-': if len(expr) == 2: return -self.eval(expr[1], env) return self.eval(expr[1], env) - self.eval(expr[2], env) elif head_name == '*': return self.eval(expr[1], env) * self.eval(expr[2], env) elif head_name == '/': return self.eval(expr[1], env) / self.eval(expr[2], env) # Comparison elif head_name == '<': return self.eval(expr[1], env) < self.eval(expr[2], env) elif head_name == '>': return self.eval(expr[1], env) > self.eval(expr[2], env) elif head_name == '<=': return self.eval(expr[1], env) <= self.eval(expr[2], env) elif head_name == '>=': return self.eval(expr[1], env) >= self.eval(expr[2], env) elif head_name == '=': return self.eval(expr[1], env) == self.eval(expr[2], env) # Logic elif head_name == 'and': return all(self.eval(a, env) for a in expr[1:]) elif head_name == 'or': return any(self.eval(a, env) for a in expr[1:]) elif head_name == 'not': return not self.eval(expr[1], env) # Function call else: fn = env.get(head_name) or self.globals.get(head_name) if fn is None: raise NameError(f"Undefined function: {head_name}") args = [self.eval(a, env) for a in expr[1:]] return self._call(fn, args, env) def _eval_let(self, expr: list, env: Dict[str, Any]) -> Any: """Evaluate (let [bindings...] body).""" bindings = expr[1] body = expr[2] # Create new environment with bindings new_env = dict(env) # Process bindings in pairs i = 0 while i < len(bindings): name = bindings[i] if isinstance(name, Symbol) or hasattr(name, 'name'): name = name.name if hasattr(name, 'name') else str(name) value = self.eval(bindings[i + 1], new_env) new_env[name] = value i += 2 return self.eval(body, new_env) def _eval_lambda(self, expr: list, env: Dict[str, Any]) -> Lambda: """Evaluate (lambda [params] body).""" params_expr = expr[1] body = expr[2] # Extract parameter names params = [] for p in params_expr: if isinstance(p, Symbol) or hasattr(p, 'name'): params.append(p.name if hasattr(p, 'name') else str(p)) else: params.append(str(p)) return Lambda(params=params, body=body, closure=dict(env)) def _eval_if(self, expr: list, env: Dict[str, Any]) -> Any: """Evaluate (if cond then else).""" cond = self.eval(expr[1], env) if cond: return self.eval(expr[2], env) elif len(expr) > 3: return self.eval(expr[3], env) return None def _eval_dict(self, expr: list, env: Dict[str, Any]) -> dict: """Evaluate (dict :key val ...).""" result = {} i = 1 while i < len(expr): key = expr[i] # Handle keyword syntax (:key) and Keyword objects if hasattr(key, 'name'): key = key.name elif hasattr(key, '__class__') and key.__class__.__name__ == 'Keyword': key = str(key).lstrip(':') elif isinstance(key, str) and key.startswith(':'): key = key[1:] value = self.eval(expr[i + 1], env) result[key] = value i += 2 return result def _call(self, fn: Any, args: List[Any], env: Dict[str, Any]) -> Any: """Call a function with arguments.""" if isinstance(fn, Lambda): # Our own Lambda type call_env = dict(fn.closure) for param, arg in zip(fn.params, args): call_env[param] = arg return self.eval(fn.body, call_env) elif hasattr(fn, 'params') and hasattr(fn, 'body'): # Lambda from parser (artdag.sexp.parser.Lambda) call_env = dict(env) if hasattr(fn, 'closure') and fn.closure: call_env.update(fn.closure) # Get param names params = [] for p in fn.params: if hasattr(p, 'name'): params.append(p.name) else: params.append(str(p)) for param, arg in zip(params, args): call_env[param] = arg return self.eval(fn.body, call_env) elif callable(fn): return fn(*args) else: raise TypeError(f"Not callable: {type(fn).__name__}") def eval_slice_on_lambda(lambda_obj, acc: dict, i: int, start: float, end: float, videos: list, interp: SexpInterpreter = None) -> dict: """ Evaluate a SLICE_ON lambda function. Args: lambda_obj: The Lambda object from the compiled recipe acc: Current accumulator state i: Beat index start: Slice start time end: Slice end time videos: List of video inputs interp: Interpreter to use Returns: Dict with 'layers', 'compose', 'acc' keys """ if interp is None: interp = SexpInterpreter() # Set up global 'videos' for (len videos) to work interp.globals['videos'] = videos # Build initial environment with lambda parameters env = dict(lambda_obj.closure) if hasattr(lambda_obj, 'closure') and lambda_obj.closure else {} env['videos'] = videos # Call the lambda result = interp._call(lambda_obj, [acc, i, start, end], env) return result