# -*- coding: utf-8 -*-
"""
Batch processing module for AutoQM application.
Handles batching of passages for premium members (정회원).
"""

import logging
from typing import List, Dict, Any, Optional, Callable
from dataclasses import dataclass


logger = logging.getLogger(__name__)


@dataclass
class PassageData:
    """Data structure for a single passage to be processed."""
    title: str
    passage: str
    유형: str
    난이도: str = "Normal"
    option: Optional[Dict[str, Any]] = None
    transformation_data: Optional[Dict[str, Any]] = None  # Transformation settings (paraphrase, difficulty, length)
    row_num: Optional[int] = None  # Excel row number for updating

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for API request."""
        payload = {
            'title': self.title,
            'passage': self.passage,
            '유형': self.유형,
            '난이도': self.난이도,
            'option': self.option or {},
            'transformation': self.transformation_data or {}
        }


        return payload


@dataclass
class BatchResult:
    """Result from processing a single passage in batch."""
    title: str
    result: str
    credits: int
    retries: int
    status: str  # 'success' or 'error'
    row_num: Optional[int] = None
    premadedata_flag: bool = False  # True if EBS premade data
    paraphrased_flag: bool = False
    paraphrased_passage: str = ""


class BatchProcessor:
    """
    Handles batching of passages for efficient processing.
    Only available for premium members (정회원).
    """

    def __init__(
        self,
        batch_size: int = 5,
        is_premium_member: bool = False
    ):
        """
        Initialize batch processor.

        Args:
            batch_size: Maximum number of passages per batch (default: 5)
            is_premium_member: Whether user is a premium member
        """
        self.batch_size = batch_size
        self.is_premium_member = is_premium_member
        self._current_batch: List[PassageData] = []

    def can_use_batch_processing(self) -> bool:
        """Check if batch processing is available."""
        return self.is_premium_member

    def add_passage(self, passage_data: PassageData) -> bool:
        """
        Add a passage to the current batch.

        Args:
            passage_data: Passage to add

        Returns:
            True if batch is full and ready to process
        """
        if not self.is_premium_member:
            logger.warning("Batch processing not available for non-premium members")
            return False

        self._current_batch.append(passage_data)
        logger.debug(f"Added passage to batch: {passage_data.title} ({len(self._current_batch)}/{self.batch_size})")

        return len(self._current_batch) >= self.batch_size

    def get_current_batch(self) -> List[PassageData]:
        """Get the current batch without clearing it."""
        return self._current_batch.copy()

    def clear_batch(self):
        """Clear the current batch."""
        logger.debug(f"Clearing batch with {len(self._current_batch)} items")
        self._current_batch.clear()

    def pop_batch(self) -> List[PassageData]:
        """
        Get current batch and clear it.

        Returns:
            List of passages in batch
        """
        batch = self._current_batch.copy()
        self.clear_batch()
        return batch

    def is_batch_full(self) -> bool:
        """Check if current batch is full."""
        return len(self._current_batch) >= self.batch_size

    def is_batch_empty(self) -> bool:
        """Check if current batch is empty."""
        return len(self._current_batch) == 0

    def get_batch_size(self) -> int:
        """Get current batch size."""
        return len(self._current_batch)

    def should_process_batch(self, force: bool = False) -> bool:
        """
        Determine if batch should be processed now.

        Args:
            force: Force processing even if batch not full

        Returns:
            True if should process now
        """
        if not self.is_premium_member:
            return False

        if self.is_batch_empty():
            return False

        # Process if full or forced
        return self.is_batch_full() or force

    def create_batch_request(
        self,
        username: str,
        user_role: list = None,
        os_name: str = "macOS",
        app_version: str = "v2.4.4"
    ) -> Dict[str, Any]:
        """
        Create API request payload for batch processing.

        Args:
            username: User's username
            user_role: User's roles (e.g., ['um_custom_role_1'])
            os_name: Operating system
            app_version: Application version

        Returns:
            Dictionary ready for API request
        """
        passages = [p.to_dict() for p in self._current_batch]

        return {
            'passages': passages,
            'username': username,
            'user_role': user_role or [],
            'os': os_name,
            'app_version': app_version
        }

    def parse_batch_response(self, response_data: Dict[str, Any]) -> List[BatchResult]:
        """
        Parse API response from batch processing.

        Args:
            response_data: JSON response from /process_batch endpoint

        Returns:
            List of BatchResult objects
        """
        results = []
        results_list = response_data.get('results', [])

        # Map results back to original passages
        for i, result_data in enumerate(results_list):
            # Find corresponding passage data
            passage_data = self._current_batch[i] if i < len(self._current_batch) else None

            batch_result = BatchResult(
                title=result_data.get('title', f'passage_{i+1}'),
                result=result_data.get('result', 'Error'),
                credits=result_data.get('credits', 0),
                retries=result_data.get('retries', 1),
                status=result_data.get('status', 'error'),
                row_num=passage_data.row_num if passage_data else None,
                premadedata_flag=result_data.get('premadedata_flag', False),
                paraphrased_flag=result_data.get('paraphrased_flag', False),
                paraphrased_passage=result_data.get('paraphrased_passage', result_data.get('passage', ""))
            )

            

            results.append(batch_result)

        return results


class BatchQueue:
    """
    Queue for managing multiple batches.
    Automatically creates batches as passages are added.
    """

    def __init__(
        self,
        batch_size: int = 5,
        is_premium_member: bool = False,
        auto_process_callback: Optional[Callable[[List[PassageData]], None]] = None
    ):
        """
        Initialize batch queue.

        Args:
            batch_size: Size of each batch
            is_premium_member: Whether user is premium member
            auto_process_callback: Callback to call when batch is full
        """
        self.batch_size = batch_size
        self.is_premium_member = is_premium_member
        self.auto_process_callback = auto_process_callback

        self._processor = BatchProcessor(batch_size, is_premium_member)
        self._pending_batches: List[List[PassageData]] = []

    def add_passage(self, passage_data: PassageData):
        """
        Add passage to queue. Automatically creates batch when full.

        Args:
            passage_data: Passage to add
        """
        if not self.is_premium_member:
            # For non-premium users, add to pending for individual processing
            self._pending_batches.append([passage_data])
            return

        # Add to current batch
        is_full = self._processor.add_passage(passage_data)

        if is_full:
            # Move batch to pending
            batch = self._processor.pop_batch()
            self._pending_batches.append(batch)

            logger.info(f"Batch full, added to pending queue. Total pending batches: {len(self._pending_batches)}")

            # Call auto-process callback if set
            if self.auto_process_callback:
                self.auto_process_callback(batch)

    def flush_current_batch(self):
        """Force current incomplete batch to pending queue."""
        if not self._processor.is_batch_empty():
            batch = self._processor.pop_batch()
            self._pending_batches.append(batch)
            logger.info(f"Flushed incomplete batch with {len(batch)} items")

    def get_next_batch(self) -> Optional[List[PassageData]]:
        """Get next batch to process."""
        if self._pending_batches:
            return self._pending_batches.pop(0)
        return None

    def get_pending_count(self) -> int:
        """Get number of pending batches."""
        return len(self._pending_batches)

    def get_total_pending_passages(self) -> int:
        """Get total number of pending passages across all batches."""
        total = sum(len(batch) for batch in self._pending_batches)
        total += self._processor.get_batch_size()  # Add current batch
        return total

    def clear_all(self):
        """Clear all pending batches and current batch."""
        self._pending_batches.clear()
        self._processor.clear_batch()
        logger.info("Cleared all pending batches")


# Helper functions for integration with existing code

def create_passage_data_from_excel(
    row_num: int,
    title: str,
    passage: str,
    유형: str,
    난이도: str = "Normal",
    option: Optional[Dict] = None,
    transformation_data: Optional[Dict] = None
) -> PassageData:
    """
    Create PassageData from Excel row data.

    Args:
        row_num: Excel row number
        title: Passage title/number
        passage: Passage text
        유형: Question type
        난이도: Difficulty (Normal/Hard)
        option: Additional options
        transformation_data: Transformation settings (paraphrase_enabled, difficulty_level, length_level)

    Returns:
        PassageData object
    """
    return PassageData(
        title=title,
        passage=passage,
        유형=유형,
        난이도=난이도,
        option=option,
        transformation_data=transformation_data,
        row_num=row_num
    )
