from dataclasses import dataclass
from typing import List, Optional, Tuple
from uuid import UUID

from sqlalchemy import and_, outerjoin, select
from sqlalchemy.ext.asyncio import AsyncSession

from src.apps.records.models.burial_info import BurialInfo
from src.apps.records.models.family_contact import FamilyContact
from src.apps.site_admin.models.audit_log import AuditLog


@dataclass
class AuditLogWithUser:
    log: AuditLog
    user_first_name: Optional[str]
    user_last_name: Optional[str]


class AuditService:
    @staticmethod
    async def log(
        db: AsyncSession,
        entity_type: str,
        entity_id: UUID,
        action: str,
        user,
        request,
        old_value: Optional[dict] = None,
        new_value: Optional[dict] = None,
    ) -> None:
        try:
            # Extract IP: prefer X-Forwarded-For, fall back to direct client host
            ip_address = None
            if request is not None:
                forwarded_for = request.headers.get("x-forwarded-for")
                if forwarded_for:
                    ip_address = forwarded_for.split(",")[0].strip()
                elif request.client:
                    ip_address = str(request.client.host)

            entry = AuditLog(
                tenant_id=getattr(user, "tenant_id", None),
                user_id=user.id,
                entity_type=entity_type,
                entity_id=entity_id,
                action=action,
                old_value=old_value,
                new_value=new_value,
                ip_address=ip_address,
                user_agent=request.headers.get("user-agent") if request is not None else None,
            )
            db.add(entry)
            await db.flush()
        except Exception:
            pass

    @staticmethod
    async def get_record_audit(
        db: AsyncSession,
        tenant_id: UUID,
        record_id: UUID,
        page: int = 1,
        page_size: int = 20,
    ) -> Tuple[List[AuditLogWithUser], int]:
        from sqlalchemy import func
        from src.apps.auth.models.user import User

        # Collect all related entity IDs: the record itself + burial_info + family_contacts
        related_ids: List[UUID] = [record_id]

        burial_result = await db.execute(
            select(BurialInfo.id).where(BurialInfo.record_id == record_id)
        )
        burial_ids = burial_result.scalars().all()
        related_ids.extend(burial_ids)

        contact_result = await db.execute(
            select(FamilyContact.id).where(FamilyContact.record_id == record_id)
        )
        contact_ids = contact_result.scalars().all()
        related_ids.extend(contact_ids)

        base_conditions = and_(
            AuditLog.tenant_id == tenant_id,
            AuditLog.entity_id.in_(related_ids),
            AuditLog.entity_type.in_(["Record", "BurialInfo", "FamilyContact"]),
        )

        count_result = await db.execute(
            select(func.count(AuditLog.id)).where(base_conditions)
        )
        total = count_result.scalar_one()

        offset = (page - 1) * page_size
        data_result = await db.execute(
            select(AuditLog, User.first_name, User.last_name)
            .select_from(
                outerjoin(AuditLog, User, AuditLog.user_id == User.id)
            )
            .where(base_conditions)
            .order_by(AuditLog.created_at.desc())
            .offset(offset)
            .limit(page_size)
        )
        rows = data_result.all()

        results = [
            AuditLogWithUser(
                log=row[0],
                user_first_name=row[1],
                user_last_name=row[2],
            )
            for row in rows
        ]

        return results, total
