"""Invoice service — business logic for billing/invoices."""
from __future__ import annotations

import csv
import io
import logging
from datetime import date, datetime
from decimal import Decimal
from typing import AsyncGenerator, List, Optional, Tuple
from uuid import UUID

from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src.apps.billing.models.invoice import Invoice
from src.apps.billing.models.invoice_payment import InvoicePayment
from src.core.exceptions import ForbiddenError, NotFoundError

logger = logging.getLogger(__name__)


class InvoiceService:
    # ── List ──────────────────────────────────────────────────────────────────

    async def get_list(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        search: Optional[str] = None,
        status: Optional[str] = None,
        page: int = 1,
        page_size: int = 20,
    ) -> Tuple[List[Invoice], int]:
        """Paginated invoice list with optional search and status filter."""
        filters = [Invoice.tenant_id == tenant_id]

        if search:
            pattern = f"%{search}%"
            filters.append(
                or_(
                    Invoice.purchaser_name.ilike(pattern),
                    Invoice.invoice_number.ilike(pattern),
                )
            )

        if status:
            filters.append(Invoice.status == status)

        # Total count
        count_stmt = select(func.count()).select_from(Invoice).where(*filters)
        total: int = (await db.execute(count_stmt)).scalar_one()

        # Paginated rows — eager-load payments
        offset = (page - 1) * page_size
        stmt = (
            select(Invoice)
            .options(selectinload(Invoice.payments))
            .where(*filters)
            .order_by(Invoice.created_at.desc())
            .offset(offset)
            .limit(page_size)
        )
        result = await db.execute(stmt)
        items: List[Invoice] = list(result.scalars().all())

        return items, total

    # ── Single ────────────────────────────────────────────────────────────────

    async def get_by_id(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        invoice_id: UUID,
    ) -> Invoice:
        """Fetch a single invoice with payments, raise NotFoundError if missing."""
        stmt = (
            select(Invoice)
            .options(selectinload(Invoice.payments))
            .where(
                Invoice.id == invoice_id,
                Invoice.tenant_id == tenant_id,
            )
        )
        invoice = (await db.execute(stmt)).scalar_one_or_none()
        if not invoice:
            raise NotFoundError("Invoice not found")
        return invoice

    # ── Record Payment ────────────────────────────────────────────────────────

    async def record_payment(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        invoice_id: UUID,
        amount: Decimal,
        payment_date: date,
        payment_method: str,
        reference_number: Optional[str] = None,
        user_id: Optional[UUID] = None,
    ) -> Invoice:
        """Record a payment against an invoice and recalculate balances."""
        invoice = await self.get_by_id(db, tenant_id, invoice_id)

        if invoice.balance_due <= 0:
            raise ForbiddenError("Invoice already fully paid")

        payment = InvoicePayment(
            invoice_id=invoice.id,
            tenant_id=tenant_id,
            amount=amount,
            method=payment_method,
            received_on=payment_date,
            receipt_number=reference_number,
            recorded_by=user_id,
        )
        db.add(payment)

        new_paid = (invoice.paid_amount or Decimal("0")) + amount
        new_balance = invoice.total_amount - new_paid

        invoice.paid_amount = new_paid
        invoice.balance_due = max(Decimal("0"), new_balance)
        if new_balance <= 0:
            invoice.status = "paid"
        elif invoice.status != "overdue":
            invoice.status = "partial"
        # if already overdue and still has balance, leave as overdue

        await db.flush()
        return invoice

    # ── Send Reminder ─────────────────────────────────────────────────────────

    async def send_reminder(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        invoice_id: UUID,
        user_id: Optional[UUID] = None,
    ) -> Invoice:
        """Log a reminder and update tracking fields. Email via SES is TODO."""
        invoice = await self.get_by_id(db, tenant_id, invoice_id)

        logger.info(
            "[billing] Reminder triggered for invoice %s (tenant=%s, count=%d)",
            invoice_id,
            tenant_id,
            invoice.reminder_count,
        )

        invoice.last_reminder_at = datetime.utcnow()
        invoice.reminder_count = (invoice.reminder_count or 0) + 1

        await db.flush()
        return invoice

    # ── CSV Injection Guard ───────────────────────────────────────────────────

    @staticmethod
    def _csv_safe(value: str | None) -> str:
        """Prefix formula-triggering characters to prevent CSV injection."""
        if value and value[0] in ('=', '+', '-', '@', '\t', '\r'):
            return '\t' + value
        return value or ''

    # ── Export CSV ────────────────────────────────────────────────────────────

    async def export_csv(
        self,
        db: AsyncSession,
        tenant_id: UUID | str,
        search: Optional[str] = None,
        status: Optional[str] = None,
    ) -> AsyncGenerator[str, None]:
        """Async generator yielding CSV rows (header first)."""
        # Fetch all matching invoices (no pagination for export)
        filters = [Invoice.tenant_id == tenant_id]
        if search:
            pattern = f"%{search}%"
            filters.append(
                or_(
                    Invoice.purchaser_name.ilike(pattern),
                    Invoice.invoice_number.ilike(pattern),
                )
            )
        if status:
            filters.append(Invoice.status == status)

        stmt = (
            select(Invoice)
            .where(*filters)
            .order_by(Invoice.created_at.desc())
        )
        result = await db.execute(stmt)
        invoices: List[Invoice] = list(result.scalars().all())

        # Yield rows via in-memory StringIO so csv.writer handles escaping
        buf = io.StringIO()
        writer = csv.writer(buf)

        writer.writerow(
            ["Invoice #", "Purchaser", "Total", "Paid", "Outstanding", "Due", "Status"]
        )
        yield buf.getvalue()
        buf.truncate(0)
        buf.seek(0)

        for inv in invoices:
            writer.writerow(
                [
                    self._csv_safe(inv.invoice_number),
                    self._csv_safe(inv.purchaser_name),
                    str(inv.total_amount),
                    str(inv.paid_amount),
                    str(inv.balance_due),
                    inv.due_date.isoformat() if inv.due_date else "",
                    inv.status,
                ]
            )
            yield buf.getvalue()
            buf.truncate(0)
            buf.seek(0)

    # ── Transition Overdue ────────────────────────────────────────────────────

    async def transition_overdue(
        self,
        db: AsyncSession,
        tenant_id: Optional[UUID | str] = None,
    ) -> int:
        """
        Mark past-due invoices as 'overdue' and queue reminder emails at
        7, 14, and 30-day thresholds (based on reminder_count to avoid dupes).

        Returns the number of invoices processed.
        """
        today = date.today()

        filters = [
            Invoice.status.in_(["outstanding", "partial"]),
            Invoice.due_date < today,
            Invoice.due_date.isnot(None),
        ]
        if tenant_id is not None:
            filters.append(Invoice.tenant_id == tenant_id)

        stmt = select(Invoice).where(*filters)
        result = await db.execute(stmt)
        invoices: List[Invoice] = list(result.scalars().all())

        processed = 0
        for invoice in invoices:
            # Transition to overdue
            invoice.status = "overdue"

            # Determine whether a 7 / 14 / 30-day reminder is due
            days_overdue = (today - invoice.due_date).days
            reminder_count = invoice.reminder_count or 0

            # reminder_count tracks how many reminders have been sent:
            #   0 → no reminders yet  → send at 7 days
            #   1 → 1 sent            → send at 14 days
            #   2 → 2 sent            → send at 30 days
            should_remind = (
                (reminder_count == 0 and days_overdue >= 7)
                or (reminder_count == 1 and days_overdue >= 14)
                or (reminder_count == 2 and days_overdue >= 30)
            )

            if should_remind:
                logger.info(
                    "[scheduler] Queuing reminder for invoice %s (day %d, count %d)",
                    invoice.id,
                    days_overdue,
                    reminder_count,
                )
                invoice.last_reminder_at = datetime.utcnow()
                invoice.reminder_count = reminder_count + 1

            processed += 1

        if invoices:
            await db.flush()

        return processed
