"""Service scheduling — business logic."""
from __future__ import annotations

import logging
from datetime import date, datetime, timedelta, timezone
from typing import List, Optional, Tuple
from uuid import UUID

from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession

from src.apps.scheduling.models.service_event import ServiceEvent
from src.apps.scheduling.models.staff_assignment import ServiceStaffAssignment
from src.apps.scheduling.schemas.requests import ServiceCreate, ServiceUpdate
from src.core.exceptions import NotFoundError

logger = logging.getLogger(__name__)


class ServiceService:
    # ── Week calendar view ────────────────────────────────────────────────────

    async def get_week(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        week_start: date,
        event_type: Optional[str] = None,
    ) -> List[ServiceEvent]:
        """Return all non-deleted services in a 7-day window from week_start."""
        week_end = week_start + timedelta(days=7)
        stmt = select(ServiceEvent).where(
            ServiceEvent.tenant_id == tenant_id,
            ServiceEvent.deleted_at.is_(None),
            ServiceEvent.scheduled_date >= week_start,
            ServiceEvent.scheduled_date < week_end,
        )
        if event_type and event_type != "all":
            stmt = stmt.where(ServiceEvent.service_type == event_type)
        stmt = stmt.order_by(ServiceEvent.scheduled_date, ServiceEvent.scheduled_time)
        result = await db.execute(stmt)
        return list(result.scalars().all())

    # ── Paginated summary table ───────────────────────────────────────────────

    async def get_summary_table(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        week_start: date,
        page: int = 1,
        page_size: int = 20,
    ) -> Tuple[List[ServiceEvent], int]:
        """Paginated service list for the given week."""
        week_end = week_start + timedelta(days=7)
        base_filters = [
            ServiceEvent.tenant_id == tenant_id,
            ServiceEvent.deleted_at.is_(None),
            ServiceEvent.scheduled_date >= week_start,
            ServiceEvent.scheduled_date < week_end,
        ]

        count_stmt = (
            select(func.count())
            .select_from(ServiceEvent)
            .where(*base_filters)
        )
        total: int = (await db.execute(count_stmt)).scalar_one()

        offset = (page - 1) * page_size
        stmt = (
            select(ServiceEvent)
            .where(*base_filters)
            .order_by(ServiceEvent.scheduled_date, ServiceEvent.scheduled_time)
            .offset(offset)
            .limit(page_size)
        )
        result = await db.execute(stmt)
        return list(result.scalars().all()), total

    # ── Create ────────────────────────────────────────────────────────────────

    async def create(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        data: ServiceCreate,
        current_user,
    ) -> ServiceEvent:
        """Create a service event and persist any crew assignments."""
        crew_ids = data.assigned_crew_ids or []
        service_data = data.model_dump(exclude={"assigned_crew_ids"})

        event = ServiceEvent(
            tenant_id=tenant_id,
            created_by=current_user.id,
            **service_data,
        )
        db.add(event)
        await db.flush()  # generate event.id

        for uid in crew_ids:
            assignment = ServiceStaffAssignment(
                tenant_id=tenant_id,
                service_id=event.id,
                user_id=uid,
                role="crew",
            )
            db.add(assignment)

        await db.flush()
        logger.info("[scheduling] Created service %s for tenant %s", event.id, tenant_id)
        return event

    # ── Get by ID ─────────────────────────────────────────────────────────────

    async def get_by_id(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        service_id: UUID,
    ) -> ServiceEvent:
        """Fetch a single non-deleted service event, raise NotFoundError if missing."""
        stmt = select(ServiceEvent).where(
            ServiceEvent.id == service_id,
            ServiceEvent.tenant_id == tenant_id,
            ServiceEvent.deleted_at.is_(None),
        )
        event = (await db.execute(stmt)).scalar_one_or_none()
        if not event:
            raise NotFoundError("Service not found")
        return event

    # ── Get staff assignments ─────────────────────────────────────────────────

    async def get_assignments(
        self,
        db: AsyncSession,
        service_id: UUID,
    ) -> List[ServiceStaffAssignment]:
        """Return all crew assignments for a service."""
        stmt = select(ServiceStaffAssignment).where(
            ServiceStaffAssignment.service_id == service_id
        )
        result = await db.execute(stmt)
        return list(result.scalars().all())

    # ── Update ────────────────────────────────────────────────────────────────

    async def update(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        service_id: UUID,
        data: ServiceUpdate,
        current_user,
    ) -> ServiceEvent:
        """Update a service event and optionally replace crew assignments."""
        event = await self.get_by_id(db, tenant_id, service_id)

        update_data = data.model_dump(exclude_unset=True, exclude={"assigned_crew_ids"})
        for key, value in update_data.items():
            setattr(event, key, value)

        if data.assigned_crew_ids is not None:
            # Replace all crew assignments
            await db.execute(
                delete(ServiceStaffAssignment).where(
                    ServiceStaffAssignment.service_id == event.id
                )
            )
            for uid in data.assigned_crew_ids:
                db.add(
                    ServiceStaffAssignment(
                        tenant_id=tenant_id,
                        service_id=event.id,
                        user_id=uid,
                        role="crew",
                    )
                )

        await db.flush()
        logger.info("[scheduling] Updated service %s", service_id)
        return event

    # ── Soft delete ───────────────────────────────────────────────────────────

    async def soft_delete(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        service_id: UUID,
        current_user,
    ) -> None:
        """Soft-delete a service event."""
        event = await self.get_by_id(db, tenant_id, service_id)
        event.deleted_at = datetime.now(timezone.utc)
        await db.flush()
        logger.info("[scheduling] Soft-deleted service %s", service_id)
