diff --git a/contracts/protocols.py b/contracts/protocols.py index 0f6b7c7..4e57c36 100644 --- a/contracts/protocols.py +++ b/contracts/protocols.py @@ -130,7 +130,8 @@ class CalendarService(Protocol): ) -> set[int]: ... async def upcoming_entries_for_container( - self, session: AsyncSession, container_type: str, container_id: int, + self, session: AsyncSession, + container_type: str | None = None, container_id: int | None = None, *, page: int = 1, per_page: int = 20, ) -> tuple[list[CalendarEntryDTO], bool]: ... diff --git a/services/calendar_impl.py b/services/calendar_impl.py index b6159b8..26a3f0e 100644 --- a/services/calendar_impl.py +++ b/services/calendar_impl.py @@ -240,25 +240,35 @@ class SqlCalendarService: return [_entry_to_dto(e) for e in merged] async def upcoming_entries_for_container( - self, session: AsyncSession, container_type: str, container_id: int, + self, session: AsyncSession, + container_type: str | None = None, container_id: int | None = None, *, page: int = 1, per_page: int = 20, ) -> tuple[list[CalendarEntryDTO], bool]: - """Upcoming confirmed entries across all calendars for a container.""" - cal_ids = select(Calendar.id).where( - Calendar.container_type == container_type, - Calendar.container_id == container_id, - Calendar.deleted_at.is_(None), - ).scalar_subquery() + """Upcoming confirmed entries. Optionally scoped to a container.""" + filters = [ + CalendarEntry.state == "confirmed", + CalendarEntry.deleted_at.is_(None), + CalendarEntry.start_at >= func.now(), + ] + + if container_type is not None and container_id is not None: + cal_ids = select(Calendar.id).where( + Calendar.container_type == container_type, + Calendar.container_id == container_id, + Calendar.deleted_at.is_(None), + ).scalar_subquery() + filters.append(CalendarEntry.calendar_id.in_(cal_ids)) + else: + # Still exclude entries from deleted calendars + cal_ids = select(Calendar.id).where( + Calendar.deleted_at.is_(None), + ).scalar_subquery() + filters.append(CalendarEntry.calendar_id.in_(cal_ids)) offset = (page - 1) * per_page result = await session.execute( select(CalendarEntry) - .where( - CalendarEntry.calendar_id.in_(cal_ids), - CalendarEntry.state == "confirmed", - CalendarEntry.deleted_at.is_(None), - CalendarEntry.start_at >= func.now(), - ) + .where(*filters) .order_by(CalendarEntry.start_at.asc()) .limit(per_page) .offset(offset)