from typing import List, Optional, Tuple
from sqlalchemy import select, func, or_
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import HTTPException, status

from app.models.product import Product
from app.models.violation import Violation
from app.models.scraping_result import ScrapingResult
from app.schemas.product import ProductCreate, ProductUpdate
from app.services.pricing_service import calculate_pack_prices


class ProductService:
    @staticmethod
    async def get_violation_count(db: AsyncSession, product_id: int) -> int:
        """Count violations for a specific product."""
        # Get product first
        product_result = await db.execute(select(Product).where(Product.id == product_id))
        product = product_result.scalars().first()
        
        if not product:
            return 0
        
        # Count violations matching this product's name
        count_result = await db.execute(
            select(func.count(Violation.id)).where(
                Violation.product_name == product.product_name
            )
        )
        return count_result.scalar() or 0

    @staticmethod
    async def enrich_product_with_violations(db: AsyncSession, product: Product) -> Product:
        """Add violation_count to a product object by querying violations matching product name."""
        # Count violations matching this product's name
        count_result = await db.execute(
            select(func.count(Violation.id)).where(
                Violation.product_name == product.product_name
            )
        )
        violation_count = count_result.scalar() or 0
        product.violation_count = violation_count  # type: ignore
        return product

    @staticmethod
    async def get_products(
        db: AsyncSession,
        page: int = 1,
        limit: int = 10,
        sort_by: str = "product_name",
        search: Optional[str] = None,
    ) -> Tuple[List[Product], int]:
        offset = (page - 1) * limit
        query = select(Product)

        if search:
            query = query.where(Product.product_name.ilike(f"%{search}%"))

        # Sorting logic
        if sort_by == "msp":
            query = query.order_by(Product.msp)
        elif sort_by == "product_name":
            query = query.order_by(Product.product_name)
        elif sort_by == "last_scraped_date":
            # For now sorting by updated_at as placeholder for last_scraped_date
            query = query.order_by(Product.updated_at.desc())
        else:
            query = query.order_by(Product.product_name)

        # Get total count for pagination
        count_query = select(func.count()).select_from(query.subquery())
        total = await db.scalar(count_query)

        # Apply pagination
        query = query.offset(offset).limit(limit)
        result = await db.execute(query)
        products = result.scalars().all()

        # Enrich products with violation counts
        enriched_products = []
        for product in products:
            product = await ProductService.enrich_product_with_violations(db, product)
            enriched_products.append(product)

        return enriched_products, total

    @staticmethod
    async def create_product(db: AsyncSession, product_in: ProductCreate) -> Product:
        """Create a new product with auto-calculated pack prices."""
        # Calculate pack prices based on MSP
        pack_prices = calculate_pack_prices(float(product_in.msp))
        
        db_product = Product(
            reference_id=product_in.reference_id,
            product_name=product_in.product_name,
            barcode=product_in.barcode,
            msp=product_in.msp,
            status=product_in.status,
            # NEW: Set calculated pack prices
            price_1_pack=pack_prices[1],
            price_2_pack=pack_prices[2],
            price_3_pack=pack_prices[3],
            price_4_pack=pack_prices[4],
            price_5_pack=pack_prices[5],
            price_6_pack=pack_prices[6],
            price_12_pack=pack_prices[12],
        )
        db.add(db_product)
        await db.commit()
        await db.refresh(db_product)
        return db_product

    @staticmethod
    async def update_product(
        db: AsyncSession, product_id: int, product_in: ProductUpdate
    ) -> Product:
        """Update product. If MSP changes, recalculate pack prices."""
        product = await ProductService.get_product_by_id(db, product_id)
        
        update_data = product_in.model_dump(exclude_unset=True)
        
        # If MSP is being updated, recalculate all pack prices
        if "msp" in update_data:
            new_msp = float(update_data["msp"])
            pack_prices = calculate_pack_prices(new_msp)
            update_data.update({
                "price_1_pack": pack_prices[1],
                "price_2_pack": pack_prices[2],
                "price_3_pack": pack_prices[3],
                "price_4_pack": pack_prices[4],
                "price_5_pack": pack_prices[5],
                "price_6_pack": pack_prices[6],
                "price_12_pack": pack_prices[12],
            })
        
        for field, value in update_data.items():
            setattr(product, field, value)
        
        await db.commit()
        await db.refresh(product)
        return product

    @staticmethod
    async def delete_product(db: AsyncSession, product_id: int) -> None:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )

        await db.delete(db_product)
        await db.commit()

    @staticmethod
    async def get_product_by_id(db: AsyncSession, product_id: int) -> Product:
        result = await db.execute(select(Product).where(Product.id == product_id))
        db_product = result.scalars().first()
        if not db_product:
            raise HTTPException(
                status_code=status.HTTP_404_NOT_FOUND,
                detail="Product not found",
            )
        # Enrich product with violation count
        db_product = await ProductService.enrich_product_with_violations(db, db_product)
        return db_product

    @staticmethod
    async def bulk_create_products(
        db: AsyncSession, products_list: List[ProductCreate]
    ) -> dict:
        """
        Bulk create multiple products. Applies the same logic as create_product:
        - Auto-calculates pack prices based on MSP
        - Preserves all discount and pack pricing logic
        - Checks for duplicate barcodes
        
        Returns a summary dict with:
        - total_processed: Number of products in input list
        - successful: Number of products successfully created
        - failed: Number of products that failed
        - skipped_duplicates: Number of duplicate barcodes skipped
        - errors: List of errors encountered
        """
        total_processed = len(products_list)
        successful = 0
        failed = 0
        skipped_duplicates = 0
        errors = []
        
        # Get all existing barcodes in one query
        existing_barcodes_result = await db.execute(
            select(Product.barcode).select_from(Product)
        )
        existing_barcodes = set(existing_barcodes_result.scalars().all())
        
        for idx, product_in in enumerate(products_list, 1):
            try:
                # Check for duplicate barcode
                if product_in.barcode in existing_barcodes:
                    skipped_duplicates += 1
                    errors.append({
                        "row": idx,
                        "product": product_in.product_name,
                        "error": f"Duplicate barcode: {product_in.barcode} already exists"
                    })
                    continue
                
                # Calculate pack prices using the same logic as create_product
                pack_prices = calculate_pack_prices(float(product_in.msp))
                
                # Create product object
                db_product = Product(
                    reference_id=product_in.reference_id,
                    product_name=product_in.product_name,
                    barcode=product_in.barcode,
                    msp=product_in.msp,
                    status=product_in.status,
                    # Apply calculated pack prices with discount and pack logic
                    price_1_pack=pack_prices[1],
                    price_2_pack=pack_prices[2],
                    price_3_pack=pack_prices[3],
                    price_4_pack=pack_prices[4],
                    price_5_pack=pack_prices[5],
                    price_6_pack=pack_prices[6],
                    price_12_pack=pack_prices[12],
                )
                db.add(db_product)
                existing_barcodes.add(product_in.barcode)  # Track locally added barcodes
                successful += 1
                
            except ValueError as e:
                failed += 1
                errors.append({
                    "row": idx,
                    "product": product_in.product_name,
                    "error": f"Validation error: {str(e)}"
                })
            except Exception as e:
                failed += 1
                errors.append({
                    "row": idx,
                    "product": product_in.product_name,
                    "error": f"Database error: {str(e)}"
                })
        
        # Commit all successful inserts at once
        if successful > 0:
            try:
                await db.commit()
            except Exception as e:
                await db.rollback()
                failed += successful
                successful = 0
                errors.append({
                    "error": f"Failed to commit all products: {str(e)}"
                })
        
        return {
            "total_processed": total_processed,
            "successful": successful,
            "failed": failed,
            "skipped_duplicates": skipped_duplicates,
            "errors": errors,
            "message": f"Bulk upload completed: {successful} created, {failed} failed, {skipped_duplicates} duplicates skipped"
        }
