from __future__ import annotations from functools import wraps from typing import Any, Dict, Iterable, Optional import inspect from quart import g, abort, redirect, request, current_app from shared.urls import login_url def require_rights(*rights: str, any_of: bool = True): """ Decorator for routes that require certain user rights. """ if not rights: raise ValueError("require_rights needs at least one right name") required_set = frozenset(rights) def decorator(view_func): @wraps(view_func) async def wrapper(*args: Any, **kwargs: Any): # Not logged in → go to login, with ?next= user = g.get("user") if not user: return redirect(login_url(request.url)) rights_dict = g.get("rights") or {} if any_of: allowed = any(rights_dict.get(name) for name in required_set) else: allowed = all(rights_dict.get(name) for name in required_set) if not allowed: abort(403) result = view_func(*args, **kwargs) if inspect.isawaitable(result): return await result return result # ---- expose access requirements on the wrapper ---- wrapper.__access_requires__ = { "rights": required_set, "any_of": any_of, } return wrapper return decorator def require_login(view_func): """ Decorator for routes that require any logged-in user. """ @wraps(view_func) async def wrapper(*args: Any, **kwargs: Any): user = g.get("user") if not user: return redirect(login_url(request.url)) result = view_func(*args, **kwargs) if inspect.isawaitable(result): return await result return result return wrapper def require_admin(view_func=None): """ Shortcut for routes that require the 'admin' right. """ if view_func is None: return require_rights("admin") return require_rights("admin")(view_func) def require_post_author(view_func): """Allow admin or post owner.""" @wraps(view_func) async def wrapper(*args, **kwargs): user = g.get("user") if not user: return redirect(login_url(request.url)) is_admin = bool((g.get("rights") or {}).get("admin")) if is_admin: result = view_func(*args, **kwargs) if inspect.isawaitable(result): return await result return result post = getattr(g, "post_data", {}).get("original_post") if post and post.user_id == user.id: result = view_func(*args, **kwargs) if inspect.isawaitable(result): return await result return result abort(403) return wrapper def _get_access_meta(view_func) -> Optional[Dict[str, Any]]: """ Walk the wrapper chain looking for __access_requires__ metadata. """ func = view_func seen: set[int] = set() while func is not None and id(func) not in seen: seen.add(id(func)) meta = getattr(func, "__access_requires__", None) if meta is not None: return meta func = getattr(func, "__wrapped__", None) return None def has_access(endpoint: str) -> bool: """ Return True if the current user has access to the given endpoint. Example: has_access("settings.home") has_access("settings.clear_cache_view") """ view = current_app.view_functions.get(endpoint) if view is None: # Unknown endpoint: be conservative return False meta = _get_access_meta(view) # If the route has no rights metadata, treat it as public: if meta is None: return True required: Iterable[str] = meta["rights"] any_of: bool = meta["any_of"] # Must be in a request context; if no user, they don't have access user = g.get("user") if not user: return False rights_dict = g.get("rights") or {} if any_of: return any(rights_dict.get(name) for name in required) else: return all(rights_dict.get(name) for name in required)