""" Database module for Art DAG L2 Server. Uses asyncpg for async PostgreSQL access with connection pooling. """ import json import os from datetime import datetime, timezone from typing import Optional from contextlib import asynccontextmanager from uuid import UUID import asyncpg # Connection pool (initialized on startup) def _parse_timestamp(ts) -> datetime: """Parse a timestamp string or datetime to datetime object.""" if ts is None: return datetime.now(timezone.utc) if isinstance(ts, datetime): return ts # Parse ISO format string if isinstance(ts, str): if ts.endswith('Z'): ts = ts[:-1] + '+00:00' return datetime.fromisoformat(ts) return datetime.now(timezone.utc) _pool: Optional[asyncpg.Pool] = None # Configuration from environment DATABASE_URL = os.environ.get( "DATABASE_URL", "postgresql://artdag:artdag@localhost:5432/artdag" ) # Schema for database initialization SCHEMA = """ -- Users table CREATE TABLE IF NOT EXISTS users ( username VARCHAR(255) PRIMARY KEY, password_hash VARCHAR(255) NOT NULL, email VARCHAR(255), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -- Assets table CREATE TABLE IF NOT EXISTS assets ( name VARCHAR(255) PRIMARY KEY, content_hash VARCHAR(128) NOT NULL, ipfs_cid VARCHAR(128), asset_type VARCHAR(50) NOT NULL, tags JSONB DEFAULT '[]'::jsonb, metadata JSONB DEFAULT '{}'::jsonb, url TEXT, provenance JSONB, description TEXT, origin JSONB, owner VARCHAR(255) NOT NULL REFERENCES users(username), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ ); -- Activities table CREATE TABLE IF NOT EXISTS activities ( activity_id UUID PRIMARY KEY, activity_type VARCHAR(50) NOT NULL, actor_id TEXT NOT NULL, object_data JSONB NOT NULL, published TIMESTAMPTZ NOT NULL, signature JSONB ); -- Followers table CREATE TABLE IF NOT EXISTS followers ( id SERIAL PRIMARY KEY, username VARCHAR(255) NOT NULL REFERENCES users(username), acct VARCHAR(255) NOT NULL, url TEXT NOT NULL, public_key TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), UNIQUE(username, acct) ); -- Indexes CREATE INDEX IF NOT EXISTS idx_users_created_at ON users(created_at); CREATE INDEX IF NOT EXISTS idx_assets_content_hash ON assets(content_hash); CREATE INDEX IF NOT EXISTS idx_assets_owner ON assets(owner); CREATE INDEX IF NOT EXISTS idx_assets_created_at ON assets(created_at DESC); CREATE INDEX IF NOT EXISTS idx_assets_tags ON assets USING GIN(tags); CREATE INDEX IF NOT EXISTS idx_activities_actor_id ON activities(actor_id); CREATE INDEX IF NOT EXISTS idx_activities_published ON activities(published DESC); CREATE INDEX IF NOT EXISTS idx_followers_username ON followers(username); """ async def init_pool(): """Initialize the connection pool and create tables. Call on app startup.""" global _pool _pool = await asyncpg.create_pool( DATABASE_URL, min_size=2, max_size=10, command_timeout=60 ) # Create tables if they don't exist async with _pool.acquire() as conn: await conn.execute(SCHEMA) async def close_pool(): """Close the connection pool. Call on app shutdown.""" global _pool if _pool: await _pool.close() _pool = None def get_pool() -> asyncpg.Pool: """Get the connection pool.""" if _pool is None: raise RuntimeError("Database pool not initialized") return _pool @asynccontextmanager async def get_connection(): """Get a connection from the pool.""" async with get_pool().acquire() as conn: yield conn # ============ Users ============ async def get_user(username: str) -> Optional[dict]: """Get user by username.""" async with get_connection() as conn: row = await conn.fetchrow( "SELECT username, password_hash, email, created_at FROM users WHERE username = $1", username ) if row: return dict(row) return None async def get_all_users() -> dict[str, dict]: """Get all users as a dict indexed by username.""" async with get_connection() as conn: rows = await conn.fetch( "SELECT username, password_hash, email, created_at FROM users ORDER BY username" ) return {row["username"]: dict(row) for row in rows} async def create_user(username: str, password_hash: str, email: Optional[str] = None) -> dict: """Create a new user.""" async with get_connection() as conn: row = await conn.fetchrow( """INSERT INTO users (username, password_hash, email) VALUES ($1, $2, $3) RETURNING username, password_hash, email, created_at""", username, password_hash, email ) return dict(row) async def user_exists(username: str) -> bool: """Check if user exists.""" async with get_connection() as conn: result = await conn.fetchval( "SELECT EXISTS(SELECT 1 FROM users WHERE username = $1)", username ) return result # ============ Assets ============ async def get_asset(name: str) -> Optional[dict]: """Get asset by name.""" async with get_connection() as conn: row = await conn.fetchrow( """SELECT name, content_hash, asset_type, tags, metadata, url, provenance, description, origin, owner, created_at, updated_at FROM assets WHERE name = $1""", name ) if row: return _parse_asset_row(row) return None async def get_asset_by_hash(content_hash: str) -> Optional[dict]: """Get asset by content hash.""" async with get_connection() as conn: row = await conn.fetchrow( """SELECT name, content_hash, asset_type, tags, metadata, url, provenance, description, origin, owner, created_at, updated_at FROM assets WHERE content_hash = $1""", content_hash ) if row: return _parse_asset_row(row) return None async def get_all_assets() -> dict[str, dict]: """Get all assets as a dict indexed by name.""" async with get_connection() as conn: rows = await conn.fetch( """SELECT name, content_hash, asset_type, tags, metadata, url, provenance, description, origin, owner, created_at, updated_at FROM assets ORDER BY created_at DESC""" ) return {row["name"]: _parse_asset_row(row) for row in rows} async def get_assets_paginated(limit: int = 100, offset: int = 0) -> tuple[list[tuple[str, dict]], int]: """Get paginated assets, returns (list of (name, asset) tuples, total_count).""" async with get_connection() as conn: total = await conn.fetchval("SELECT COUNT(*) FROM assets") rows = await conn.fetch( """SELECT name, content_hash, asset_type, tags, metadata, url, provenance, description, origin, owner, created_at, updated_at FROM assets ORDER BY created_at DESC LIMIT $1 OFFSET $2""", limit, offset ) return [(row["name"], _parse_asset_row(row)) for row in rows], total async def get_assets_by_owner(owner: str) -> dict[str, dict]: """Get all assets owned by a user.""" async with get_connection() as conn: rows = await conn.fetch( """SELECT name, content_hash, asset_type, tags, metadata, url, provenance, description, origin, owner, created_at, updated_at FROM assets WHERE owner = $1 ORDER BY created_at DESC""", owner ) return {row["name"]: _parse_asset_row(row) for row in rows} async def create_asset(asset: dict) -> dict: """Create a new asset.""" async with get_connection() as conn: row = await conn.fetchrow( """INSERT INTO assets (name, content_hash, ipfs_cid, asset_type, tags, metadata, url, provenance, description, origin, owner, created_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING *""", asset["name"], asset["content_hash"], asset.get("ipfs_cid"), asset["asset_type"], json.dumps(asset.get("tags", [])), json.dumps(asset.get("metadata", {})), asset.get("url"), json.dumps(asset.get("provenance")) if asset.get("provenance") else None, asset.get("description"), json.dumps(asset.get("origin")) if asset.get("origin") else None, asset["owner"], _parse_timestamp(asset.get("created_at")) ) return _parse_asset_row(row) async def update_asset(name: str, updates: dict) -> Optional[dict]: """Update an existing asset.""" # Build dynamic UPDATE query set_clauses = [] values = [] idx = 1 for key, value in updates.items(): if key in ("tags", "metadata", "provenance", "origin"): set_clauses.append(f"{key} = ${idx}") values.append(json.dumps(value) if value is not None else None) else: set_clauses.append(f"{key} = ${idx}") values.append(value) idx += 1 set_clauses.append(f"updated_at = ${idx}") values.append(datetime.now(timezone.utc)) idx += 1 values.append(name) # WHERE clause async with get_connection() as conn: row = await conn.fetchrow( f"""UPDATE assets SET {', '.join(set_clauses)} WHERE name = ${idx} RETURNING *""", *values ) if row: return _parse_asset_row(row) return None async def asset_exists(name: str) -> bool: """Check if asset exists.""" async with get_connection() as conn: return await conn.fetchval( "SELECT EXISTS(SELECT 1 FROM assets WHERE name = $1)", name ) def _parse_asset_row(row) -> dict: """Parse a database row into an asset dict, handling JSONB fields.""" asset = dict(row) # Convert datetime to ISO string if asset.get("created_at"): asset["created_at"] = asset["created_at"].isoformat() if asset.get("updated_at"): asset["updated_at"] = asset["updated_at"].isoformat() # Ensure JSONB fields are dicts (handle string case) for field in ("tags", "metadata", "provenance", "origin"): if isinstance(asset.get(field), str): try: asset[field] = json.loads(asset[field]) except (json.JSONDecodeError, TypeError): pass return asset # ============ Activities ============ async def get_activity(activity_id: str) -> Optional[dict]: """Get activity by ID.""" async with get_connection() as conn: row = await conn.fetchrow( """SELECT activity_id, activity_type, actor_id, object_data, published, signature FROM activities WHERE activity_id = $1""", UUID(activity_id) ) if row: return _parse_activity_row(row) return None async def get_activity_by_index(index: int) -> Optional[dict]: """Get activity by index (for backward compatibility with URL scheme).""" async with get_connection() as conn: row = await conn.fetchrow( """SELECT activity_id, activity_type, actor_id, object_data, published, signature FROM activities ORDER BY published ASC LIMIT 1 OFFSET $1""", index ) if row: return _parse_activity_row(row) return None async def get_all_activities() -> list[dict]: """Get all activities ordered by published date (oldest first for index compatibility).""" async with get_connection() as conn: rows = await conn.fetch( """SELECT activity_id, activity_type, actor_id, object_data, published, signature FROM activities ORDER BY published ASC""" ) return [_parse_activity_row(row) for row in rows] async def get_activities_paginated(limit: int = 100, offset: int = 0) -> tuple[list[dict], int]: """Get paginated activities (newest first), returns (activities, total_count).""" async with get_connection() as conn: total = await conn.fetchval("SELECT COUNT(*) FROM activities") rows = await conn.fetch( """SELECT activity_id, activity_type, actor_id, object_data, published, signature FROM activities ORDER BY published DESC LIMIT $1 OFFSET $2""", limit, offset ) return [_parse_activity_row(row) for row in rows], total async def get_activities_by_actor(actor_id: str) -> list[dict]: """Get all activities by an actor.""" async with get_connection() as conn: rows = await conn.fetch( """SELECT activity_id, activity_type, actor_id, object_data, published, signature FROM activities WHERE actor_id = $1 ORDER BY published DESC""", actor_id ) return [_parse_activity_row(row) for row in rows] async def create_activity(activity: dict) -> dict: """Create a new activity.""" async with get_connection() as conn: row = await conn.fetchrow( """INSERT INTO activities (activity_id, activity_type, actor_id, object_data, published, signature) VALUES ($1, $2, $3, $4, $5, $6) RETURNING *""", UUID(activity["activity_id"]), activity["activity_type"], activity["actor_id"], json.dumps(activity["object_data"]), _parse_timestamp(activity["published"]), json.dumps(activity.get("signature")) if activity.get("signature") else None ) return _parse_activity_row(row) async def count_activities() -> int: """Get total activity count.""" async with get_connection() as conn: return await conn.fetchval("SELECT COUNT(*) FROM activities") def _parse_activity_row(row) -> dict: """Parse a database row into an activity dict, handling JSONB fields.""" activity = dict(row) # Convert UUID to string if activity.get("activity_id"): activity["activity_id"] = str(activity["activity_id"]) # Convert datetime to ISO string if activity.get("published"): activity["published"] = activity["published"].isoformat() # Ensure JSONB fields are dicts (handle string case) for field in ("object_data", "signature"): if isinstance(activity.get(field), str): try: activity[field] = json.loads(activity[field]) except (json.JSONDecodeError, TypeError): pass return activity # ============ Followers ============ async def get_followers(username: str) -> list[dict]: """Get followers for a user.""" async with get_connection() as conn: rows = await conn.fetch( """SELECT id, username, acct, url, public_key, created_at FROM followers WHERE username = $1""", username ) return [dict(row) for row in rows] async def get_all_followers() -> list: """Get all followers (for backward compatibility with old global list).""" async with get_connection() as conn: rows = await conn.fetch( """SELECT DISTINCT url FROM followers""" ) return [row["url"] for row in rows] async def add_follower(username: str, acct: str, url: str, public_key: Optional[str] = None) -> dict: """Add a follower.""" async with get_connection() as conn: row = await conn.fetchrow( """INSERT INTO followers (username, acct, url, public_key) VALUES ($1, $2, $3, $4) ON CONFLICT (username, acct) DO UPDATE SET url = $3, public_key = $4 RETURNING *""", username, acct, url, public_key ) return dict(row) async def remove_follower(username: str, acct: str) -> bool: """Remove a follower.""" async with get_connection() as conn: result = await conn.execute( "DELETE FROM followers WHERE username = $1 AND acct = $2", username, acct ) return result == "DELETE 1" # ============ Stats ============ async def get_stats() -> dict: """Get counts for dashboard.""" async with get_connection() as conn: assets = await conn.fetchval("SELECT COUNT(*) FROM assets") activities = await conn.fetchval("SELECT COUNT(*) FROM activities") users = await conn.fetchval("SELECT COUNT(*) FROM users") return {"assets": assets, "activities": activities, "users": users}