import os
import sys
import json
import tempfile
import traceback
import asyncio
from typing import List
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, EmailStr
import mysql.connector
from mysql.connector import Error
import bcrypt

# ---------------------------------------------------------------------------
# Database configuration
# ---------------------------------------------------------------------------
DB_CONFIG = {
    "host": "localhost",
    "user": "root",
    "password": "root@123",
    "port": 3306,
}
DB_NAME = "omr_grader_db"


def hash_password(plain: str) -> str:
    return bcrypt.hashpw(plain.encode(), bcrypt.gensalt()).decode()


def verify_password(plain: str, hashed: str) -> bool:
    return bcrypt.checkpw(plain.encode(), hashed.encode())


def get_db_connection():
    conn = mysql.connector.connect(**DB_CONFIG, database=DB_NAME)
    return conn


def init_db():
    """Create database and users table if they don't exist."""
    try:
        conn = mysql.connector.connect(**DB_CONFIG)
        cursor = conn.cursor()
        cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{DB_NAME}`")
        cursor.execute(f"USE `{DB_NAME}`")
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS users (
                id         INT AUTO_INCREMENT PRIMARY KEY,
                firstname  VARCHAR(100)  NOT NULL,
                lastname   VARCHAR(100)  NOT NULL,
                email      VARCHAR(255)  NOT NULL UNIQUE,
                password   VARCHAR(255)  NOT NULL,
                username   VARCHAR(100)  NOT NULL UNIQUE,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)
        conn.commit()
        cursor.close()
        conn.close()
    except Error as e:
        print(f"Database init error: {e}", file=sys.stderr)


# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class AddUserRequest(BaseModel):
    firstname: str
    lastname: str
    email: EmailStr
    password: str
    username: str


class LoginRequest(BaseModel):
    username: str
    password: str


# Ensure the local directory is in the path so we can import omr_grader
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
import omr_grader

app = FastAPI(
    title="OMR Grader API",
    description="Backend API for scanning and grading OMR sheets against an answer key.",
    version="1.0.0",
)

# Enable CORS for Flutter web or general cross-origin clients
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialise DB on startup
init_db()


async def save_upload_file(upload_file: UploadFile, destination_path: str):
    """Write the uploaded file content asynchronously to a local file path."""
    with open(destination_path, "wb") as buffer:
        while True:
            chunk = await upload_file.read(1024 * 1024)  # Read in 1MB chunks
            if not chunk:
                break
            buffer.write(chunk)


async def grade_sheets_generator(
    answer_key: UploadFile, answer_sheets: List[UploadFile]
):
    """Generator that saves files to a temporary directory, processes them,
    and yields NDJSON strings for each sheet as it is graded.
    """
    # Create a temporary directory that will be deleted automatically when leaving this block
    with tempfile.TemporaryDirectory() as tmpdir:
        # 1. Save answer key file
        # We preserve the original filename/extension as fitz/pdf_to_images relies on it
        key_filename = answer_key.filename or "Answers_keys.pdf"
        key_path = os.path.join(tmpdir, key_filename)
        await save_upload_file(answer_key, key_path)

        # Load keys from file
        try:
            keys = omr_grader.load_answer_keys(key_path)
        except Exception as e:
            err_msg = f"Failed to load answer keys: {str(e)}"
            print(err_msg, file=sys.stderr)
            yield json.dumps({"error": err_msg}) + "\n"
            return

        # 2. Process each sheet sequentially
        for sheet_file in answer_sheets:
            sheet_filename = sheet_file.filename or "sheet.pdf"
            sheet_path = os.path.join(tmpdir, sheet_filename)
            await save_upload_file(sheet_file, sheet_path)

            try:
                # Scan sheet to extract marked answers
                # scan_sheet uses Path(file_path).suffix to detect if PDF or image
                answers, pages = omr_grader.scan_sheet(sheet_path)

                # Pick the appropriate key matching this sheet
                key_name, key = omr_grader.pick_key(sheet_path, keys, None)

                # Grade the sheet against selected key
                res = omr_grader.grade_sheet(sheet_filename, answers, key, key_name)

                # Extract handwritten details if model/tesseract is available
                try:
                    res.upn = omr_grader.extract_upn(sheet_path)
                except Exception as upn_err:
                    print(
                        f"UPN extraction failed for {sheet_filename}: {upn_err}",
                        file=sys.stderr,
                    )
                    res.upn = ""

                try:
                    res.name, res.name_conf = omr_grader.extract_name(sheet_path)
                except Exception as name_err:
                    print(
                        f"Name extraction failed for {sheet_filename}: {name_err}",
                        file=sys.stderr,
                    )
                    res.name, res.name_conf = "", 0.0

                # Build stats from graded result so total reflects the key size
                # when a key is present, or the detected count otherwise.
                score = 0
                blank = 0
                wrong_list = []
                answer_rows = []

                for g in res.graded:
                    q = g.question
                    marked = g.marked or ""
                    correct = g.correct or ""

                    if not marked:
                        blank += 1
                    elif marked == correct:
                        score += 1
                    elif correct:
                        wrong_list.append(q)

                    answer_rows.append(
                        {
                            "question": q,
                            "marked": marked,
                            "correct": correct,
                        }
                    )

                total = len(res.graded) if res.graded else len(answers)

                percent = round(100.0 * score / total, 1) if total > 0 else 0.0

                # Construct graded result response object matching sample response structure
                result_data = {
                    "sheet_name": sheet_filename,
                    "score": score,
                    "total": total,
                    "percent": percent,
                    "blank": blank,
                    "wrong_list": wrong_list,
                    "answers": answer_rows,
                    "upn": res.upn,
                    "pupil_name": res.name,
                    "name_confidence": round(res.name_conf, 1),
                }

            except Exception as e:
                traceback.print_exc()
                result_data = {
                    "sheet_name": sheet_filename,
                    "score": 0,
                    "total": 0,
                    "percent": 0.0,
                    "blank": 0,
                    "wrong_list": [],
                    "answers": [],
                    "error": f"Failed to grade sheet: {str(e)}",
                }

            # Yield NDJSON row (JSON string followed by newline)
            yield json.dumps(result_data) + "\n"
            # Give back control to the event loop
            await asyncio.sleep(0.01)


@app.post("/fileUpload")
async def file_upload(
    answer_key: UploadFile = File(...), answer_sheets: List[UploadFile] = File(...)
):
    """Multipart POST handler that accepts an answer key PDF and a list of
    answer sheets to stream grading results back.
    """
    if not answer_sheets:
        raise HTTPException(status_code=400, detail="No answer sheets uploaded.")

    return StreamingResponse(
        grade_sheets_generator(answer_key, answer_sheets),
        media_type="application/x-ndjson",
    )


@app.post("/addUser", status_code=201)
async def add_user(body: AddUserRequest):
    """Register a new user. Returns the new user's id on success."""
    hashed_pw = hash_password(body.password)
    try:
        conn = get_db_connection()
        cursor = conn.cursor()
        cursor.execute(
            """INSERT INTO users (firstname, lastname, email, password, username)
               VALUES (%s, %s, %s, %s, %s)""",
            (body.firstname, body.lastname, body.email, hashed_pw, body.username),
        )
        conn.commit()
        new_id = cursor.lastrowid
        cursor.close()
        conn.close()
    except Error as e:
        if e.errno == 1062:  # Duplicate entry
            field = "email" if "email" in str(e.msg) else "username"
            raise HTTPException(
                status_code=409, detail=f"A user with this {field} already exists."
            )
        raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")
    return {"message": "User created successfully.", "user_id": new_id}


@app.post("/login")
async def login(body: LoginRequest):
    """Authenticate with username + password. Returns basic user info on success."""
    try:
        conn = get_db_connection()
        cursor = conn.cursor(dictionary=True)
        cursor.execute(
            "SELECT id, firstname, lastname, email, username, password FROM users WHERE username = %s",
            (body.username,),
        )
        user = cursor.fetchone()
        cursor.close()
        conn.close()
    except Error as e:
        raise HTTPException(status_code=500, detail=f"Database error: {str(e)}")

    if not user or not verify_password(body.password, user["password"]):
        raise HTTPException(status_code=401, detail="Invalid username or password.")

    return {
        "message": "Login successful.",
        "user": {
            "id": user["id"],
            "firstname": user["firstname"],
            "lastname": user["lastname"],
            "email": user["email"],
            "username": user["username"],
        },
    }


if __name__ == "__main__":
    import uvicorn

    # Bind to 0.0.0.0 so emulator/other local network clients can connect
    uvicorn.run("api_server:app", host="0.0.0.0", port=5000, reload=True)
