# FILE: src/apps/sales/services/contract_service.py
from __future__ import annotations

from datetime import datetime, timezone
from decimal import Decimal
from typing import Optional
from uuid import UUID

from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src.core.exceptions import NotFoundError
from src.apps.sales.models.contract import Contract
from src.apps.sales.models.contract_line_item import ContractLineItem
from src.apps.billing.models.invoice import Invoice


class ContractService:

    @staticmethod
    async def list(
        db: AsyncSession,
        tenant_id: str,
        page: int = 1,
        page_size: int = 20,
        status: Optional[str] = None,
        search: Optional[str] = None,
    ) -> tuple[list[Contract], int]:
        filters = [
            Contract.tenant_id == tenant_id,
            Contract.deleted_at.is_(None),
        ]
        if status:
            filters.append(Contract.status == status)
        if search:
            filters.append(Contract.purchaser_name.ilike(f"%{search}%"))

        total = (
            await db.execute(
                select(func.count()).select_from(Contract).where(*filters)
            )
        ).scalar_one()

        offset = (page - 1) * page_size
        result = await db.execute(
            select(Contract)
            .options(selectinload(Contract.line_items))
            .where(*filters)
            .order_by(Contract.created_at.desc())
            .offset(offset)
            .limit(page_size)
        )
        items = result.scalars().all()
        return list(items), total

    @staticmethod
    async def get_by_id(
        db: AsyncSession,
        tenant_id: str,
        contract_id: UUID,
    ) -> Contract:
        result = await db.execute(
            select(Contract)
            .options(
                selectinload(Contract.line_items),
                selectinload(Contract.invoices),
            )
            .where(
                Contract.id == contract_id,
                Contract.tenant_id == tenant_id,
                Contract.deleted_at.is_(None),
            )
        )
        contract = result.scalar_one_or_none()
        if not contract:
            raise NotFoundError("Contract not found")
        return contract

    @staticmethod
    async def create(
        db: AsyncSession,
        tenant_id: str,
        data: dict,
        current_user,
    ) -> Contract:
        year = datetime.now(timezone.utc).year

        # Count existing contracts for this tenant to generate unique number
        count = (
            await db.execute(
                select(func.count())
                .select_from(Contract)
                .where(Contract.tenant_id == tenant_id)
            )
        ).scalar_one()
        contract_number = f"CNT-{year}-{count + 1:03d}"

        line_items_data = data.pop("line_items", [])

        contract = Contract(
            tenant_id=tenant_id,
            contract_number=contract_number,
            created_by=current_user.id,
            status="draft",
            **data,
        )
        db.add(contract)
        await db.flush()

        subtotal = Decimal("0")
        for li_data in line_items_data:
            quantity = int(li_data.get("quantity", 1))
            unit_price = Decimal(str(li_data.get("unit_price", 0)))
            line_total = Decimal(str(quantity)) * unit_price
            subtotal += line_total

            li = ContractLineItem(
                contract_id=contract.id,
                tenant_id=tenant_id,
                description=li_data["description"],
                fee_type=li_data.get("fee_type"),
                quantity=quantity,
                unit_price=unit_price,
                line_total=line_total,
            )
            db.add(li)

        # Apply HST 13%
        contract.total_amount = (subtotal * Decimal("1.13")).quantize(Decimal("0.01"))

        await db.flush()

        # Re-fetch the contract with line_items eager-loaded so the caller can
        # serialize it without hitting a MissingGreenlet / lazy-load error.
        result = await db.execute(
            select(Contract)
            .options(selectinload(Contract.line_items))
            .where(Contract.id == contract.id)
        )
        return result.scalar_one()

    @staticmethod
    async def sign(
        db: AsyncSession,
        tenant_id: str,
        contract_id: UUID,
        signature_b64: str,
        witness_name: Optional[str],
        current_user,
        arq_redis=None,
    ) -> Contract:
        # SELECT ... FOR UPDATE acquires a row-level lock before the status check,
        # preventing concurrent sign requests from both passing the idempotency
        # guard and creating duplicate invoices.
        locked = await db.execute(
            select(Contract)
            .where(
                Contract.id == contract_id,
                Contract.tenant_id == tenant_id,
                Contract.deleted_at.is_(None),
            )
            .with_for_update()
        )
        contract = locked.scalar_one_or_none()
        if not contract:
            raise NotFoundError("Contract not found")

        # Idempotency guard — already signed contracts must not be re-signed
        # (would create duplicate invoices and overwrite the original signature).
        if contract.status == "signed":
            return await ContractService.get_by_id(db, tenant_id, contract_id)

        contract.status = "signed"
        contract.signed_at = datetime.now(timezone.utc)
        contract.purchaser_signature_b64 = signature_b64
        contract.witness_name = witness_name

        await ContractService._auto_create_invoice(db, tenant_id, contract)

        if arq_redis:
            await arq_redis.enqueue_job(
                "generate_contract_pdf",
                str(tenant_id),
                str(contract_id),
            )
            # Upsert qr_codes row for this contract then enqueue generation
            from src.apps.settings.services.qr_code_service import QRCodeService
            from src.apps.tenants.models.account import Account as _Account
            from sqlalchemy import select as sa_select
            from sqlalchemy.dialects.postgresql import insert as pg_insert
            from src.apps.settings.models.qr_code import QRCode

            acct_result = await db.execute(
                sa_select(_Account).where(_Account.id == tenant_id)
            )
            _account = acct_result.scalar_one_or_none()
            _subdomain = _account.subdomain if _account else str(tenant_id)
            _content_url = QRCodeService.build_content_url(
                _subdomain, "contract", contract.contract_number
            )
            _upsert_stmt = (
                pg_insert(QRCode)
                .values(
                    tenant_id=tenant_id,
                    qr_type="contract",
                    reference_id=contract.contract_number,
                    display_label=f"Contract {contract.contract_number}",
                    content_url=_content_url,
                    is_active=True,
                )
                .on_conflict_do_update(
                    constraint="uq_qr_codes_tenant_type_ref",
                    set_={
                        "display_label": f"Contract {contract.contract_number}",
                    },
                )
                .returning(QRCode.id)
            )
            _qr_result = await db.execute(_upsert_stmt)
            _qr_id = _qr_result.scalar_one()
            await db.flush()
            await arq_redis.enqueue_job("generate_qr_code", str(_qr_id))

        await db.flush()

        # Re-fetch with eager-loaded relationships so the router can serialize
        # without a MissingGreenlet / lazy-load error.
        result = await db.execute(
            select(Contract)
            .options(
                selectinload(Contract.line_items),
                selectinload(Contract.invoices),
            )
            .where(Contract.id == contract.id)
        )
        return result.scalar_one()

    @staticmethod
    async def soft_delete(
        db: AsyncSession,
        tenant_id: str,
        contract_id: UUID,
        current_user,
    ) -> None:
        contract = await ContractService.get_by_id(db, tenant_id, contract_id)
        contract.soft_delete()
        await db.flush()

    @staticmethod
    async def _auto_create_invoice(
        db: AsyncSession,
        tenant_id: str,
        contract: Contract,
    ) -> None:
        year = datetime.now(timezone.utc).year

        count = (
            await db.execute(
                select(func.count())
                .select_from(Invoice)
                .where(Invoice.tenant_id == tenant_id)
            )
        ).scalar_one()
        invoice_number = f"INV-{year}-{count + 1:03d}"

        invoice = Invoice(
            tenant_id=tenant_id,
            contract_id=contract.id,
            invoice_number=invoice_number,
            status="outstanding",
            total_amount=contract.total_amount,
            balance_due=contract.total_amount,
            paid_amount=Decimal("0"),
        )
        db.add(invoice)
        await db.flush()
