from datetime import date, datetime, timezone
from typing import AsyncGenerator, Optional
from uuid import UUID

from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src.apps.site_admin.models.platform_payment import PlatformPayment
from src.apps.tenants.models.account import Account
from src.apps.tenants.models.subscription import Subscription
from src.core.exceptions import NotFoundError

PLAN_PRICES: dict[str, int] = {"starter": 149, "professional": 349, "enterprise": 749}
PLAN_LABELS: dict[str, str] = {"starter": "Starter", "professional": "Professional", "enterprise": "Enterprise"}


def _pdf_safe(text: str) -> str:
    """Replace characters outside latin-1 so fpdf2 built-in fonts don't crash."""
    return text.encode("latin-1", errors="replace").decode("latin-1")


def _safe_csv_cell(value: str) -> str:
    """Prefix formula-injection characters and wrap in double quotes."""
    s = str(value).replace("\n", " ")
    if s and s[0] in ("=", "+", "-", "@", "\t", "\r"):
        s = f"'{s}"
    return f'"{s.replace(chr(34), chr(34) + chr(34))}"'


class PaymentService:
    def __init__(self, db: AsyncSession):
        self.db = db

    async def get_kpis(self) -> dict:
        db = self.db
        now = datetime.now(timezone.utc)
        first_of_month = date(now.year, now.month, 1)

        total_received_result = await db.execute(
            select(func.coalesce(func.sum(PlatformPayment.amount_cad), 0))
        )
        total_received = float(total_received_result.scalar_one())

        this_month_result = await db.execute(
            select(func.coalesce(func.sum(PlatformPayment.amount_cad), 0)).where(
                PlatformPayment.payment_date >= first_of_month
            )
        )
        this_month = float(this_month_result.scalar_one())

        payment_count_result = await db.execute(select(func.count(PlatformPayment.id)))
        payment_count = payment_count_result.scalar_one()

        # MRR: sum plan prices for all accounts whose latest subscription is active
        subq = (
            select(
                Subscription.account_id,
                func.max(Subscription.created_at).label("max_created"),
            )
            .group_by(Subscription.account_id)
            .subquery()
        )
        active_subs_result = await db.execute(
            select(Account.plan)
            .join(subq, Account.id == subq.c.account_id)
            .join(
                Subscription,
                (Subscription.account_id == subq.c.account_id)
                & (Subscription.created_at == subq.c.max_created),
            )
            .where(Subscription.status == "active")
        )
        active_plans = active_subs_result.scalars().all()
        mrr = float(sum(PLAN_PRICES.get(p.lower(), 0) for p in active_plans))

        month_label = now.strftime("%B %Y")

        return {
            "total_received": total_received,
            "this_month": this_month,
            "this_month_label": month_label,
            "mrr": mrr,
            "payment_count": payment_count,
        }

    def _build_filter(self, plan: Optional[str], source: Optional[str], q: Optional[str]):
        conditions = []
        if plan:
            conditions.append(PlatformPayment.plan == plan.lower())
        if source:
            conditions.append(PlatformPayment.source == source.lower())
        if q:
            term = f"%{q.lower()}%"
            conditions.append(
                or_(
                    func.lower(Account.organization_name).like(term),
                    func.lower(Account.contact_email).like(term),
                    func.lower(PlatformPayment.invoice_no).like(term),
                )
            )
        return conditions

    async def list_payments(
        self,
        *,
        plan: Optional[str] = None,
        source: Optional[str] = None,
        q: Optional[str] = None,
        page: int = 1,
        page_size: int = 20,
    ) -> tuple[list[dict], int]:
        db = self.db
        conditions = self._build_filter(plan, source, q)

        base_stmt = (
            select(PlatformPayment)
            .join(Account, PlatformPayment.account_id == Account.id)
        )
        count_stmt = (
            select(func.count(PlatformPayment.id))
            .join(Account, PlatformPayment.account_id == Account.id)
        )

        if conditions:
            from sqlalchemy import and_
            where = and_(*conditions)
            base_stmt = base_stmt.where(where)
            count_stmt = count_stmt.where(where)

        count_result = await db.execute(count_stmt)
        total = count_result.scalar_one()

        stmt = (
            base_stmt
            .options(selectinload(PlatformPayment.account))
            .order_by(PlatformPayment.payment_date.desc(), PlatformPayment.created_at.desc())
            .offset((page - 1) * page_size)
            .limit(page_size)
        )
        result = await db.execute(stmt)
        payments = result.scalars().all()

        items = [self._serialize(p) for p in payments]
        return items, total

    async def export_csv(
        self,
        *,
        plan: Optional[str] = None,
        source: Optional[str] = None,
        q: Optional[str] = None,
    ) -> AsyncGenerator[str, None]:
        items, _ = await self.list_payments(plan=plan, source=source, q=q, page=1, page_size=100_000)
        header = "Organization,Email,Plan,Amount (CAD),Method,Invoice,Source,Date"
        yield header + "\n"
        for item in items:
            row = ",".join([
                _safe_csv_cell(item["organization"]),
                _safe_csv_cell(item["email"]),
                _safe_csv_cell(item["plan"]),
                _safe_csv_cell(f"{item['amount']:.2f}"),
                _safe_csv_cell(item["method"] or ""),
                _safe_csv_cell(item["invoice_no"] or ""),
                _safe_csv_cell(item["source"]),
                _safe_csv_cell(str(item["payment_date"])),
            ])
            yield row + "\n"

    async def record_payment(self, *, payload: dict) -> PlatformPayment:
        db = self.db
        payment = PlatformPayment(**payload)
        db.add(payment)
        await db.flush()
        await db.refresh(payment)
        return payment

    async def get_payment_with_account(self, payment_id: UUID) -> tuple[PlatformPayment, Account]:
        db = self.db
        stmt = (
            select(PlatformPayment)
            .options(selectinload(PlatformPayment.account))
            .where(PlatformPayment.id == payment_id)
        )
        result = await db.execute(stmt)
        payment = result.scalar_one_or_none()
        if not payment:
            raise NotFoundError("Payment not found")
        return payment, payment.account

    def generate_invoice_pdf(
        self,
        payment: PlatformPayment,
        account: Account,
        trial_ends_at: Optional[date] = None,
    ) -> bytes:
        from fpdf import FPDF

        pdf = FPDF(orientation="P", unit="mm", format="A4")
        pdf.add_page()
        pdf.set_auto_page_break(auto=True, margin=15)

        # Header
        pdf.set_font("Helvetica", "B", 22)
        pdf.cell(0, 10, "INDELIS", ln=True)
        pdf.set_font("Helvetica", size=10)
        pdf.set_text_color(100, 100, 100)
        pdf.cell(0, 5, "Cemetery Management Software", ln=True)
        pdf.cell(0, 5, "hello@indelis.com   Ottawa, ON, Canada", ln=True)
        pdf.set_text_color(0, 0, 0)
        pdf.ln(4)
        pdf.set_draw_color(220, 220, 220)
        pdf.line(10, pdf.get_y(), 200, pdf.get_y())
        pdf.ln(5)

        # Status stamp (right-aligned large text)
        status = (payment.status or "paid").lower()
        stamp_text = {"paid": "PAID", "pending": "PENDING", "failed": "FAILED"}.get(status, "PAID")
        stamp_colors = {
            "PAID": (34, 197, 94),
            "PENDING": (156, 163, 175),
            "FAILED": (239, 68, 68),
        }
        r, g, b = stamp_colors.get(stamp_text, (34, 197, 94))
        pdf.set_text_color(r, g, b)
        pdf.set_font("Helvetica", "B", 28)
        pdf.set_y(20)
        pdf.cell(0, 10, stamp_text, align="R", ln=True)
        pdf.set_text_color(0, 0, 0)
        pdf.set_y(pdf.get_y() + 2)

        # Invoice header
        pdf.set_font("Helvetica", "B", 14)
        pdf.cell(0, 8, "INVOICE", ln=True)
        pdf.set_font("Helvetica", size=10)
        pdf.ln(2)

        if hasattr(payment.payment_date, "strftime"):
            d = payment.payment_date
            payment_date_str = f"{d.strftime('%B')} {d.day}, {d.year}"
        else:
            payment_date_str = str(payment.payment_date)

        def label_value(label: str, value: str):
            pdf.set_font("Helvetica", "B", 10)
            pdf.cell(50, 6, label, ln=False)
            pdf.set_font("Helvetica", size=10)
            pdf.cell(0, 6, value, ln=True)

        invoice_no = _pdf_safe(payment.invoice_no or str(payment.id)[:8].upper())
        label_value("Invoice No:", invoice_no)
        label_value("Date:", payment_date_str)
        label_value("Payment Date:", payment_date_str)

        if trial_ends_at:
            t = trial_ends_at
            trial_str = f"{t.strftime('%B')} {t.day}, {t.year}"
            label_value("Trial Period:", "14 days free")
            label_value("First Charge:", trial_str)

        pdf.ln(4)
        pdf.line(10, pdf.get_y(), 200, pdf.get_y())
        pdf.ln(5)

        # Trial notice banner
        if trial_ends_at:
            t = trial_ends_at
            trial_str = f"{t.strftime('%B')} {t.day}, {t.year}"
            pdf.set_fill_color(236, 253, 245)
            pdf.set_text_color(5, 122, 85)
            pdf.set_font("Helvetica", "B", 10)
            pdf.cell(0, 8, f"  14-Day Free Trial  -  No charge until {trial_str}", fill=True, ln=True)
            pdf.set_font("Helvetica", size=9)
            pdf.cell(0, 6, "  Your card has been saved. You will be charged automatically after the trial ends.", ln=True)
            pdf.set_text_color(0, 0, 0)
            pdf.ln(3)

        # Bill To
        pdf.set_font("Helvetica", "B", 11)
        pdf.cell(0, 6, "BILL TO", ln=True)
        pdf.set_font("Helvetica", size=10)
        pdf.cell(0, 6, _pdf_safe(account.organization_name or "-"), ln=True)
        pdf.cell(0, 6, _pdf_safe(account.contact_email or "-"), ln=True)
        pdf.ln(4)
        pdf.line(10, pdf.get_y(), 200, pdf.get_y())
        pdf.ln(5)

        # Line item table header
        pdf.set_fill_color(245, 245, 245)
        pdf.set_font("Helvetica", "B", 10)
        pdf.cell(140, 7, "Description", border=0, fill=True)
        pdf.cell(0, 7, "Amount (CAD)", border=0, fill=True, align="R", ln=True)
        pdf.line(10, pdf.get_y(), 200, pdf.get_y())
        pdf.ln(1)

        # Line item
        plan_label = PLAN_LABELS.get(payment.plan.lower() if payment.plan else "", payment.plan or "—")
        pdf.set_font("Helvetica", size=10)
        pdf.cell(140, 7, f"{plan_label} Plan - Monthly Subscription")
        pdf.set_font("Helvetica", "B", 10)
        pdf.cell(0, 7, f"${float(payment.amount_cad):,.2f}", align="R", ln=True)
        pdf.ln(1)
        pdf.line(10, pdf.get_y(), 200, pdf.get_y())
        pdf.ln(3)

        # Total
        pdf.set_font("Helvetica", "B", 11)
        pdf.cell(140, 7, "Total")
        pdf.cell(0, 7, f"${float(payment.amount_cad):,.2f}", align="R", ln=True)
        pdf.set_font("Helvetica", size=10)
        pdf.cell(0, 6, _pdf_safe(f"Payment Method: {payment.method or '-'}"), ln=True)
        pdf.ln(6)
        pdf.line(10, pdf.get_y(), 200, pdf.get_y())
        pdf.ln(5)

        # Footer
        pdf.set_font("Helvetica", size=9)
        pdf.set_text_color(120, 120, 120)
        pdf.cell(0, 5, "Questions? Contact hello@indelis.com  +1 800-555-0100", ln=True, align="C")

        return bytes(pdf.output())

    def _serialize(self, p: PlatformPayment) -> dict:
        acc = p.account
        plan_key = (p.plan or "starter").lower()
        return {
            "id": str(p.id),
            "account_id": str(p.account_id),
            "organization": acc.organization_name if acc else "—",
            "email": acc.contact_email if acc else "—",
            "plan": PLAN_LABELS.get(plan_key, p.plan.capitalize() if p.plan else "—"),
            "amount": float(p.amount_cad),
            "method": p.method,
            "invoice_no": p.invoice_no,
            "source": p.source,
            "status": p.status,
            "payment_date": p.payment_date.isoformat(),
            "created_at": p.created_at.isoformat(),
        }
