#!/usr/bin/env python3
# =============================================================================
# OMR ANSWER-SHEET BATCH GRADER  (Python + OpenCV)  --  GL Assessment style
# =============================================================================
#
# Reads an ANSWER KEY (PDF or text, in the GL "Familiarisation" layout), then
# scans and GRADES one or many filled answer sheets against it, all at once.
#
# The OMR detection engine (contour-based cell detection + ink scoring) is the
# original scanner; this version adds:
#   * answer-key parsing  (supports several keys in one file, e.g. Fam 1/2/3)
#   * continuous question numbering 1..N across all pages of a sheet
#   * batch processing of many sheets (files and/or folders)
#   * grading: score, %, blanks, ambiguous, per-question wrong list
#   * key auto-selection per sheet (by filename) or forced with --key-name
#   * graded overlay images (green = correct, red = wrong, amber = blank/multi)
#   * a combined results CSV
#
# ---------------------------------------------------------------------------
# INSTALL
#   pip install opencv-python numpy pymupdf
#       (pymupdf is needed for PDF sheets AND for reading a PDF answer key;
#        pdf2image+poppler also works for sheet rendering as a fallback)
#   pip install onnxruntime      (optional, strongly recommended)
#       enables the MNIST CNN that reads the handwritten Unique Pupil Number.
#       The ~26 KB model auto-downloads to ./models/mnist-12.onnx on first use.
#       Without it, UPN falls back to Tesseract (poor on handwriting).
#
# USAGE
#   # Grade every sheet in a folder against a key file:
#   python omr_grader.py sheets/ --key Answers_keys.pdf
#
#   # Grade specific files, forcing which key to use for all of them:
#   python omr_grader.py s1.pdf s2.pdf --key keys.pdf --key-name 1
#
#   # Just scan (no grading) -- omit --key:
#   python omr_grader.py sheet.pdf
#
# KEY SELECTION (when the key file holds more than one test)
#   1. --key-name X         use that key for every sheet (X = number or name)
#   2. otherwise, guess from each sheet's filename
#                           (e.g. "NVR1_*", "fam2", "paper3" -> key 1/2/3)
#   3. if the file holds exactly one key, it is used for everything.
#
# MAIN FLAGS
#   --key FILE        answer-key PDF or .txt   (omit to only scan)
#   --key-name X      force a key for all sheets
#   --out-dir DIR     where graded overlays go        (default: omr_out)
#   --csv FILE        combined per-question results    (default: omr_out/results.csv)
#   --no-overlay      skip writing overlay images (faster)
#   --options N       force options-per-question       (default: auto)
#   --order rows|cols question numbering order         (default: rows)
#   --dpi N           DPI for PDF rendering             (default: 200)
#   --questions N     expected questions per sheet (for a sanity warning)
# =============================================================================

from __future__ import annotations

import argparse
import csv
import glob
import os
import re
import statistics
import string
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import cv2
import numpy as np

LETTERS = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]

# ---------------------------------------------------------------------------
# Legacy cell-detection filters  (kept for detect_grid compatibility)
# ---------------------------------------------------------------------------
ASPECT_MIN, ASPECT_MAX = 1.1, 3.5
CELL_W_MIN, CELL_W_MAX = 0.010, 0.065
CELL_H_MIN, CELL_H_MAX = 0.005, 0.025

MIN_COL_FRACTION = 0.25

EXCESS_MIN    = 0.07
EXCESS_MARGIN = 0.05

IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}

# ---------------------------------------------------------------------------
# New detection constants
# ---------------------------------------------------------------------------
LEFT_IGNORE_RATIO   = 0.30   # ignore left 30% (question numbers / labels)
TOP_IGNORE_RATIO    = 0.08   # ignore top 8%  (page header)
BOTTOM_IGNORE_RATIO = 0.99   # ignore below 99% (footer)
MIN_OPTIONS         = 2      # minimum options per question card
MAX_OPTIONS         = 8      # maximum options per question card

# When the median strength of detected marks falls below this, the page is a
# low-contrast photo and is illumination-normalised before scoring.  Clean
# scans (median ~0.6-0.73) stay untouched to avoid amplifying faint artifacts.
LOW_CONTRAST_MEDIAN = 0.55

PUPIL_NUMBER_Y_MAX_RATIO = 0.22   # UPN cells are in the top ~22% of the page
PUPIL_NUMBER_MIN_CELLS   = 5      # minimum cells to count as a valid UPN row
PUPIL_NUMBER_CELLS       = 13     # GL Assessment UPN field has exactly 13 boxes


# ===========================================================================
#  Data classes
# ===========================================================================
@dataclass
class OmrAnswer:
    question: int
    choice: Optional[str]
    excess: list = field(default_factory=list)
    multi_mark: bool = False

    def __str__(self) -> str:
        tag = "  *** MULTIPLE MARKS ***" if self.multi_mark else ""
        return f"Q{self.question:>2}: {self.choice or '-'}{tag}"


@dataclass
class GradedQuestion:
    question: int
    marked: Optional[str]
    correct: Optional[str]
    status: str          # "correct" | "wrong" | "blank" | "multi" | "no-key"

    @property
    def is_correct(self) -> bool:
        return self.status == "correct"


@dataclass
class SheetResult:
    sheet: str
    key_name: Optional[str]
    graded: List[GradedQuestion]
    upn: str = ""
    name: str = ""
    name_conf: float = 0.0

    @property
    def total(self) -> int:
        return sum(1 for g in self.graded if g.correct is not None)

    @property
    def score(self) -> int:
        return sum(1 for g in self.graded if g.is_correct)

    @property
    def percent(self) -> float:
        return 100.0 * self.score / self.total if self.total else 0.0

    @property
    def blanks(self) -> int:
        return sum(1 for g in self.graded if g.status == "blank")

    @property
    def multis(self) -> int:
        return sum(1 for g in self.graded if g.status == "multi")


@dataclass(frozen=True)
class Box:
    x: int
    y: int
    w: int
    h: int

    @property
    def cx(self) -> float:
        return self.x + self.w / 2

    @property
    def cy(self) -> float:
        return self.y + self.h / 2

    @property
    def area(self) -> int:
        return self.w * self.h


# ===========================================================================
#  ANSWER-KEY PARSING
# ===========================================================================
HEADER_RE = re.compile(r"Familiari[sz]ation\s*([0-9]+)", re.IGNORECASE)
ANSWER_RE = re.compile(r"(?<!\d)(\d{1,3})\s*[.)]\s*([A-Ja-j])(?![A-Za-z])")


def _parse_key_block(text: str) -> Dict[int, str]:
    """Pull {question_number: LETTER} out of one key's text."""
    key: Dict[int, str] = {}
    for num, letter in ANSWER_RE.findall(text):
        key[int(num)] = letter.upper()
    return key


def parse_keys_from_text(text: str) -> "Dict[str, Dict[int, str]]":
    """
    Split a text blob into one or more keys.

    Each key is introduced by a "...Familiarisation N..." header.  If no
    header is present the whole text becomes a single key named 'default'.
    Returns {key_name: {q: letter}}.
    """
    headers = list(HEADER_RE.finditer(text))
    keys: Dict[str, Dict[int, str]] = {}

    if not headers:
        block = _parse_key_block(text)
        if block:
            keys["default"] = block
        return keys

    for i, h in enumerate(headers):
        start = h.end()
        end = headers[i + 1].start() if i + 1 < len(headers) else len(text)
        name = h.group(1)
        block = _parse_key_block(text[start:end])
        if block:
            keys[name] = block
    return keys


# Answer keys are printed as N section-columns of "question. letter" rows
# (GL Assessment: 4 sections x 20 questions = 80).  These constants describe
# that grid for the image-OCR fallback below.
KEY_SECTIONS     = 4
KEY_PER_SECTION  = 20


def _key_binarize(img: np.ndarray) -> np.ndarray:
    """Flat-field + median + Otsu — yields clean black text on white paper.

    The median blur (kernel 5) is essential for phone photos / screenshots,
    whose moire / paper texture otherwise survives Otsu and breaks the
    single-character OCR of the answer letters.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
    k = max(31, (min(gray.shape) // 8) | 1)
    flat = cv2.divide(gray, cv2.GaussianBlur(gray, (k, k), 0), scale=255)
    flat = cv2.medianBlur(flat, 5)
    return cv2.threshold(flat, 0, 255,
                         cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]


def _key_token_centres(bw: np.ndarray) -> List[float]:
    """x-centres of every number / single-letter token, for column finding.

    OCR runs on a downscaled copy: column positions don't need full
    resolution, and a 0.4x image is several times faster to OCR.
    """
    import pytesseract
    scale = 0.4
    small = cv2.resize(bw, None, fx=scale, fy=scale,
                       interpolation=cv2.INTER_AREA)
    data = pytesseract.image_to_data(small, config="--psm 6",
                                     output_type=pytesseract.Output.DICT)
    centres: List[float] = []
    for i in range(len(data["text"])):
        t = data["text"][i].strip().upper().replace("£", "E")
        w = data["width"][i]
        if not t or data["height"][i] == 0:
            continue
        if re.fullmatch(r"\d{1,3}[.)]?", t):
            q = int(re.sub(r"\D", "", t))
            if 1 <= q <= KEY_SECTIONS * KEY_PER_SECTION:
                centres.append((data["left"][i] + w / 2) / scale)
        elif re.fullmatch(r"[A-E][.)]?", t):
            centres.append((data["left"][i] + w / 2) / scale)
    return sorted(centres)


def _key_column_ranges(bw: np.ndarray) -> List[Tuple[int, int]]:
    """Split the page into KEY_SECTIONS column x-ranges at the widest gaps."""
    centres = _key_token_centres(bw)
    if len(centres) < KEY_SECTIONS:
        return []
    gaps = sorted(range(1, len(centres)),
                  key=lambda i: centres[i] - centres[i - 1],
                  reverse=True)[:KEY_SECTIONS - 1]
    cuts = sorted(gaps)
    groups: List[list] = []
    prev = 0
    for c in cuts:
        groups.append(centres[prev:c])
        prev = c
    groups.append(centres[prev:])

    w = bw.shape[1]
    ranges: List[Tuple[int, int]] = []
    for g in groups:
        if not g:
            continue
        ranges.append((max(0, int(min(g) - w * 0.03)),
                       min(w, int(max(g) + w * 0.06))))
    return ranges


def _key_row_bands(col: np.ndarray) -> List[Tuple[int, int]]:
    """Horizontal-projection text rows within one isolated column."""
    h = col.shape[0]
    ink = (col < 128).sum(axis=1).astype(float)
    active = ink > max(3.0, ink.max() * 0.04)
    bands: List[Tuple[int, int]] = []
    start: Optional[int] = None
    for y in range(h):
        if active[y] and start is None:
            start = y
        elif not active[y] and start is not None:
            bands.append((start, y))
            start = None
    if start is not None:
        bands.append((start, h))
    if not bands:
        return []
    med = float(np.median([b - a for a, b in bands]))
    return [(a, b) for a, b in bands if 0.5 * med <= (b - a) <= 2.5 * med]


def _ocr_key_image(img: np.ndarray) -> Dict[int, str]:
    """Read a photographed / scanned printed answer-key table by OCR.

    This is the fallback for image-based key PDFs (e.g. a phone-camera photo)
    whose text layer is empty.  The key is a grid of section-columns, each a
    list of "<question>. <letter>" rows.  Strategy:

      1. binarise (flat-field + median + Otsu),
      2. split into the section columns at the widest inter-token gaps,
      3. OCR each whole column in ONE pass (PSM 6) and read its rows in order
         — column c, row r maps to question c*PER_SECTION + r + 1; the printed
         number is used when legible and consistent, else row order decides,
      4. for any row still unread, OCR just that one line (PSM 7) as a fallback.

    OCR is the slow part, so the per-column pass (one call per section) is used
    instead of one call per row; the per-line fallback fires only for the rare
    row the column pass misses.  Returns {question: letter}, or {} if OCR is
    unavailable.
    """
    if not _ensure_tesseract():
        return {}
    import pytesseract

    bw = _key_binarize(img)
    h, w = bw.shape
    ranges = _key_column_ranges(bw)
    if not ranges:
        return {}

    key: Dict[int, str] = {}
    for col_idx, (x1, x2) in enumerate(ranges):
        base = col_idx * KEY_PER_SECTION
        col = bw[:, x1:x2]
        bands = _key_row_bands(col)
        # The answer rows are the bottom-most KEY_PER_SECTION bands; any rows
        # above them (the "Section N" header) are dropped.
        rows = bands[-KEY_PER_SECTION:] if len(bands) >= KEY_PER_SECTION else bands

        # 1) One OCR pass over the whole (header-trimmed) column.
        sub = col
        if rows:
            sub = col[max(0, rows[0][0] - 6):min(h, rows[-1][1] + 6), :]
        sub = cv2.copyMakeBorder(sub, 15, 15, 25, 25,
                                 cv2.BORDER_CONSTANT, value=255)
        text = pytesseract.image_to_string(sub, config="--psm 6").replace("£", "E")
        lines = [ln for ln in text.splitlines() if re.search(r"[A-Ea-e]", ln)]
        for i, ln in enumerate(lines[:KEY_PER_SECTION]):
            m = re.search(r"(?<!\d)(\d{1,3})\s*[^A-Za-z0-9]{0,4}\s*([A-Ea-e])(?![A-Za-z])",
                          ln)
            if m:
                q = int(m.group(1))
                letter = m.group(2).upper()
                if not (base + 1 <= q <= base + KEY_PER_SECTION):
                    q = base + i + 1
            else:
                found = re.findall(r"[A-Ea-e](?![A-Za-z])", ln)
                if not found:
                    continue
                q = base + i + 1
                letter = found[-1].upper()
            key.setdefault(q, letter)

        # 2) Per-line fallback for any row the column pass left unread.
        if len(rows) == KEY_PER_SECTION:
            for row_idx, (y1, y2) in enumerate(rows):
                q = base + row_idx + 1
                if q in key:
                    continue
                line = col[max(0, y1 - 4):min(h, y2 + 4), :]
                line = cv2.copyMakeBorder(line, 12, 12, 20, 20,
                                          cv2.BORDER_CONSTANT, value=255)
                t = pytesseract.image_to_string(
                    line, config="--psm 7").replace("£", "E")
                ml = re.search(r"([A-Ea-e])(?![A-Za-z])", t)
                if ml:
                    key[q] = ml.group(1).upper()

    return key


def _ocr_key_header_number(img: np.ndarray) -> Optional[str]:
    """OCR the page header to recover the 'Familiarisation N' number, if any."""
    if not _ensure_tesseract():
        return None
    try:
        import pytesseract
        top = img[:int(img.shape[0] * 0.18), :]
        text = pytesseract.image_to_string(top, config="--psm 6")
        m = HEADER_RE.search(text)
        return m.group(1) if m else None
    except Exception:
        return None


def load_answer_keys(path: str, dpi: int = 200) -> "Dict[str, Dict[int, str]]":
    """
    Load answer keys from a PDF or a plain-text file.

    For PDFs each PAGE is parsed independently (so per-page question numbers
    that restart at 1 do not collide), then named by the Familiarisation
    number found on that page (falling back to the page index).

    The PDF text layer is tried first (exact and fast for born-digital keys).
    If a page yields no answers — which happens when the key is a phone-camera
    photo or scan with no text layer — the page is rendered to an image and
    read with OCR (see ``_ocr_key_image``).
    """
    ext = Path(path).suffix.lower()
    keys: Dict[str, Dict[int, str]] = {}

    if ext == ".pdf":
        import fitz  # pymupdf
        doc = fitz.open(path)
        for pidx in range(doc.page_count):
            text = doc[pidx].get_text()
            m = HEADER_RE.search(text)
            name = m.group(1) if m else f"page{pidx + 1}"
            block = _parse_key_block(_strip_header_numbers(text, m))
            if not block:
                # No text layer — render the page and OCR the printed key.
                img = pdf_to_images(path, dpi=dpi, page_num=pidx + 1)[0]
                # Flatten a genuine angled phone photo (a bright page sitting on
                # a darker background).  _deskew_sheet is conservative: it only
                # warps when it finds a clearly inset quad, so flat scans /
                # screenshots — whose grey-ish background defeats the heavier
                # camera de-warp — are left untouched.
                img = _deskew_sheet(img)
                block = _ocr_key_image(img)
                if block and not m:
                    hdr = _ocr_key_header_number(img)
                    if hdr:
                        name = hdr
            if block:
                if name in keys:
                    keys[name].update(block)
                else:
                    keys[name] = block
        doc.close()
    else:
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            text = f.read()
        keys = parse_keys_from_text(text)

    if not keys:
        raise RuntimeError(f"No answers could be parsed from key file: {path}")
    return keys


def _strip_header_numbers(text: str, header_match) -> str:
    text = HEADER_RE.sub("Familiarisation", text)
    text = re.sub(r"\bPage\s+\d+", "Page", text, flags=re.IGNORECASE)
    text = re.sub(r"\bSection\s+\d+", "Section", text, flags=re.IGNORECASE)
    return text


# ===========================================================================
#  UNIQUE PUPIL NUMBER extraction (sheet header)
#
#  Three strategies, tried in order:
#    1. PDF text layer  -- exact and fast for "born-digital" sheets.
#    2. CV digit detection -- contour-based cell finder + feature classifier,
#       no external dependency required.
#    3. OCR fallback     -- Tesseract via pytesseract; if unavailable returns "".
# ===========================================================================
def _upn_from_text(file_path: str) -> str:
    """Read UPN from a PDF text layer using PyMuPDF. Returns "" on failure."""
    try:
        import fitz
    except ImportError:
        return ""
    try:
        doc = fitz.open(file_path)
        page = doc[0]
        words = page.get_text("words")
        page_w = page.rect.width
        doc.close()

        upn_lab = next((w for w in words
                        if w[4].strip().upper() == "UNIQUE"), None)
        if not upn_lab:
            return ""
        top = upn_lab[1]
        sch = [w for w in words
               if w[4].strip().upper() == "SCHOOL" and abs(w[1] - top) < 6]
        right = min((w[0] for w in sch), default=page_w)
        left = upn_lab[0] - 6
        digits = [w for w in words
                  if w[4].strip().isdigit()
                  and (top + 4) <= w[1] <= (top + 60)
                  and left <= w[0] < (right - 8)]
        digits.sort(key=lambda w: w[0])
        return "".join(w[4].strip() for w in digits)
    except Exception:
        return ""


# --- Tesseract discovery -----------------------------------------------------
_TESS_STATE = {"checked": False, "ok": False, "warned": False}


def _ensure_tesseract() -> bool:
    if _TESS_STATE["checked"]:
        return _TESS_STATE["ok"]
    _TESS_STATE["checked"] = True
    try:
        import pytesseract
    except ImportError:
        _warn_no_tesseract("the 'pytesseract' package is not installed "
                           "(pip install pytesseract)")
        return False

    import shutil
    if shutil.which("tesseract"):
        _TESS_STATE["ok"] = True
        return True

    candidates = []
    for var in ("ProgramFiles", "ProgramFiles(x86)", "ProgramW6432", "LOCALAPPDATA"):
        base = os.environ.get(var)
        if base:
            candidates.append(os.path.join(base, "Tesseract-OCR", "tesseract.exe"))
    candidates += glob.glob(r"C:\Program Files*\Tesseract-OCR\tesseract.exe")
    candidates += ["/opt/homebrew/bin/tesseract", "/usr/local/bin/tesseract",
                   "/usr/bin/tesseract", "/opt/local/bin/tesseract"]
    for c in candidates:
        if c and os.path.isfile(c):
            import pytesseract
            pytesseract.pytesseract.tesseract_cmd = c
            _TESS_STATE["ok"] = True
            return True

    _warn_no_tesseract("the Tesseract program was not found")
    return False


def _warn_no_tesseract(reason: str) -> None:
    if _TESS_STATE["warned"]:
        return
    _TESS_STATE["warned"] = True
    print(f"  ! Pupil number needs OCR for scanned sheets, but {reason}.\n"
          f"    Install it:  Windows -> https://github.com/UB-Mannheim/tesseract/wiki\n"
          f"                 macOS   -> brew install tesseract\n"
          f"                 Linux   -> sudo apt install tesseract-ocr\n"
          f"    Then:        pip install pytesseract\n"
          f"    (Digital PDFs still read the pupil number without Tesseract.)",
          file=sys.stderr)


def _upn_via_ocr(file_path: str, dpi: int = 200) -> str:
    if not _ensure_tesseract():
        return ""
    try:
        import pytesseract
    except ImportError:
        return ""
    try:
        img = _get_sheet_page1(file_path, max(dpi, NAME_RENDER_DPI))
        if img is None:
            return ""
        H, W = img.shape[:2]
        data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
        n = len(data["text"])
        tok = lambda i: data["text"][i].strip().upper()
        ip = next((i for i in range(n) if tok(i) == "PUPIL"), None)
        if ip is None:
            return ""
        ly, lh = data["top"][ip], data["height"][ip]
        same = lambda i: abs(data["top"][i] - ly) < lh * 1.2
        iu = next((i for i in range(n) if tok(i) == "UNIQUE" and same(i)), None)
        isch = next((i for i in range(n) if tok(i) == "SCHOOL" and same(i)), None)
        left = max(0, (data["left"][iu] - int(4 * lh)) if iu is not None
                   else data["left"][ip] - int(4 * lh))
        right = min(W, (data["left"][isch] - int(1.3 * lh)) if isch is not None
                    else int(W * 0.52))
        if right - left < 10:
            return ""
        crop = img[int(ly + lh * 1.05):int(ly + lh * 3.35), left:right]
        sc = 3
        red = cv2.resize(crop[:, :, 2], None, fx=sc, fy=sc,
                         interpolation=cv2.INTER_CUBIC)
        h, w = red.shape
        dark = cv2.threshold(red, 150, 255, cv2.THRESH_BINARY_INV)[1]
        vk = cv2.getStructuringElement(cv2.MORPH_RECT, (1, int(h * 0.7)))
        vmask = cv2.morphologyEx(dark, cv2.MORPH_OPEN, vk)
        nlab, lab, stats, _ = cv2.connectedComponentsWithStats(vmask, 8)
        sep = np.zeros_like(vmask)
        for k in range(1, nlab):
            if stats[k, cv2.CC_STAT_WIDTH] <= max(3, int(0.10 * lh * sc)):
                sep[lab == k] = 255
        sep = cv2.dilate(sep, cv2.getStructuringElement(cv2.MORPH_RECT, (3, 1)))
        hk = cv2.getStructuringElement(cv2.MORPH_RECT, (max(25, w // 5), 1))
        hmask = cv2.morphologyEx(dark, cv2.MORPH_OPEN, hk)
        clean = cv2.subtract(dark, cv2.bitwise_or(sep, hmask))
        clean = cv2.medianBlur(clean, 3)
        out = cv2.copyMakeBorder(cv2.bitwise_not(clean), 18, 18, 18, 18,
                                 cv2.BORDER_CONSTANT, value=255)
        txt = pytesseract.image_to_string(
            out, config="--psm 7 -c tessedit_char_whitelist=0123456789")
        return re.sub(r"\D", "", txt)
    except Exception:
        return ""


# ---------------------------------------------------------------------------
#  Perspective correction for phone photos of a sheet
# ---------------------------------------------------------------------------
SHEET_DESKEW_MIN_INSET = 0.08   # warp only when a corner is inset > 8% (a photo)


def _deskew_sheet(img: np.ndarray) -> np.ndarray:
    """Flatten a phone photo of an answer sheet to a head-on rectangle.

    Detects the bright sheet quadrilateral sitting on a darker background and
    perspective-warps it upright.  Only fires when the quad is clearly inset /
    skewed (a real photo); flat scans and digital PDFs already fill the frame,
    so they are returned unchanged.  This keeps the UPN row near its expected
    position so cell detection works on angled photos.
    """
    try:
        h, w = img.shape[:2]
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
        blur = cv2.GaussianBlur(gray, (5, 5), 0)
        _, th = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, np.ones((15, 15), np.uint8))
        cnts, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if not cnts:
            return img
        c = max(cnts, key=cv2.contourArea)
        if cv2.contourArea(c) < 0.30 * h * w:      # sheet should dominate frame
            return img
        approx = cv2.approxPolyDP(c, 0.02 * cv2.arcLength(c, True), True)
        if len(approx) != 4:
            return img

        pts = approx.reshape(4, 2).astype(np.float32)
        s = pts.sum(1)
        d = np.diff(pts, axis=1).reshape(-1)
        tl, tr = pts[np.argmin(s)], pts[np.argmin(d)]
        br, bl = pts[np.argmax(s)], pts[np.argmax(d)]
        src = np.array([tl, tr, br, bl], np.float32)

        img_corners = np.array([[0, 0], [w, 0], [w, h], [0, h]], np.float32)
        max_inset = max(float(np.linalg.norm(src[i] - img_corners[i]))
                        for i in range(4)) / max(w, h)
        if max_inset < SHEET_DESKEW_MIN_INSET:     # already head-on (a scan)
            return img

        out_w = int(max(np.linalg.norm(br - bl), np.linalg.norm(tr - tl)))
        out_h = int(max(np.linalg.norm(tr - br), np.linalg.norm(tl - bl)))
        if out_w < 100 or out_h < 100:
            return img
        dst = np.array([[0, 0], [out_w - 1, 0],
                        [out_w - 1, out_h - 1], [0, out_h - 1]], np.float32)
        M = cv2.getPerspectiveTransform(src, dst)
        return cv2.warpPerspective(img, M, (out_w, out_h))
    except Exception:
        return img


# ---------------------------------------------------------------------------
#  Handwritten-digit recognition for the UPN (MNIST CNN via onnxruntime)
# ---------------------------------------------------------------------------
_MNIST_MODEL_PATH = Path(__file__).resolve().parent / "models" / "mnist-12.onnx"
_MNIST_MODEL_URL  = ("https://github.com/onnx/models/raw/main/validated/"
                     "vision/classification/mnist/model/mnist-12.onnx")
_mnist_session = None
_mnist_input   = None
_mnist_tried   = False


def _load_mnist():
    """Lazily load the MNIST ONNX classifier, downloading it once if missing.

    Tesseract is unreliable on handwritten digits; a small MNIST CNN reads the
    UPN row almost perfectly.  Returns (session, input_name), or (None, None)
    when onnxruntime is absent or the model cannot be obtained.
    """
    global _mnist_session, _mnist_input, _mnist_tried
    if _mnist_tried:
        return _mnist_session, _mnist_input
    _mnist_tried = True
    try:
        import onnxruntime as ort
    except ImportError:
        return None, None
    try:
        if not _MNIST_MODEL_PATH.exists():
            _MNIST_MODEL_PATH.parent.mkdir(parents=True, exist_ok=True)
            import urllib.request
            urllib.request.urlretrieve(_MNIST_MODEL_URL, str(_MNIST_MODEL_PATH))
        sess = ort.InferenceSession(str(_MNIST_MODEL_PATH),
                                    providers=["CPUExecutionProvider"])
        _mnist_session = sess
        _mnist_input   = sess.get_inputs()[0].name
    except Exception:
        _mnist_session, _mnist_input = None, None
    return _mnist_session, _mnist_input


def _mnist_prep(cell_gray: np.ndarray) -> Optional[np.ndarray]:
    """Convert a grayscale digit-cell crop to a 28x28 MNIST-style image
    (white ink on black, scaled to 20px and centred by centre-of-mass).
    Returns None when the cell holds no ink."""
    if cell_gray.size == 0:
        return None
    # Absolute-darkness gate: Otsu always splits a cell into "ink"/"paper" even
    # when it is blank noise, so first confirm the cell actually contains dark
    # ink.  An empty cell's darkest pixels are ~as bright as its paper; a real
    # digit is markedly darker.  This drops the empty leading UPN cells that
    # would otherwise be misread as a spurious digit prefix.
    dark = float(np.percentile(cell_gray, 5))
    light = float(np.percentile(cell_gray, 90))
    if light <= 1 or dark > 0.70 * light:
        return None
    _, b = cv2.threshold(cell_gray, 0, 255,
                         cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    ys, xs = np.where(b > 0)
    if len(xs) < 3:
        return None
    digit = b[ys.min():ys.max() + 1, xs.min():xs.max() + 1]
    h, w = digit.shape
    scale = 20.0 / max(h, w)
    nh, nw = max(1, int(round(h * scale))), max(1, int(round(w * scale)))
    digit = cv2.resize(digit, (nw, nh), interpolation=cv2.INTER_AREA)

    canvas = np.zeros((28, 28), np.uint8)
    ys2, xs2 = np.where(digit > 0)
    if len(xs2) == 0:
        return None
    cy, cx = float(ys2.mean()), float(xs2.mean())
    oy, ox = int(round(14 - cy)), int(round(14 - cx))
    y0, x0   = max(0, oy), max(0, ox)
    dy0, dx0 = max(0, -oy), max(0, -ox)
    hh = min(nh - dy0, 28 - y0)
    ww = min(nw - dx0, 28 - x0)
    if hh <= 0 or ww <= 0:
        return None
    canvas[y0:y0 + hh, x0:x0 + ww] = digit[dy0:dy0 + hh, dx0:dx0 + ww]
    return canvas


def _upn_via_mnist(cells: List[Box], gray: np.ndarray) -> str:
    """Classify each detected UPN cell with the MNIST CNN.
    Returns '' when the model is unavailable so the caller can fall back."""
    sess, iname = _load_mnist()
    if sess is None:
        return ""
    digits: List[str] = []
    for c in cells:
        bpx = max(3, int(min(c.w, c.h) * 0.15))   # trim cell border
        roi = gray[c.y + bpx:c.y + c.h - bpx, c.x + bpx:c.x + c.w - bpx]
        canvas = _mnist_prep(roi)
        if canvas is None:
            continue
        x = canvas.astype(np.float32).reshape(1, 1, 28, 28)
        out = sess.run(None, {iname: x})[0][0]
        digits.append(str(int(np.argmax(out))))
    return "".join(digits)


# ---------------------------------------------------------------------------
# Shared first-page image cache — avoids rendering the PDF twice when both
# UPN and name extraction need page 1 (saves one pdf_to_images + one deskew).
# ---------------------------------------------------------------------------
_page1_cache: Dict[Tuple[str, int], Optional[np.ndarray]] = {}


def _get_sheet_page1(file_path: str, dpi: int) -> Optional[np.ndarray]:
    """Render and deskew page 1 of a sheet, caching the result.

    A second call with the same (path, dpi) returns the cached array without
    re-rendering.  The cache is module-level and lives for the duration of the
    process, so sequential sheets do not interfere with each other.
    """
    abs_path = os.path.abspath(file_path)
    key = (abs_path, dpi)
    if key in _page1_cache:
        return _page1_cache[key]
    img: Optional[np.ndarray] = None
    try:
        ext = Path(file_path).suffix.lower()
        if ext == ".pdf":
            imgs = pdf_to_images(file_path, dpi=dpi, page_num=1)
            img = imgs[0] if imgs else None
        else:
            img = cv2.imread(file_path, cv2.IMREAD_COLOR)
        if img is not None:
            img = _deskew_sheet(img)
    except Exception:
        img = None
    _page1_cache[key] = img
    return img


def _upn_via_cv(file_path: str, dpi: int = 200) -> str:
    """Read the UPN by locating the digit-cell row and classifying each cell.

    Preferred path: the MNIST CNN (accurate on handwritten digits).  If the
    model is unavailable, fall back to a Tesseract pass over a cleaned canvas
    of the digit strip.  Returns '' when nothing can be read so the caller
    falls through to ``_upn_via_ocr``.
    """
    try:
        img = _get_sheet_page1(file_path, max(dpi, NAME_RENDER_DPI))
        if img is None:
            return ""
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Try grayscale → blue channel → green channel.  Red-ink cell borders
        # show strongly in the blue channel; green splits them from black ink
        # and helps on sheets where grey smudging confuses the grayscale path.
        cells = detect_digit_cells(gray)
        if not cells and img.ndim == 3:
            cells = detect_digit_cells(img[:, :, 0])   # BGR[0] = blue
        if not cells and img.ndim == 3:
            cells = detect_digit_cells(img[:, :, 1])   # BGR[1] = green

        if not cells:
            return ""

        # Preferred: MNIST CNN per cell.
        mnist_val = _upn_via_mnist(cells, gray)
        if mnist_val:
            return mnist_val

        # Fallback: Tesseract over a cleaned canvas of the whole strip.
        if not _ensure_tesseract():
            return ""
        import pytesseract

        # Bounding strip around all detected cells.
        # Use a tiny top margin so the "UNIQUE PUPIL NUMBER" label printed
        # above the cell row is excluded — it confuses Tesseract PSM 7.
        # Extend the left edge by one cell-spacing to capture the first cell
        # when detect_digit_cells misses it (its x-boundary is slightly off).
        margin_h = max(6, int(cells[0].h * 0.20))  # horizontal margin
        margin_v = 4                                 # vertical: just enough for border

        if len(cells) >= 2:
            spacings = [cells[i + 1].cx - cells[i].cx for i in range(len(cells) - 1)]
            med_spacing = float(np.median(spacings))
        else:
            med_spacing = float(cells[0].w)

        x1 = max(0, int(min(c.x for c in cells) - med_spacing))   # one cell left
        x2 = min(img.shape[1], max(c.x + c.w for c in cells) + margin_h)
        y1 = max(0, min(c.y for c in cells) - margin_v)            # no header
        y2 = min(img.shape[0], max(c.y + c.h for c in cells) + margin_v)
        strip = img[y1:y2, x1:x2]
        if strip.size == 0:
            return ""

        # Scale up 3× for better OCR accuracy.
        sc = 3
        strip_gray = cv2.cvtColor(strip, cv2.COLOR_BGR2GRAY) if strip.ndim == 3 else strip
        strip_lg = cv2.resize(strip_gray, None, fx=sc, fy=sc,
                              interpolation=cv2.INTER_CUBIC)
        h_s, w_s = strip_lg.shape[:2]

        # OTSU threshold: ink (dark) → white (255), background → black (0).
        _, binary = cv2.threshold(
            strip_lg, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)

        # Canvas approach: copy only each cell's *inner* content onto a white
        # canvas, skipping the cell-border pixels entirely.  This avoids the
        # thick printed grid lines that confuse Tesseract layout analysis.
        bpx = max(4, int(min(cells[0].w, cells[0].h) * 0.12)) * sc
        canvas = np.ones((h_s, w_s), np.uint8) * 255   # white background
        for cell in cells:
            rx = int((cell.x - x1) * sc)
            ry = int((cell.y - y1) * sc)
            cw = int(cell.w * sc)
            ch = int(cell.h * sc)
            ix1 = max(0, rx + bpx)
            iy1 = max(0, ry + bpx)
            ix2 = min(w_s, rx + cw - bpx)
            iy2 = min(h_s, ry + ch - bpx)
            if ix2 > ix1 and iy2 > iy1:
                # binary: ink=255, bg=0 → invert so ink=0 (black) on canvas
                canvas[iy1:iy2, ix1:ix2] = cv2.bitwise_not(
                    binary[iy1:iy2, ix1:ix2])

        # Slightly thicken strokes so thin handwritten digits (e.g. "1")
        # are recognised more reliably.
        ink = cv2.bitwise_not(canvas)                       # ink = white
        ink = cv2.dilate(ink, np.ones((2, 2), np.uint8))
        out = cv2.copyMakeBorder(cv2.bitwise_not(ink),
                                 15, 15, 15, 15,
                                 cv2.BORDER_CONSTANT, value=255)

        txt = pytesseract.image_to_string(
            out, config="--psm 7 -c tessedit_char_whitelist=0123456789")
        return re.sub(r"\D", "", txt)
    except Exception:
        return ""


def extract_upn(file_path: str, dpi: int = 200) -> str:
    """Read UPN from a sheet: PDF text layer → CV digits → Tesseract OCR.

    The text-layer result is only trusted when it looks complete (>=8 digits).
    A shorter result (e.g. a partial token like "12345") means the text layer
    only captured part of the number, so we fall through to the CV path and
    prefer whichever result is longer.
    """
    text_val = ""
    if Path(file_path).suffix.lower() == ".pdf":
        text_val = _upn_from_text(file_path)
        if len(text_val) >= 8:
            return text_val
    cv_val = _upn_via_cv(file_path, dpi)
    if len(cv_val) > len(text_val):
        return cv_val
    if cv_val:
        return cv_val
    if text_val:
        return text_val
    return _upn_via_ocr(file_path, dpi)


# ===========================================================================
#  PUPIL NAME extraction (handwritten "Pupil's Name" field)
#
#  The name is hand-printed, which Tesseract reads only moderately well, so
#  this is best-effort: it returns the recognised text plus a 0-100 confidence
#  so a caller can flag low-confidence reads for manual review.  The field is
#  located relative to the (robustly detected) Unique Pupil Number cell row,
#  which works for both sheet layouts (name printed below the label, or beside
#  it).  The printed red labels are dropped by colour before OCR.
# ===========================================================================
NAME_RENDER_DPI = 300          # name OCR needs more resolution than grading
_NAME_LABEL_WORDS = {"PUPIL", "PUPILS", "PUPILSNAME", "NAME",
                     "SCHOOL", "SCHOOLNAME", "DATE", "TEST", "BIRTH"}
_NAME_WHITELIST = ("-c tessedit_char_whitelist="
                   "ABCDEFGHIJKLMNOPQRSTUVWXYZ ")
# Characters / patterns extremely unlikely in genuine Indian names (in
# standard English transliteration).  A match flags the OCR result as a
# probable merged-stroke artifact and triggers the adaptive-crop fallback.
_NAME_SUSPECT_RE = re.compile(r"Q|UU|VV")


def _name_band(img: np.ndarray) -> Optional[np.ndarray]:
    """Crop the pupil-name box, anchored to the detected UPN cell row.

    On GL Assessment sheets the name text box starts near the left page margin,
    well to the left of where the UPN digits begin.  Using x1=0 (rather than
    anchoring on the UPN cell x-position) prevents the first few letters of the
    name from being clipped out of the crop.
    """
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
    cells = detect_digit_cells(gray)
    if not cells and img.ndim == 3:
        cells = detect_digit_cells(img[:, :, 0])   # blue channel fallback
    if not cells:
        # UPN anchor not found — fall back to a fixed page-fraction crop that
        # covers the GL Assessment name box (rows ~4-14 %, cols ~20-82 %).
        h, w = img.shape[:2]
        yt, yb = int(h * 0.04), int(h * 0.14)
        xl, xr = int(w * 0.20), int(w * 0.82)
        if yb > yt and xr > xl:
            return img[yt:yb, xl:xr]
        return None
    ys = [c.y for c in cells]
    xs = [c.x for c in cells]
    y0 = min(ys)
    ch = int(np.median([c.h for c in cells]))
    cw = int(np.median([c.w for c in cells]))
    cx = min(xs)
    h, w = gray.shape
    # Use a fixed page-fraction anchor (22%) rather than anchoring to cx.
    # Some sheets have the UPN block starting further right than others
    # (cx can vary by up to 8% of page width across scans), which would
    # cause cx-anchored bands to miss the name's first letters on those
    # sheets.  22% matches the typical name-box left boundary.
    x1 = max(0, int(w * 0.22))
    x2 = min(w, cx + int(13 * cw))
    yt = max(0, y0 - int(4.8 * ch))
    yb = max(0, y0 - int(2.5 * ch))
    if yb <= yt or x2 <= x1:
        return None
    return img[yt:yb, x1:x2]


def _name_handwriting_crop(
    band: np.ndarray,
) -> Optional[Tuple[np.ndarray, Optional[np.ndarray]]]:
    """Isolate the handwritten name line and return it ready for OCR.

    Returns (primary_crop, adaptive_crop_or_None).
    primary_crop  – flat-field + Otsu binarised, 1.5× upscaled.
    adaptive_crop – raw-gray adaptive-threshold binarised with the same
                    spatial bounds; None when generation fails.  Used as a
                    fallback in _ocr_name when the primary pipeline merges
                    adjacent letter strokes (e.g. produces 'Q' from 'RTH').

    Pipeline:
      1. Erase red-ink label pixels (the printed "Pupil's Name:" text).
      2. Flat-field illumination correction + Otsu binarise.
      3. Remove printed box rules (long horizontal elements).
      4. Horizontal projection → pick densest ink strip = handwriting.
      5. Upscale 1.5× bicubic + unsharp mask for crisper letter edges.
    """
    b = band[:, :, 0].astype(int)
    g = band[:, :, 1].astype(int)
    r = band[:, :, 2].astype(int)
    red = (r - g > 18) & (r - b > 18)
    gray = cv2.cvtColor(band, cv2.COLOR_BGR2GRAY).copy()
    gray[red] = 255

    hh, ww = gray.shape
    k = max(15, (min(gray.shape) // 4) | 1)
    flat = cv2.divide(gray, cv2.GaussianBlur(gray, (k, k), 0), scale=255)
    bw = cv2.threshold(flat, 0, 255,
                       cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
    ink = cv2.bitwise_not(bw)

    # Remove printed horizontal rules (box lines).
    hk = cv2.getStructuringElement(cv2.MORPH_RECT, (max(20, ww // 3), 1))
    ink = cv2.subtract(ink, cv2.morphologyEx(ink, cv2.MORPH_OPEN, hk))

    proj = (ink > 0).sum(axis=1).astype(float)
    if proj.max() < 5:
        return None
    active = proj > proj.max() * 0.18
    bands: List[Tuple[int, int]] = []
    start: Optional[int] = None
    for y in range(hh):
        if active[y] and start is None:
            start = y
        elif not active[y] and start is not None:
            bands.append((start, y))
            start = None
    if start is not None:
        bands.append((start, hh))
    if not bands:
        return None
    # Pick the topmost band that has a plausible density rather than the
    # densest one.  When the form includes both a "Pupil's Name" and a
    # "School Name" row in the crop, the school name is usually longer and
    # would win the density contest even though the pupil name sits above it.
    # Using a relative threshold (35 % of the densest band's average) ignores
    # sparse noise at the top of scanned sheets while still accepting a pupil-
    # name line that is somewhat shorter than the school-name line below it.
    _densest = max(bands, key=lambda b: proj[b[0]:b[1]].sum())
    _densest_h = _densest[1] - _densest[0]
    _min_avg = proj[_densest[0]:_densest[1]].mean() * 0.35
    # A genuine handwriting strip spans at least 25 rows at 300 DPI.
    # Shorter clusters are letter-top fragments or noise — skip them.
    _first_qualifying = next(
        (b for b in bands
         if (b[1] - b[0]) >= 25 and proj[b[0]:b[1]].mean() >= _min_avg),
        _densest,
    )
    _first_h = _first_qualifying[1] - _first_qualifying[0]
    # If the densest band is substantially taller (>1.5×) than the first
    # qualifying band it is the main content row (e.g. a scanned sheet where
    # a form element precedes the handwriting).  Otherwise take the first
    # qualifying band so that a pupil name (shorter) beats a school name
    # (longer/denser) that sits below it.
    if _densest_h > _first_h * 1.5:
        y1, y2 = _densest
    else:
        y1, y2 = _first_qualifying

    # Two strips with different margins:
    #  • strip_broad (12 px each side): used for density-based detection.
    #    With ≥ 12 px padding the strip height is large enough that genuine
    #    handwriting letter strokes (density 0.55–0.80) stay clearly below
    #    the 0.90 form-border threshold.
    #  • strip_tight (4 px top / 2 px bottom): used for the OCR crop.
    #    The tighter margins keep the printed "Pupil's Name" label rows
    #    (which sit above the handwriting band, separated by white space)
    #    out of the image fed to Tesseract.
    strip_broad = ink[max(0, y1 - 12):min(hh, y2 + 12), :]
    strip_tight = ink[max(0, y1 -  4):min(hh, y2 +  2), :].copy()

    col_proj       = (strip_broad > 0).sum(axis=0)
    col_proj_tight = (strip_tight > 0).sum(axis=0)
    strip_h        = strip_broad.shape[0]
    strip_h_tight  = strip_tight.shape[0]
    if strip_h == 0 or strip_h_tight == 0:
        return None

    # Form-field borders are vertical lines spanning the full strip height
    # (density ≈ 1.0).  Letter strokes stay at 0.55–0.80 when strip_h ≥ 77.
    col_density = col_proj.astype(float) / strip_h
    search_end  = int(ww * 0.40)
    borders     = np.where(col_density[:search_end] > 0.90)[0]
    scan_from   = int(borders[-1]) + 1 if len(borders) > 0 else 0

    # Extend scan_from past any residual form-border columns whose density
    # exceeds 0.70 in the tight strip (catches thick border right-edges that
    # fall just below the 0.90 broad-strip threshold).
    while (scan_from < search_end
           and col_proj_tight[scan_from] > strip_h_tight * 0.70):
        scan_from += 1

    # Cluster analysis on the broad strip.
    ink_mask = (col_proj > 0).astype(int)
    trans    = np.diff(np.concatenate([[0], ink_mask, [0]]))
    cl_starts = np.where(trans == 1)[0]
    cl_ends   = np.where(trans == -1)[0]
    if len(cl_starts) == 0:
        return None

    # Skip low-density clusters (printed "Pupil's Name" label characters,
    # density 0.10–0.30) and land on the first genuine handwriting cluster.
    x1c = int(cl_starts[0])
    for idx in range(len(cl_starts)):
        s, e = int(cl_starts[idx]), int(cl_ends[idx])
        if e <= scan_from:
            continue
        s_eff = max(s, scan_from)
        if s_eff >= e:
            continue
        if col_proj[s_eff:e].max() >= strip_h * 0.35:
            x1c = s_eff
            break

    # Refine x1c with the tight strip: when there is only one broad cluster
    # (form border + noise + handwriting fused), the cluster analysis lands
    # x1c at scan_from (just past the form border), but the actual first ink
    # of handwriting may be many columns to the right.  Walk forward until we
    # find a column that has density ≥ 28 % sustained for ≥ 7 consecutive
    # columns — that is reliably a letter stroke rather than sparse noise.
    _tight_thr = strip_h_tight * 0.28
    _run_min   = 7
    for col in range(x1c, min(ww, x1c + 250)):
        if col_proj_tight[col] >= _tight_thr:
            run_end = col + 1
            while (run_end < ww
                   and col_proj_tight[run_end] >= strip_h_tight * 0.10):
                run_end += 1
            if run_end - col >= _run_min:
                x1c = col
                break

    # Mirror on the right using the tight strip, which is unaffected by the
    # small noise bands that sometimes sit just outside the handwriting zone
    # in the broad strip and push x2c all the way to the image edge.
    x2c = x1c
    for col in range(ww - 1, x1c, -1):
        if col_proj_tight[col] >= strip_h_tight * 0.05:
            x2c = col
            break

    if x1c >= ww or x2c - x1c < 5:
        return None
    # crop_x0 ≥ scan_from so the form border stays outside the crop,
    # with up to 4 px of slack to avoid clipping the first letter edge.
    crop_x0 = max(scan_from, x1c - 4)
    crop = cv2.bitwise_not(strip_tight[:, crop_x0:min(ww, x2c + 4)])

    # 1.5× upscale: large enough to sharpen letter detail without producing
    # images so wide that Tesseract's layout analyser mis-estimates resolution.
    upscaled = cv2.resize(crop, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC)

    # Unsharp mask: sharpens letter edges so Tesseract discriminates
    # similar shapes (e.g. 'A'/'H', 'P'/'R') more reliably.
    blur = cv2.GaussianBlur(upscaled, (0, 0), 1.0)
    sharpened = cv2.addWeighted(upscaled, 1.5, blur, -0.5, 0)
    primary = np.clip(sharpened, 0, 255).astype(np.uint8)

    # Adaptive-threshold secondary crop (same spatial bounds).
    # Raw-gray adaptive binarisation preserves stroke separation in sheets
    # where flat-field + Otsu merges adjacent letter strokes into blobs that
    # Tesseract misreads (e.g. 'RTH' → 'Q').  A 2-pixel-tall removal kernel
    # catches thicker printed form lines that survive the 1-pixel kernel.
    try:
        adapt_bw = cv2.adaptiveThreshold(
            gray, 255,
            cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 51, 10,
        )
        adapt_ink = cv2.bitwise_not(adapt_bw)
        hk2 = cv2.getStructuringElement(cv2.MORPH_RECT, (max(20, ww // 3), 2))
        adapt_ink = cv2.subtract(
            adapt_ink, cv2.morphologyEx(adapt_ink, cv2.MORPH_OPEN, hk2)
        )
        adapt_strip = adapt_ink[max(0, y1 - 4):min(hh, y2 + 2), :]
        adapt_raw = cv2.bitwise_not(adapt_strip[:, crop_x0:min(ww, x2c + 4)])
        adapt_up = cv2.resize(
            adapt_raw, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC
        )
        adapt_blur = cv2.GaussianBlur(adapt_up, (0, 0), 1.0)
        adaptive = np.clip(
            cv2.addWeighted(adapt_up, 1.5, adapt_blur, -0.5, 0), 0, 255
        ).astype(np.uint8)
    except Exception:
        adaptive = None

    return primary, adaptive


def _ocr_single_crop(crop: np.ndarray) -> Tuple[str, float, float, str, int]:
    """Run all PSM modes on *crop*; return (text, score, conf, wl_text, wl_chars)."""
    import pytesseract
    padded = cv2.copyMakeBorder(crop, 30, 30, 30, 30,
                                cv2.BORDER_CONSTANT, value=255)
    best_text, best_score, best_conf = "", -1.0, 0.0
    wl_text, wl_chars = "", 0
    for cfg in (f"--psm 7 {_NAME_WHITELIST}",
                f"--psm 6 {_NAME_WHITELIST}",
                f"--psm 11 {_NAME_WHITELIST}",
                f"--psm 8 {_NAME_WHITELIST}",
                f"--psm 13 {_NAME_WHITELIST}",
                "--psm 7"):
        data = pytesseract.image_to_data(padded, config=cfg,
                                         output_type=pytesseract.Output.DICT)
        toks: List[Tuple[int, str]] = []
        confs: List[float] = []
        for i in range(len(data["text"])):
            tok = re.sub(r"[^A-Za-z]", "", data["text"][i]).upper()
            try:
                conf = float(data["conf"][i])
            except ValueError:
                conf = -1.0
            if tok and conf >= 0 and tok not in _NAME_LABEL_WORDS:
                toks.append((data["left"][i], tok))
                confs.append(conf)
        if not toks:
            continue
        toks.sort()
        text = " ".join(t for _, t in toks)
        mean_conf = float(np.mean(confs))
        n = len(text.replace(" ", ""))
        # Add a per-character bonus beyond 3 so that a long low-confidence
        # read (e.g. "SARTHAK VETA" at conf=1.5) beats a single-character
        # high-confidence fragment (e.g. "S" at conf=44).
        score = mean_conf * min(n, 12) + max(0, n - 3) * 4
        if score > best_score:
            best_text, best_score, best_conf = text, score, mean_conf
        if _NAME_WHITELIST in cfg:
            n = len(text.replace(" ", ""))
            if n > wl_chars:
                wl_text, wl_chars = text, n
    return best_text, best_score, best_conf, wl_text, wl_chars


def _ocr_name(
    crops: Union[np.ndarray, Tuple[np.ndarray, Optional[np.ndarray]]],
) -> Tuple[str, float]:
    """OCR an isolated handwritten name; return (text, confidence 0-100).

    *crops* is either a single ndarray (primary only) or the tuple returned
    by _name_handwriting_crop: (primary_crop, adaptive_crop_or_None).

    When both crops are available, the adaptive result is preferred if it
    yields ≥ 12 % more characters than the primary — a reliable indicator
    that flat-field + Otsu merged adjacent strokes into a spurious glyph.
    """
    if isinstance(crops, tuple):
        primary, adaptive = crops
    else:
        primary, adaptive = crops, None

    p_text, p_score, p_conf, p_wl, p_wlc = _ocr_single_crop(primary)

    # Low-confidence primary: prefer whitelist read if it is longer.
    if p_conf < 15 and p_wlc >= 8:
        p_text, p_conf = p_wl, p_conf
    else:
        # Short-fragment override: a token of ≤ 2 letters usually indicates
        # bad word segmentation (e.g. "DHARAMPA L" where "L" is a stray
        # fragment).  Only fires when the primary already read a plausible
        # total length (≥ 5 chars) so that genuinely short/empty reads
        # (e.g. a blank sheet returning "S") are left alone.
        p_tokens = p_text.split()
        p_nchars  = len(p_text.replace(" ", ""))
        if (p_nchars >= 5
                and p_wlc >= 8
                and p_wlc >= p_nchars
                and any(len(t) <= 2 for t in p_tokens)):
            p_text, p_conf = p_wl, p_conf

    # Very-short-read override: the best-scoring result has ≤ 3 non-space
    # characters but the whitelist pass captured a substantially longer read.
    # Happens when PSM 11 picks a single high-confidence initial letter while
    # PSM 7 segmented the full name at low per-letter confidence.
    if len(p_text.replace(" ", "")) <= 3 and p_wlc >= 5:
        p_text = p_wl

    # Suspect-pattern whitelist override: when the best primary result
    # contains merged-stroke artifacts (Q = merged R/T, UU = doubled vowel)
    # and the whitelist pass found a longer, cleaner read, prefer the latter.
    # PSM 8 (single-word mode) often recovers more characters from heavy ink.
    if (_NAME_SUSPECT_RE.search(p_text)
            and p_wl
            and not _NAME_SUSPECT_RE.search(p_wl)
            and p_wlc > len(p_text.replace(" ", ""))):
        p_text = p_wl

    if adaptive is None:
        return p_text, p_conf

    a_text, a_score, a_conf, a_wl, a_wlc = _ocr_single_crop(adaptive)

    # Adaptive whitelist logic is more aggressive than the primary's: also
    # prefer the whitelist read when it captures strictly more characters than
    # the highest-scoring non-whitelist result.  The adaptive image is noisier;
    # the A-Z whitelist often avoids fragmentation (e.g. "SART TA" vs
    # "SRTHAKGUPTA" for the same handwriting).
    a_nchars = len(a_text.replace(" ", ""))
    if (a_conf < 15 and a_wlc >= 8) or (a_wlc > a_nchars):
        a_text = a_wl
    a_chars = len(a_text.replace(" ", ""))
    p_chars = len(p_text.replace(" ", ""))

    # Prefer adaptive when it recovers ≥ 12 % more characters (merged strokes
    # in the Otsu pipeline).  Only kicks in when adaptive provides a plausible
    # name length (≥ 8 chars).
    if a_chars >= 8 and a_chars > max(p_chars, 7) * 1.12:
        return a_text, a_conf

    # Prefer adaptive when primary contains patterns nearly absent in Indian
    # names (Q = merged R/T blob; UU = doubled-stroke vowel) and the adaptive
    # read avoids them and produces a long-enough result.
    if (_NAME_SUSPECT_RE.search(p_text)
            and not _NAME_SUSPECT_RE.search(a_text)
            and a_chars >= 8):
        return a_text, a_conf

    return p_text, p_conf


def extract_name(file_path: str, dpi: int = NAME_RENDER_DPI) -> Tuple[str, float]:
    """Read the handwritten pupil name from a sheet's first page.

    Returns (name, confidence 0-100).  Best-effort: handwriting recognition is
    imperfect, so an empty string or a low confidence means the read should be
    treated as unreliable.  Renders at >=300 DPI because name OCR needs more
    resolution than the grading pass.
    """
    if not _ensure_tesseract():
        return "", 0.0
    try:
        render_dpi = max(dpi, NAME_RENDER_DPI)
        img = _get_sheet_page1(file_path, render_dpi)
        if img is None:
            return "", 0.0
        band = _name_band(img)
        if band is None:
            return "", 0.0
        crops = _name_handwriting_crop(band)
        if crops is None:
            return "", 0.0
        name, conf = _ocr_name(crops)

        # Low-confidence 3× retry.  At 1.5× adjacent strokes merge into wrong
        # glyphs ("VETA" for "GUPTA"); at 4.5× total (3× additional) they
        # separate and Tesseract recovers the correct letters.  Only fires
        # when conf < 5.0 so high-confidence sheets are completely unaffected.
        if conf < 5.0 and isinstance(crops, tuple):
            primary_c, adaptive_c = crops
            best_n = len(name.replace(" ", ""))
            for src in [s for s in (primary_c, adaptive_c) if s is not None]:
                try:
                    src2 = cv2.resize(src, None, fx=3.0, fy=3.0,
                                      interpolation=cv2.INTER_CUBIC)
                    t2, _, c2, wl2, wlc2 = _ocr_single_crop(src2)
                    t2_use = wl2 if wlc2 >= len(t2.replace(" ", "")) else t2
                    n2 = len(t2_use.replace(" ", ""))
                    if n2 >= best_n and not _NAME_SUSPECT_RE.search(t2_use):
                        name, conf = t2_use, c2
                        best_n = n2
                        break
                except Exception:
                    pass

        return name, conf
    except Exception:
        return "", 0.0


# ===========================================================================
#  PDF -> image conversion  (pymupdf first, pdf2image fallback)
# ===========================================================================
def pdf_to_images(pdf_path: str, dpi: int = 200,
                  page_num: Optional[int] = None) -> List[np.ndarray]:
    try:
        import fitz  # pymupdf
        doc = fitz.open(pdf_path)
        pages = range(doc.page_count) if page_num is None else [page_num - 1]
        images = []
        mat = fitz.Matrix(dpi / 72.0, dpi / 72.0)
        for p in pages:
            pix = doc[p].get_pixmap(matrix=mat, alpha=False)
            arr = np.frombuffer(pix.samples, dtype=np.uint8).reshape(
                pix.height, pix.width, pix.n)
            if pix.n == 4:
                arr = cv2.cvtColor(arr, cv2.COLOR_RGBA2BGR)
            else:
                arr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
            images.append(arr)
        doc.close()
        return images
    except ImportError:
        pass

    try:
        from pdf2image import convert_from_path  # type: ignore
        kw: dict = {"dpi": dpi}
        if page_num is not None:
            kw["first_page"] = page_num
            kw["last_page"] = page_num
        return [cv2.cvtColor(np.array(im), cv2.COLOR_RGB2BGR)
                for im in convert_from_path(pdf_path, **kw)]
    except ImportError:
        pass

    raise RuntimeError(
        "PDF support requires either 'pymupdf' or 'pdf2image'.\n"
        "  pip install pymupdf"
    )


# ===========================================================================
#  Geometry helpers
# ===========================================================================
def _cluster_1d(vals: list, tol: float) -> List[list]:
    vals = sorted(vals)
    groups: List[list] = [[vals[0]]]
    for v in vals[1:]:
        if v - groups[-1][-1] <= tol:
            groups[-1].append(v)
        else:
            groups.append([v])
    return groups


def _tightest_window(rows: list, k: int) -> list:
    if len(rows) <= k:
        return rows
    best, best_span = rows[:k], float("inf")
    for i in range(0, len(rows) - k + 1):
        win = rows[i:i + k]
        span = win[-1] - win[0]
        if span < best_span:
            best_span, best = span, win
    return best


def _longest_even_run(centers: list, spacing: float, tol: float) -> list:
    centers = sorted(centers)
    best: list = []
    run: list = [centers[0]]
    for i in range(1, len(centers)):
        if abs((centers[i] - run[-1]) - spacing) <= tol:
            run.append(centers[i])
        else:
            if len(run) > len(best):
                best = run
            run = [centers[i]]
    return run if len(run) > len(best) else best


# ===========================================================================
#  Grid detection  (legacy -- retained for compatibility)
# ===========================================================================
def detect_grid(gray: np.ndarray,
                force_options: Optional[int] = None,
                order: str = "rows") -> Tuple[list, list, float, float]:
    h, w = gray.shape[:2]

    _, bw = cv2.threshold(gray, 200, 255, cv2.THRESH_BINARY_INV)
    cnts, _ = cv2.findContours(bw, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)

    raw: list = []
    for c in cnts:
        bx, by, bw_, bh = cv2.boundingRect(c)
        if bh == 0:
            continue
        ar = bw_ / float(bh)
        if (ASPECT_MIN <= ar <= ASPECT_MAX
                and CELL_W_MIN * w <= bw_ <= CELL_W_MAX * w
                and CELL_H_MIN * h <= bh <= CELL_H_MAX * h):
            raw.append((bx, by, bw_, bh))

    if len(raw) < 10:
        raise RuntimeError(
            f"Only {len(raw)} candidate cell(s) found -- cannot build grid.")

    mw = statistics.median(b[2] for b in raw)
    mh = statistics.median(b[3] for b in raw)
    cells = [b for b in raw
             if 0.55 * mw <= b[2] <= 1.6 * mw and 0.55 * mh <= b[3] <= 1.8 * mh]
    if len(cells) < 10:
        raise RuntimeError("Too few cells survive the size filter.")

    cy_all = [b[1] + b[3] / 2 for b in cells]

    row_groups = _cluster_1d(cy_all, mh * 0.7)
    row_centers = sorted(statistics.mean(g) for g in row_groups)
    if len(row_centers) < 2:
        raise RuntimeError("Could not detect multiple option rows.")

    rgaps = [row_centers[i + 1] - row_centers[i] for i in range(len(row_centers) - 1)]
    typ_gap = statistics.median(rgaps)
    blocks: List[list] = [[row_centers[0]]]
    for i, g in enumerate(rgaps):
        if g > typ_gap * 1.8:
            blocks.append([row_centers[i + 1]])
        else:
            blocks[-1].append(row_centers[i + 1])

    real_sizes = [len(b) for b in blocks if len(b) >= 2]
    options = force_options or statistics.mode(real_sizes or [len(b) for b in blocks])

    qblocks = [_tightest_window(b, options) for b in blocks if len(b) >= options]
    if not qblocks:
        raise RuntimeError(
            f"No complete question blocks found (options={options}).")

    def block_columns(rows: list) -> list:
        xs = [b[0] + b[2] / 2 for b in cells
              if any(abs((b[1] + b[3] / 2) - ry) <= mh * 0.8 for ry in rows)]
        if not xs:
            return []
        groups = _cluster_1d(xs, mw * 1.2)
        centers = sorted(statistics.mean(g) for g in groups)
        if len(centers) < 2:
            return centers
        spacing = statistics.median(
            centers[i + 1] - centers[i] for i in range(len(centers) - 1))
        main = _longest_even_run(centers, spacing, spacing * 0.35)
        if len(main) < max(2, len(centers) // 2):
            main = _longest_even_run(centers, spacing, spacing * 0.55)
        return main

    block_cols = [block_columns(blk) for blk in qblocks]

    questions: list = []
    entries: list = []
    for bi, (blk, cols) in enumerate(zip(qblocks, block_cols)):
        for colx in cols:
            entries.append((bi, colx, blk))

    if order == "cols":
        entries.sort(key=lambda e: (round(e[1] / max(mw, 1)), e[0]))
    else:
        entries.sort(key=lambda e: (e[0], e[1]))

    for n, (_, colx, blk) in enumerate(entries, start=1):
        questions.append((n, colx, blk))

    return questions, cells, mw, mh


# ===========================================================================
#  Legacy ink-fill decision  (kept for reference)
# ===========================================================================
def decide(q_num: int, fills: list) -> OmrAnswer:
    base = statistics.median(fills)
    excess = [f - base for f in fills]
    order = sorted(range(len(excess)), key=lambda i: excess[i], reverse=True)
    best = order[0]
    second = order[1] if len(order) > 1 else order[0]
    rounded = [round(e, 3) for e in excess]

    if excess[best] >= EXCESS_MIN:
        if excess[best] >= excess[second] + EXCESS_MARGIN:
            return OmrAnswer(q_num, LETTERS[best], rounded, False)
        if excess[second] >= EXCESS_MIN:
            return OmrAnswer(q_num, LETTERS[best], rounded, True)
    return OmrAnswer(q_num, None, rounded, False)


# ===========================================================================
#  New OMR detection engine
#  (adaptive threshold + contour grouping + local-background ink scoring)
# ===========================================================================

def prepare_threshold(gray: np.ndarray) -> np.ndarray:
    """Adaptive mean threshold — more robust than a fixed global value."""
    blurred = cv2.GaussianBlur(gray, (3, 3), 0)
    return cv2.adaptiveThreshold(
        blurred, 255,
        cv2.ADAPTIVE_THRESH_MEAN_C,
        cv2.THRESH_BINARY_INV,
        21, 10,
    )


def dedupe_boxes(boxes: List[Box]) -> List[Box]:
    """Remove near-duplicate boxes (same centre ±5 px), keeping the largest."""
    selected: List[Box] = []
    for box in sorted(boxes, key=lambda b: b.area, reverse=True):
        duplicate = any(
            abs(box.cx - e.cx) <= 5 and abs(box.cy - e.cy) <= 5
            for e in selected
        )
        if not duplicate:
            selected.append(box)
    return selected


def detect_option_boxes(gray: np.ndarray) -> List[Box]:
    """Detect all answer-option rectangles on a page using adaptive threshold."""
    threshold = prepare_threshold(gray)
    contours, _ = cv2.findContours(threshold, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    height, width = gray.shape

    min_w = max(10, int(width * 0.010))
    max_w = max(30, int(width * 0.045))
    min_h = max(5,  int(height * 0.003))
    max_h = max(14, int(height * 0.018))

    candidates: List[Box] = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        if h == 0:
            continue
        aspect_ratio = w / float(h)
        contour_area = cv2.contourArea(contour)

        if not (min_w <= w <= max_w and min_h <= h <= max_h):
            continue
        if not (1.45 <= aspect_ratio <= 6.00):
            continue
        if contour_area < w * h * 0.18:
            continue

        candidates.append(Box(x, y, w, h))

    return dedupe_boxes(candidates)


def detect_digit_cells(gray: np.ndarray) -> List[Box]:
    """Detect the row of digit cells that contain the Unique Pupil Number.

    Returns the leftmost contiguous group of cells with at least
    PUPIL_NUMBER_MIN_CELLS members (the UPN itself; School Number and
    Date of Birth cells are in later groups separated by larger gaps).
    """
    height, width = gray.shape
    max_y = int(height * PUPIL_NUMBER_Y_MAX_RATIO)

    threshold = prepare_threshold(gray)
    contours, _ = cv2.findContours(threshold, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

    min_side = max(18, int(width * 0.018))
    max_side = max(70, int(width * 0.065))

    candidates: List[Box] = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        if y + h / 2 > max_y:
            continue
        if h == 0:
            continue
        aspect_ratio = w / float(h)
        contour_area = cv2.contourArea(contour)
        rect_area = w * h

        if not (min_side <= w <= max_side and min_side <= h <= max_side):
            continue
        if not (0.50 <= aspect_ratio <= 2.0):
            continue
        if rect_area > 0 and contour_area < rect_area * 0.25:
            continue

        candidates.append(Box(x, y, w, h))

    candidates = dedupe_boxes(candidates)

    if len(candidates) < PUPIL_NUMBER_MIN_CELLS:
        return []

    # Size-consistency filter: keep cells close to the median dimensions.
    widths  = np.array([b.w for b in candidates], dtype=float)
    heights = np.array([b.h for b in candidates], dtype=float)
    med_w   = float(np.median(widths))
    med_h   = float(np.median(heights))
    size_tol = 0.45

    consistent = [
        b for b in candidates
        if abs(b.w - med_w) <= med_w * size_tol
        and abs(b.h - med_h) <= med_h * size_tol
    ]

    if len(consistent) < PUPIL_NUMBER_MIN_CELLS:
        return []

    # Group into horizontal rows.
    row_med_h   = float(np.median([b.h for b in consistent]))
    y_tolerance = row_med_h * 0.4

    rows: List[List[Box]] = []
    for box in sorted(consistent, key=lambda b: b.cy):
        placed = False
        for row in rows:
            row_cy = sum(b.cy for b in row) / len(row)
            if abs(box.cy - row_cy) <= y_tolerance:
                row.append(box)
                placed = True
                break
        if not placed:
            rows.append([box])

    if not rows:
        return []

    best_row = max(rows, key=len)
    if len(best_row) < PUPIL_NUMBER_MIN_CELLS:
        return []

    # Split into contiguous sub-groups by x-gap.
    sorted_row = sorted(best_row, key=lambda b: b.x)
    gaps = [
        sorted_row[i].x - (sorted_row[i - 1].x + sorted_row[i - 1].w)
        for i in range(1, len(sorted_row))
    ]

    if not gaps:
        return sorted_row

    median_gap    = float(np.median(gaps))
    gap_threshold = max(median_gap * 3.0, 10.0)

    groups: List[List[Box]] = [[sorted_row[0]]]
    for i in range(1, len(sorted_row)):
        if gaps[i - 1] > gap_threshold:
            groups.append([sorted_row[i]])
        else:
            groups[-1].append(sorted_row[i])

    # Stitch adjacent groups whose inter-group gap is ≤ 2.5 cell-widths.
    # A gap this small means a single cell border was not detected (e.g. the
    # student's pen stroke covered it), NOT a section boundary between the UPN
    # and the school/DOB fields.  Real section boundaries are ≥ 3× cell-width.
    med_cw = float(np.median([b.w for b in sorted_row]))
    stitch_thr = med_cw * 2.5
    merged: List[List[Box]] = [list(groups[0])]
    for g in groups[1:]:
        last_right = max(b.x + b.w for b in merged[-1])
        if min(b.x for b in g) - last_right <= stitch_thr:
            merged[-1].extend(g)
        else:
            merged.append(list(g))

    for group in merged:
        if len(group) >= PUPIL_NUMBER_MIN_CELLS:
            # Return at most PUPIL_NUMBER_CELLS cells, left-to-right, so that
            # School Number and Date of Birth cells stitched in by a small
            # section gap do not appear in the UPN output.
            return sorted(group, key=lambda b: b.x)[:PUPIL_NUMBER_CELLS]

    return []


# ---------------------------------------------------------------------------
#  Question-card grouping
# ---------------------------------------------------------------------------

def card_center_x(card: List[Box]) -> float:
    return sum(box.cx for box in card) / len(card)


def card_center_y(card: List[Box]) -> float:
    return sum(box.cy for box in card) / len(card)


def maybe_add_card(cards: List[List[Box]], boxes: List[Box]) -> None:
    if MIN_OPTIONS <= len(boxes) <= MAX_OPTIONS:
        cards.append(sorted(boxes, key=lambda b: b.cy))


def group_question_cards(option_boxes: List[Box],
                         page_shape: Tuple[int, int]) -> List[List[Box]]:
    """Group option boxes into per-question cards and return them in
    reading order (top-to-bottom rows, left-to-right within a row)."""
    height, width = page_shape
    usable = [
        box for box in option_boxes
        if box.cx > width  * LEFT_IGNORE_RATIO
        and box.cy > height * TOP_IGNORE_RATIO
        and box.cy < height * BOTTOM_IGNORE_RATIO
    ]

    if not usable:
        return []

    median_w     = float(np.median([b.w for b in usable]))
    median_h     = float(np.median([b.h for b in usable]))
    x_tolerance  = max(14.0, median_w * 0.70)
    y_gap_tol    = max(60.0, median_h * 4.00)

    # Bin boxes by x-centre into columns.
    x_bins: list = []
    for box in sorted(usable, key=lambda b: b.cx):
        for xb in x_bins:
            if abs(box.cx - float(xb["cx"])) <= x_tolerance:
                xb["items"].append(box)
                xb["cx"] = sum(b.cx for b in xb["items"]) / len(xb["items"])
                break
        else:
            x_bins.append({"cx": box.cx, "items": [box]})

    # Within each column split contiguous vertical runs into cards.
    cards: List[List[Box]] = []
    for xb in x_bins:
        run: List[Box] = []
        prev_y: Optional[float] = None
        for box in sorted(xb["items"], key=lambda b: b.cy):
            if prev_y is None or box.cy - prev_y <= y_gap_tol:
                run.append(box)
            else:
                maybe_add_card(cards, run)
                run = [box]
            prev_y = box.cy
        maybe_add_card(cards, run)

    # Drop isolated outlier cards whose option-count is well below the
    # dominant count (e.g. text-input boxes on the form that look like
    # answer boxes but appear in isolated pairs rather than in a full column).
    if cards:
        from collections import Counter as _Counter
        mode_n = _Counter(len(c) for c in cards).most_common(1)[0][0]
        cards = [c for c in cards if len(c) >= max(MIN_OPTIONS, mode_n - 1)]

    return sort_cards_reading_order(cards)


def sort_cards_reading_order(cards: List[List[Box]]) -> List[List[Box]]:
    if not cards:
        return []

    heights     = [max(b.cy for b in c) - min(b.cy for b in c) for c in cards]
    row_tol     = max(45.0, float(np.median(heights)) * 0.80)

    rows: list = []
    for card in sorted(cards, key=card_center_y):
        cy = card_center_y(card)
        for row in rows:
            if abs(cy - float(row["cy"])) <= row_tol:
                row["cards"].append(card)
                row["cy"] = sum(card_center_y(c) for c in row["cards"]) / len(row["cards"])
                break
        else:
            rows.append({"cy": cy, "cards": [card]})

    ordered: List[List[Box]] = []
    for row in sorted(rows, key=lambda r: float(r["cy"])):
        ordered.extend(
            normalize_row_cards(sorted(row["cards"], key=card_center_x))
        )
    return ordered


def normalize_row_cards(cards: List[List[Box]]) -> List[List[Box]]:
    """Ensure every card in a row has the same number of options, inserting
    synthetic placeholder boxes into cards that are missing one or more."""
    if not cards:
        return cards

    counts: Dict[int, int] = {}
    for card in cards:
        counts[len(card)] = counts.get(len(card), 0) + 1

    target = max(counts, key=lambda n: (counts[n], n))
    if target < MIN_OPTIONS:
        return cards

    all_boxes  = [b for card in cards for b in card]
    median_w   = int(round(float(np.median([b.w for b in all_boxes]))))
    median_h   = int(round(float(np.median([b.h for b in all_boxes]))))

    complete = [c for c in cards if len(c) == target]
    if complete:
        steps = []
        for card in complete:
            sc = sorted(card, key=lambda b: b.cy)
            for i in range(1, len(sc)):
                steps.append(sc[i].cy - sc[i - 1].cy)
        typical_step = float(np.median(steps)) if steps else median_h * 2.0
    else:
        typical_step = median_h * 2.0

    normalized: List[List[Box]] = []
    for card in cards:
        sc = sorted(card, key=lambda b: b.cy)

        if len(sc) == target:
            normalized.append([
                Box(int(round(b.cx - median_w / 2)),
                    int(round(b.cy - median_h / 2)),
                    median_w, median_h)
                for b in sc
            ])

        elif len(sc) < target:
            filled = list(sc)
            while len(filled) < target:
                fs = sorted(filled, key=lambda b: b.cy)
                best_gap_idx, best_gap = 0, 0.0
                for i in range(1, len(fs)):
                    gap = fs[i].cy - fs[i - 1].cy
                    if gap > best_gap:
                        best_gap, best_gap_idx = gap, i
                card_x = float(np.median([b.cx for b in filled]))
                if best_gap < typical_step and len(fs) >= 2:
                    # prepend
                    new_cy = fs[0].cy - typical_step
                    filled.append(Box(int(round(card_x - median_w / 2)),
                                      int(round(new_cy - median_h / 2)),
                                      median_w, median_h))
                else:
                    mid_cy = (fs[best_gap_idx - 1].cy + fs[best_gap_idx].cy) / 2
                    filled.append(Box(int(round(card_x - median_w / 2)),
                                      int(round(mid_cy - median_h / 2)),
                                      median_w, median_h))
            fs = sorted(filled, key=lambda b: b.cy)[:target]
            normalized.append([
                Box(int(round(b.cx - median_w / 2)),
                    int(round(b.cy - median_h / 2)),
                    median_w, median_h)
                for b in fs
            ])

        else:
            from itertools import combinations
            best_combo = None
            best_var   = float("inf")
            for combo in combinations(sc, target):
                ys   = [b.cy for b in sorted(combo, key=lambda b: b.cy)]
                diffs = [ys[i + 1] - ys[i] for i in range(len(ys) - 1)]
                var  = float(np.var(diffs)) if diffs else 0.0
                if var < best_var:
                    best_var, best_combo = var, combo
            chosen = sorted(best_combo, key=lambda b: b.cy)  # type: ignore[arg-type]
            normalized.append([
                Box(int(round(b.cx - median_w / 2)),
                    int(round(b.cy - median_h / 2)),
                    median_w, median_h)
                for b in chosen
            ])

    return normalized


# ---------------------------------------------------------------------------
#  Image enhancement (make phone photos look like clean scans)
# ---------------------------------------------------------------------------

def enhance_for_scoring(gray: np.ndarray) -> np.ndarray:
    """Normalise illumination and stretch contrast so a low-contrast phone
    photo scores like a flat high-contrast scan.

    Phone photos of answer sheets have uneven lighting, shadows and a dull
    grey-ish paper instead of pure white.  Genuine pen marks on such images
    only reach ~0.35-0.45 darkness (vs >0.55 on a real scan), so they fall
    below the mark threshold and are wrongly read as blank.

    The fix is a flat-field correction: estimate the background illumination
    with a large-kernel blur, divide the image by it (which flattens shadows
    and makes the paper uniformly white), then stretch the contrast using
    robust percentiles.  On an already-clean scan this is close to a no-op
    (the background is already uniform white), so it is safe to apply to
    every page.  Used only for *scoring* — box detection still runs on the
    raw grayscale, which detection handles more reliably.
    """
    if gray.ndim != 2:
        gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)

    # Large odd kernel ~1/8 of the short side for the illumination estimate.
    k = max(31, (min(gray.shape) // 8) | 1)
    background = cv2.GaussianBlur(gray, (k, k), 0)
    flat = cv2.divide(gray, background, scale=255)

    # Robust contrast stretch: map the 2nd..99th percentile to 0..255 so a
    # few dark mark pixels and specular highlights don't skew the range.
    lo = float(np.percentile(flat, 2))
    hi = float(np.percentile(flat, 99))
    if hi - lo < 1.0:
        return flat
    stretched = (flat.astype(np.float32) - lo) * (255.0 / (hi - lo))
    return np.clip(stretched, 0, 255).astype(np.uint8)


# ---------------------------------------------------------------------------
#  Robust mark scoring
# ---------------------------------------------------------------------------

def _inner_roi(gray: np.ndarray, box: Box,
               shrink_x: float = 0.25, shrink_y: float = 0.30) -> np.ndarray:
    """Extract the interior of a box, avoiding the printed border lines."""
    mx = max(3, int(round(box.w * shrink_x)))
    my = max(2, int(round(box.h * shrink_y)))
    x1 = min(gray.shape[1], box.x + mx)
    x2 = max(x1 + 1, min(gray.shape[1], box.x + box.w - mx))
    y1 = min(gray.shape[0], box.y + my)
    y2 = max(y1 + 1, min(gray.shape[0], box.y + box.h - my))
    return gray[y1:y2, x1:x2]


def _estimate_local_background(gray: np.ndarray, box: Box, margin: int = 8) -> float:
    """Sample strips above/below a box to estimate local paper brightness."""
    h, w = gray.shape
    samples = []

    ya1, ya2 = max(0, box.y - margin), box.y
    xa1, xa2 = max(0, box.x), min(w, box.x + box.w)
    if ya2 > ya1 and xa2 > xa1:
        strip = gray[ya1:ya2, xa1:xa2]
        if strip.size:
            samples.append(float(np.percentile(strip, 85)))

    yb1 = min(h, box.y + box.h)
    yb2 = min(h, yb1 + margin)
    if yb2 > yb1 and xa2 > xa1:
        strip = gray[yb1:yb2, xa1:xa2]
        if strip.size:
            samples.append(float(np.percentile(strip, 85)))

    return float(np.mean(samples)) if samples else 220.0


#  Gap-spillover signal: a mark counts even when drawn slightly outside the
#  box.  Only counted when very dark (>= this) so the clean inter-box gaps
#  never produce a false mark.
GAP_SPILLOVER_MIN = 0.50


def _gap_spillover_score(gray: np.ndarray, box: Box, bg: float) -> float:
    """Darkness of a strong horizontal stroke drawn just *outside* the box.

    Students sometimes draw the line a little high or low so it crosses the
    box border and continues into the gap above/below.  The inner ROI misses
    that ink, so this looks in the clean gap strips immediately adjacent to
    the box (the printed border is skipped, the neighbouring box is not
    reached).  Because those strips are normally blank paper, any localised
    dark row there is a genuine mark.  Returns 0 unless the stroke is clearly
    dark, so it can only ever promote a real spill-over mark, never invent one.
    """
    h, w = box.h, box.w
    inx = int(round(w * 0.20))
    x1, x2 = box.x + inx, box.x + w - inx
    if x2 <= x1 or bg <= 1.0:
        return 0.0

    skip = max(2, int(round(h * 0.18)))   # clear the printed border edge
    ext  = max(4, int(round(h * 0.40)))   # reach into the gap, not the neighbour
    best = 0.0
    for y1, y2 in ((box.y - ext, box.y - skip),
                   (box.y + h + skip, box.y + h + ext)):
        y1 = max(0, y1)
        y2 = min(gray.shape[0], y2)
        if y2 - y1 < 1:
            continue
        strip = gray[y1:y2, x1:x2].astype(np.float64)
        if strip.size == 0:
            continue
        row_means = np.mean(strip, axis=1)
        med = float(np.median(row_means))
        mn  = float(np.min(row_means))
        if (med - mn) / (med + 1e-6) <= 0.20:   # not a localised stroke
            continue
        best = max(best, (bg - mn) / bg)

    return best if best >= GAP_SPILLOVER_MIN else 0.0


def score_option_mark(gray: np.ndarray, box: Box) -> float:
    """Return a normalised darkness score [0,1] for one option box.

    Three complementary signals are computed and the strongest is returned:

    * **Mean-of-darkest-30 %** — catches any concentrated fill (heavy pen,
      pencil blobs, tick marks covering most of the cell interior).

    * **Thin-line row score** — for GL Assessment sheets students draw a
      thin horizontal line through a box (only 1–2 pixel rows are dark).
      This signal fires when one row is markedly darker than the median
      of all rows (ratio-gated to avoid false positives from the printed
      box border, which darkens all rows uniformly).

    * **Gap spill-over** — a strong horizontal stroke drawn a little outside
      the box (crossing the border into the clean gap above/below) so the
      student's intent is still captured.

    The y-shrinkage stays at 0.30 (the original value) to keep the ROI
    safely inside the printed border.  The x-shrinkage is slightly relaxed
    to 0.20 so more of the horizontal mark extent is captured.
    """
    roi = _inner_roi(gray, box, shrink_x=0.20, shrink_y=0.30)
    if roi.size == 0:
        return 0.0

    bg = _estimate_local_background(gray, box)
    if bg <= 1.0:
        return 0.0

    roi_f = roi.astype(np.float64)

    # Signal 1: mean of the darkest 30 % of pixels.
    flat         = roi_f.flatten()
    flat_sorted  = np.sort(flat)
    n_dark       = max(1, int(len(flat_sorted) * 0.30))
    mean_dark_30 = float(np.mean(flat_sorted[:n_dark]))
    score1 = max(0.0, (bg - mean_dark_30) / bg)

    # Signal 2: thin-line row projection.
    # Only count a "dark row" when it stands clearly below the median row
    # brightness (ratio threshold 0.20).  This rejects the uniform border
    # influence where every row is equally slightly darkened.
    score2 = 0.0
    if roi_f.shape[0] >= 3:
        row_means   = np.mean(roi_f, axis=1)
        med_row     = float(np.median(row_means))
        min_row     = float(np.min(row_means))
        row_contrast = (med_row - min_row) / (med_row + 1e-6)
        if row_contrast > 0.20:            # one row clearly darker than median
            score2 = max(0.0, (bg - min_row) / bg)

    # Signal 3: mark drawn slightly outside the box.
    score3 = _gap_spillover_score(gray, box, bg)

    return max(score1, score2, score3)


def extract_card_answer(gray: np.ndarray, card: List[Box]) -> str:
    """Determine which option in a question card is marked.

    Returns an uppercase letter ("A", "B", …), "blank", or
    " multiple(A,B)" when multiple strong marks are detected.
    """
    option_labels = list(string.ascii_uppercase)
    scores = [score_option_mark(gray, box) for box in card]

    if not scores:
        return "blank"

    best = max(scores)
    # A genuine pen/pencil mark scores >= ~0.48 on these sheets, whereas the
    # printed box border (top/bottom edge) bleeding into the ROI tops out at
    # ~0.43.  ABS_MIN sits in that empirical gap so border bleed is never
    # mistaken for a mark (was 0.06, which let the borders through).
    ABS_MIN = 0.45
    if best < ABS_MIN:
        return "blank"

    median_score = float(np.median(scores))
    min_gap   = max(0.04, median_score * 0.30)
    threshold = median_score + max(min_gap, (best - median_score) * 0.40)
    hard_floor = max(ABS_MIN, median_score + 0.025)

    candidates = [
        (idx, s) for idx, s in enumerate(scores)
        if s >= threshold and s >= hard_floor
    ]

    if not candidates:
        return "blank"
    if len(candidates) == 1:
        return option_labels[candidates[0][0]]

    primary_score   = max(s for _, s in candidates)
    multi_threshold = primary_score * 0.55
    strong = [(idx, s) for idx, s in candidates if s >= multi_threshold]

    if len(strong) == 1:
        return option_labels[strong[0][0]]

    labels = [option_labels[idx] for idx, _ in sorted(strong)]
    return f" multiple({','.join(labels)})"


# ---------------------------------------------------------------------------
#  Digit recognition (CV-based, Tesseract optional)
# ---------------------------------------------------------------------------

def _compute_digit_features(binary: np.ndarray) -> Dict[str, float]:
    """Compute structural features from a binarised digit image (white-on-black)."""
    h, w = binary.shape
    total = h * w
    if total == 0:
        return {}

    white   = float(np.count_nonzero(binary))
    density = white / total

    if white > 0:
        ys, xs = np.where(binary > 0)
        cx = float(np.mean(xs)) / w
        cy = float(np.mean(ys)) / h
    else:
        cx, cy = 0.5, 0.5

    mid_y        = h // 2
    top_pixels   = float(np.count_nonzero(binary[:mid_y, :]))
    bot_pixels   = float(np.count_nonzero(binary[mid_y:, :]))
    top_density  = top_pixels / max(1, mid_y * w)
    bot_density  = bot_pixels / max(1, (h - mid_y) * w)
    tb_ratio     = top_density / (top_density + bot_density + 1e-9)

    mid_x        = w // 2
    left_pixels  = float(np.count_nonzero(binary[:, :mid_x]))
    right_pixels = float(np.count_nonzero(binary[:, mid_x:]))
    left_density  = left_pixels  / max(1, h * mid_x)
    right_density = right_pixels / max(1, h * (w - mid_x))
    lr_ratio      = left_density / (left_density + right_density + 1e-9)

    q1y, q3y      = h // 4, 3 * h // 4
    q1x, q3x      = w // 4, 3 * w // 4
    centre        = binary[q1y:q3y, q1x:q3x]
    centre_density = float(np.count_nonzero(centre)) / max(1, centre.size)

    contours, _ = cv2.findContours(binary.copy(), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    n_contours  = min(len(contours), 5) / 5.0

    mid_row     = binary[mid_y, :]
    transitions = int(np.sum(np.abs(np.diff(mid_row.astype(np.int16))) > 127))
    h_crossings = min(transitions, 8) / 8.0

    return {
        "density":        density,
        "cx":             cx,
        "cy":             cy,
        "tb_ratio":       tb_ratio,
        "lr_ratio":       lr_ratio,
        "centre_density": centre_density,
        "n_contours":     n_contours,
        "h_crossings":    h_crossings,
    }


def _recognise_digit_features(binary: np.ndarray) -> str:
    """Classify a binarised digit ROI with a feature-profile nearest-neighbour."""
    if binary.size == 0:
        return "?"

    resized = cv2.resize(binary, (28, 28), interpolation=cv2.INTER_AREA)
    _, resized = cv2.threshold(resized, 127, 255, cv2.THRESH_BINARY)

    feats = _compute_digit_features(resized)
    if not feats or feats.get("density", 0) < 0.04:
        return "?"

    profiles: Dict[str, Dict[str, float]] = {
        "0": {"density": 0.1745, "cx": 0.4001, "cy": 0.6492, "tb_ratio": 0.2517, "lr_ratio": 0.6137, "centre_density": 0.3493, "n_contours": 0.2883, "h_crossings": 0.4854},
        "1": {"density": 0.1160, "cx": 0.4936, "cy": 0.5431, "tb_ratio": 0.4303, "lr_ratio": 0.4086, "centre_density": 0.2917, "n_contours": 0.2000, "h_crossings": 0.2500},
        "2": {"density": 0.1471, "cx": 0.3959, "cy": 0.6038, "tb_ratio": 0.3708, "lr_ratio": 0.7885, "centre_density": 0.2309, "n_contours": 0.3067, "h_crossings": 0.2500},
        "3": {"density": 0.1652, "cx": 0.4920, "cy": 0.4909, "tb_ratio": 0.5306, "lr_ratio": 0.4188, "centre_density": 0.2488, "n_contours": 0.2000, "h_crossings": 0.2500},
        "4": {"density": 0.1653, "cx": 0.4557, "cy": 0.4949, "tb_ratio": 0.4723, "lr_ratio": 0.4744, "centre_density": 0.4588, "n_contours": 0.2000, "h_crossings": 0.2500},
        "5": {"density": 0.2243, "cx": 0.4605, "cy": 0.5457, "tb_ratio": 0.4589, "lr_ratio": 0.5287, "centre_density": 0.3472, "n_contours": 0.2000, "h_crossings": 0.3083},
        "6": {"density": 0.1742, "cx": 0.2542, "cy": 0.5810, "tb_ratio": 0.3203, "lr_ratio": 1.0000, "centre_density": 0.1730, "n_contours": 0.3533, "h_crossings": 0.2500},
        "7": {"density": 0.1730, "cx": 0.4761, "cy": 0.4858, "tb_ratio": 0.4467, "lr_ratio": 0.5140, "centre_density": 0.4111, "n_contours": 0.5533, "h_crossings": 0.3083},
        "8": {"density": 0.2843, "cx": 0.5319, "cy": 0.4715, "tb_ratio": 0.4942, "lr_ratio": 0.4493, "centre_density": 0.4469, "n_contours": 0.4000, "h_crossings": 0.2500},
        "9": {"density": 0.2348, "cx": 0.5764, "cy": 0.4767, "tb_ratio": 0.5366, "lr_ratio": 0.3148, "centre_density": 0.4570, "n_contours": 0.6000, "h_crossings": 0.5000},
    }

    weights: Dict[str, float] = {
        "centre_density": 2.0,
        "cx":             2.5,
        "cy":             3.0,
        "density":        3.0,
        "h_crossings":    2.0,
        "lr_ratio":       2.5,
        "n_contours":     1.0,
        "tb_ratio":       3.5,
    }

    best_digit, best_score = "?", float("inf")
    for digit, profile in profiles.items():
        score = sum(
            abs(feats[k] - profile[k]) * weights.get(k, 1.0)
            for k in feats if k in profile
        )
        if score < best_score:
            best_score, best_digit = score, digit

    return best_digit


def _extract_digit_roi(gray: np.ndarray, box: Box) -> Optional[np.ndarray]:
    """Extract and binarise the interior of a digit cell (white strokes on black)."""
    mx = max(3, int(box.w * 0.18))
    my = max(3, int(box.h * 0.18))
    x1 = max(0, box.x + mx)
    x2 = min(gray.shape[1], box.x + box.w - mx)
    y1 = max(0, box.y + my)
    y2 = min(gray.shape[0], box.y + box.h - my)

    if x2 <= x1 or y2 <= y1:
        return None

    roi = gray[y1:y2, x1:x2]
    _, binary = cv2.threshold(roi, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    return binary


def recognise_digit(gray: np.ndarray, box: Box) -> str:
    """Recognise a single handwritten digit cell.

    Tries Tesseract first (if available); falls back to the built-in
    feature-profile classifier which needs no external dependencies.
    """
    binary = _extract_digit_roi(gray, box)
    if binary is None:
        return "?"

    try:
        import pytesseract
        if _ensure_tesseract():
            clean    = cv2.bitwise_not(binary)
            bordered = cv2.copyMakeBorder(clean, 10, 10, 10, 10,
                                          cv2.BORDER_CONSTANT, value=255)
            config = "--psm 10 -c tessedit_char_whitelist=0123456789"
            result = pytesseract.image_to_string(bordered, config=config).strip()
            if result and result[0].isdigit():
                return result[0]
    except Exception:
        pass

    return _recognise_digit_features(binary)


# ===========================================================================
#  Process a single page image
#  Returns (answers, debug_image, qmeta)
#  qmeta[i] = {"q": global_qnum, "rects": [...], "choice_idx": int|None}
# ===========================================================================
def process_image(src: np.ndarray,
                  force_options: Optional[int] = None,
                  order: str = "rows") -> Tuple[List[OmrAnswer], np.ndarray, list]:
    h, w = src.shape[:2]
    gray = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)

    # Detect option boxes and group them into per-question cards.
    option_boxes = detect_option_boxes(gray)
    cards        = group_question_cards(option_boxes, gray.shape)

    if not cards:
        raise RuntimeError(
            f"Only {len(option_boxes)} candidate cell(s) found -- cannot build grid.")

    # Apply force_options filter.
    if force_options is not None:
        cards = [c for c in cards if len(c) == force_options]
        if not cards:
            raise RuntimeError(
                f"No question cards with exactly {force_options} options found.")

    # Re-sort into column-first order when requested.
    if order == "cols" and cards:
        med_w = float(np.median([b.w for card in cards for b in card]))
        x_tol = med_w * 0.7
        col_groups: list = []
        for card in sorted(cards, key=card_center_x):
            cx = card_center_x(card)
            for grp in col_groups:
                if abs(cx - float(grp["cx"])) <= x_tol:
                    grp["cards"].append(card)
                    grp["cx"] = sum(card_center_x(c) for c in grp["cards"]) / len(grp["cards"])
                    break
            else:
                col_groups.append({"cx": cx, "cards": [card]})
        cards = []
        for grp in sorted(col_groups, key=lambda g: float(g["cx"])):
            cards.extend(sorted(grp["cards"], key=card_center_y))

    debug    = src.copy()
    answers: List[OmrAnswer] = []
    qmeta:   list = []

    # Decide whether this page needs illumination normalisation: score on the
    # raw grayscale first and look at how strong the detected marks are.  A
    # clean scan yields strong marks (median ~0.6-0.73) and is left as-is; a
    # dull phone photo yields weak marks (~0.48) and is enhanced so genuine
    # marks reach scan-like darkness before the real scoring pass.
    raw_max = [max(score_option_mark(gray, b) for b in card) for card in cards]
    strong  = [m for m in raw_max if m > 0.25]
    score_gray = gray
    if strong and float(np.median(strong)) < LOW_CONTRAST_MEDIAN:
        score_gray = enhance_for_scoring(gray)

    for q_num, card in enumerate(cards, start=1):
        scores     = [score_option_mark(score_gray, box) for box in card]
        answer_str = extract_card_answer(score_gray, card)

        # Parse the answer string into OmrAnswer fields.
        if answer_str == "blank":
            choice, multi_mark = None, False
        elif "multiple(" in answer_str:
            inner   = answer_str.strip().removeprefix("multiple(").removesuffix(")")
            letters = inner.split(",")
            choice     = letters[0] if letters else None
            multi_mark = True
        else:
            choice, multi_mark = answer_str, False

        excess = [round(s, 3) for s in scores]
        answers.append(OmrAnswer(q_num, choice, excess, multi_mark))

        rects      = [(box.x, box.y, box.w, box.h) for box in card]
        choice_idx = LETTERS.index(choice) if choice and choice in LETTERS else None
        colx       = sum(box.cx for box in card) / len(card)

        for box in card:
            cv2.rectangle(debug, (box.x, box.y),
                          (box.x + box.w, box.y + box.h), (200, 200, 200), 1)

        qmeta.append({"q": q_num, "rects": rects, "choice_idx": choice_idx,
                      "colx": colx, "rows": [box.cy for box in card]})

    return answers, debug, qmeta


# ===========================================================================
#  Scan a whole sheet (all pages), renumbering questions 1..N continuously
# ===========================================================================
def scan_sheet(file_path: str,
               force_options: Optional[int] = None,
               order: str = "rows",
               dpi: int = 200) -> Tuple[List[OmrAnswer], list]:
    """
    Returns (answers, pages) where:
      answers : OmrAnswer list with GLOBAL question numbers 1..N
      pages   : list of {"debug": img, "qmeta": [...]}  (qmeta q's are global)
    """
    ext = Path(file_path).suffix.lower()
    if ext == ".pdf":
        imgs = pdf_to_images(file_path, dpi=dpi)
    else:
        im = cv2.imread(file_path, cv2.IMREAD_COLOR)
        if im is None:
            raise FileNotFoundError(f"Could not read image: {file_path}")
        imgs = [im]

    answers: List[OmrAnswer] = []
    pages:   list = []
    offset = 0
    for pidx, img in enumerate(imgs, start=1):
        try:
            page_ans, debug, qmeta = process_image(img, force_options, order)
        except RuntimeError as e:
            print(f"  ! page {pidx}: detection failed ({e})", file=sys.stderr)
            continue
        for a in page_ans:
            a.question += offset
        for m in qmeta:
            m["q"] += offset
        answers.extend(page_ans)
        pages.append({"debug": debug, "qmeta": qmeta})
        offset += len(page_ans)
    return answers, pages


# ===========================================================================
#  Grading
# ===========================================================================
def grade_sheet(sheet_name: str,
                answers: List[OmrAnswer],
                key: Optional[Dict[int, str]],
                key_name: Optional[str]) -> SheetResult:
    by_q = {a.question: a for a in answers}
    graded: List[GradedQuestion] = []

    qnums = sorted(key.keys()) if key else sorted(by_q.keys())
    for q in qnums:
        a       = by_q.get(q)
        correct = key.get(q) if key else None
        marked  = a.choice if a else None

        if a is not None and a.multi_mark:
            all_marked = [LETTERS[i] for i, e in enumerate(a.excess)
                          if i < len(LETTERS) and e >= EXCESS_MIN]
            if len(all_marked) > 1:
                marked = ", ".join(all_marked)
            status = "multi"
        elif marked is None:
            status = "blank"
        elif correct is None:
            status = "no-key"
        elif marked == correct:
            status = "correct"
        else:
            status = "wrong"
        graded.append(GradedQuestion(q, marked, correct, status))

    return SheetResult(sheet_name, key_name, graded)


# ===========================================================================
#  Graded overlay annotation
# ===========================================================================
GRADE_COLOR = {
    "correct": (0, 170, 0),     # green
    "wrong":   (0, 0, 230),     # red
    "blank":   (0, 170, 230),   # amber
    "multi":   (0, 0, 230),     # red
    "no-key":  (180, 180, 0),   # teal
}


def annotate_pages(pages: list, result: SheetResult) -> None:
    status_by_q = {g.question: g for g in result.graded}
    for pg in pages:
        debug = pg["debug"]
        for m in pg["qmeta"]:
            g = status_by_q.get(m["q"])
            if g is None:
                continue
            color = GRADE_COLOR.get(g.status, (120, 120, 120))

            if m["choice_idx"] is not None:
                rx, ryy, rww, rhh = m["rects"][m["choice_idx"]]
                cv2.rectangle(debug, (rx, ryy), (rx + rww, ryy + rhh), color, 2)

            if g.status in ("wrong", "blank") and g.correct in LETTERS:
                ci = LETTERS.index(g.correct)
                if ci < len(m["rects"]):
                    rx, ryy, rww, rhh = m["rects"][ci]
                    cv2.rectangle(debug, (rx, ryy), (rx + rww, ryy + rhh),
                                  (0, 170, 0), 1)

            x0 = m["rects"][0][0]
            y0 = m["rects"][0][1] - 3
            lab = f"{m['q']}:{g.marked or '-'}"
            if g.status == "wrong":
                lab += f">{g.correct}"
            cv2.putText(debug, lab, (int(x0), int(y0)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.28, color, 1, cv2.LINE_AA)


# ===========================================================================
#  Input gathering & key selection
# ===========================================================================
def gather_sheets(inputs: List[str]) -> List[str]:
    files: List[str] = []
    for inp in inputs:
        if os.path.isdir(inp):
            for ext in (".pdf", *IMG_EXTS):
                files += glob.glob(os.path.join(inp, f"*{ext}"))
                files += glob.glob(os.path.join(inp, f"*{ext.upper()}"))
        elif os.path.isfile(inp):
            files.append(inp)
        else:
            matches = glob.glob(inp)
            if matches:
                files += matches
            else:
                print(f"  ! not found: {inp}", file=sys.stderr)
    seen, out = set(), []
    for f in sorted(files):
        rp = os.path.realpath(f)
        if rp not in seen:
            seen.add(rp)
            out.append(f)
    return out


def pick_key(sheet_path: str,
             keys: Dict[str, Dict[int, str]],
             forced_name: Optional[str]) -> Tuple[Optional[str], Optional[Dict[int, str]]]:
    if not keys:
        return None, None

    if forced_name is not None:
        fn = forced_name.lower()
        for name in keys:
            if name.lower() == fn or name.lower() == f"page{fn}":
                return name, keys[name]
        for name in keys:
            if re.sub(r"\D", "", name) == re.sub(r"\D", "", forced_name):
                return name, keys[name]
        raise RuntimeError(f"--key-name '{forced_name}' not found. "
                           f"Available: {', '.join(keys)}")

    if len(keys) == 1:
        name = next(iter(keys))
        return name, keys[name]

    base        = Path(sheet_path).stem.lower()
    key_numbers = {name: re.sub(r"\D", "", name) for name in keys}
    hits        = []
    for name, num in key_numbers.items():
        if not num:
            continue
        toks = [f"fam{num}", f"fam {num}", f"fam-{num}", f"fam_{num}",
                f"nvr{num}", f"nvr {num}", f"nvr-{num}", f"nvr_{num}",
                f"paper{num}", f"test{num}", f"key{num}",
                f"familiarisation {num}", f"familiarisation{num}"]
        if any(t in base for t in toks):
            hits.append(name)
    if len(hits) == 1:
        return hits[0], keys[hits[0]]

    nums_in_name = set(re.findall(r"\d+", base))
    num_hits = [name for name, num in key_numbers.items() if num in nums_in_name]
    if len(num_hits) == 1:
        return num_hits[0], keys[num_hits[0]]

    raise RuntimeError(
        f"Could not decide which key to use for '{Path(sheet_path).name}'. "
        f"Pass --key-name (available: {', '.join(keys)}).")


# ===========================================================================
#  Reporting
# ===========================================================================
def print_sheet_report(res: SheetResult, verbose_wrong: bool = True) -> None:
    keyinfo = f"key {res.key_name}" if res.key_name else "no key"
    if res.total:
        bar = f"{res.score}/{res.total}  ({res.percent:.1f}%)"
    else:
        bar = "scan only (no key)"
    print(f"\n=== {Path(res.sheet).name}   [{keyinfo}] ===")
    if res.name or res.upn:
        print(f"  Pupil : {res.name or '(unread)'}   UPN: {res.upn or '-'}")
    print(f"  Score : {bar}")
    if verbose_wrong and res.total:
        corrects = [g for g in res.graded if g.status == "correct"]
        wrongs  = [g for g in res.graded if g.status == "wrong"]
        multis  = [g for g in res.graded if g.status == "multi"]
        blanks  = [g for g in res.graded if g.status == "blank"]
        correct_ids = f" ({', '.join('Q' + str(g.question) for g in corrects)})" if corrects else ""
        blank_ids = f" ({', '.join('Q' + str(g.question) for g in blanks)})" if blanks else ""
        multi_ids = f" ({', '.join('Q' + str(g.question) for g in multis)})" if multis else ""
        print(f"  Correct : {len(corrects)}{correct_ids}")
        print(f"  Blank : {res.blanks}{blank_ids}   Ambiguous(multi): {res.multis}{multi_ids}")
        if wrongs:
            txt = ", ".join(f"Q{g.question}({g.marked or '-'}->{g.correct})"
                            for g in wrongs)
            print(f"  Wrong : {txt}")
        if not wrongs and not blanks and not multis:
            print("  All correct!")
    elif not res.total:
        print(f"  Blank : {res.blanks}   Ambiguous(multi): {res.multis}")
        line = "  " + "  ".join(f"Q{g.question}:{g.marked or '-'}" for g in res.graded)
        print(line)


def print_summary(results: List[SheetResult]) -> None:
    if not results:
        return
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)

    def label(r: SheetResult) -> str:
        nm = Path(r.sheet).name
        return f"{r.upn}  {nm}" if r.upn else nm

    def pupil(r: SheetResult) -> str:
        return r.name if r.name else "-"

    name_w = max(len(label(r)) for r in results)
    name_w = min(max(name_w, 12), 52)
    pup_w = max(len("Pupil Name"), max(len(pupil(r)) for r in results))
    pup_w = min(pup_w, 22)
    print(f"{'Pupil Name'.ljust(pup_w)}  {'UPN / Sheet'.ljust(name_w)}  Key   Score    %")
    print("-" * (name_w + pup_w + 24))
    graded = [r for r in results if r.total]
    for r in results:
        nm = label(r)
        if len(nm) > name_w:
            nm = nm[:name_w - 1] + "…"
        pn = pupil(r)
        if len(pn) > pup_w:
            pn = pn[:pup_w - 1] + "…"
        if r.total:
            print(f"{pn.ljust(pup_w)}  {nm.ljust(name_w)}  {str(r.key_name or '-').ljust(3)}  "
                  f"{str(r.score)+'/'+str(r.total):>6}  {r.percent:5.1f}")
        else:
            print(f"{pn.ljust(pup_w)}  {nm.ljust(name_w)}  {'-':<3}  {'scan':>6}    -")
    if graded:
        avg = statistics.mean(r.percent for r in graded)
        print("-" * (name_w + pup_w + 24))
        print(f"{'AVERAGE'.ljust(pup_w)}  {''.ljust(name_w)}  {'':3}  {'':6}  {avg:5.1f}")


def write_csv(results: List[SheetResult], path: str) -> None:
    os.makedirs(os.path.dirname(os.path.abspath(path)) or ".", exist_ok=True)
    with open(path, "w", newline="") as f:
        wr = csv.writer(f)
        wr.writerow(["sheet", "key", "question", "marked", "correct",
                     "status", "is_correct"])
        for r in results:
            for g in r.graded:
                wr.writerow([Path(r.sheet).name, r.key_name or "", g.question,
                             g.marked or "", g.correct or "", g.status,
                             int(g.is_correct)])
    summ = str(Path(path).with_name(Path(path).stem + "_summary.csv"))
    with open(summ, "w", newline="") as f:
        wr = csv.writer(f)
        wr.writerow(["pupil_name", "name_confidence", "upn", "sheet", "key",
                     "score", "total", "percent", "blank", "multi"])
        for r in results:
            wr.writerow([r.name, f"{r.name_conf:.0f}", r.upn,
                         Path(r.sheet).name, r.key_name or "", r.score,
                         r.total, f"{r.percent:.1f}", r.blanks, r.multis])
    print(f"\nCSV  -> {path}")
    print(f"CSV  -> {summ}")


# ===========================================================================
#  CLI
# ===========================================================================
def main() -> None:
    ap = argparse.ArgumentParser(
        description="Batch OMR grader for GL Assessment answer sheets. "
                    "Reads an answer key, then scores every filled sheet.")
    ap.add_argument("inputs", nargs="+",
                    help="Sheet file(s), folder(s), or glob(s) to grade.")
    ap.add_argument("--key", default=None,
                    help="Answer-key file (PDF or .txt). Omit to only scan.")
    ap.add_argument("--key-name", default=None,
                    help="Force which key to use for ALL sheets (number/name).")
    ap.add_argument("--out-dir", default="omr_out",
                    help="Directory for graded overlay images (default omr_out).")
    ap.add_argument("--csv", default=None,
                    help="Combined results CSV (default <out-dir>/results.csv).")
    ap.add_argument("--no-overlay", action="store_true",
                    help="Do not write overlay images.")
    ap.add_argument("--options", type=int, default=None,
                    help="Force options-per-question (default: auto).")
    ap.add_argument("--order", choices=["rows", "cols"], default="rows",
                    help="Question numbering order (default rows).")
    ap.add_argument("--dpi", type=int, default=200,
                    help="DPI for PDF rendering (default 200).")
    ap.add_argument("--questions", type=int, default=None,
                    help="Expected question count per sheet (sanity warning).")
    args = ap.parse_args()

    keys: Dict[str, Dict[int, str]] = {}
    if args.key:
        if not os.path.isfile(args.key):
            print(f"Error: key file not found: {args.key}", file=sys.stderr)
            sys.exit(1)
        keys = load_answer_keys(args.key, dpi=args.dpi)
        print(f"Loaded {len(keys)} answer key(s) from {Path(args.key).name}: "
              f"{', '.join(f'{k}({len(v)}q)' for k, v in keys.items())}")
        # Log every parsed key as  Q1-A, Q2-B, ...  so the extraction is auditable.
        for name, block in keys.items():
            answers = ", ".join(f"Q{q}-{block[q]}" for q in sorted(block))
            print(f"  Key {name}: {answers}")

    sheets = gather_sheets(args.inputs)
    if not sheets:
        print("Error: no input sheets found.", file=sys.stderr)
        sys.exit(1)
    print(f"Found {len(sheets)} sheet(s) to process.")

    if not args.no_overlay:
        os.makedirs(args.out_dir, exist_ok=True)

    results: List[SheetResult] = []
    for sheet in sheets:
        print(f"\n--- {Path(sheet).name} ---")
        try:
            answers, pages = scan_sheet(sheet, args.options, args.order, args.dpi)
        except Exception as e:
            print(f"  ! scan failed: {e}", file=sys.stderr)
            continue

        if args.questions and len(answers) != args.questions:
            print(f"  ! warning: detected {len(answers)} questions "
                  f"(expected {args.questions})", file=sys.stderr)

        key_name, key = (None, None)
        if keys:
            try:
                key_name, key = pick_key(sheet, keys, args.key_name)
            except RuntimeError as e:
                print(f"  ! {e}", file=sys.stderr)

        res = grade_sheet(sheet, answers, key, key_name)
        res.upn = extract_upn(sheet, args.dpi)
        res.name, res.name_conf = extract_name(sheet, args.dpi)
        _page1_cache.clear()   # free cached page image before next sheet
        results.append(res)
        print_sheet_report(res)

        if not args.no_overlay:
            annotate_pages(pages, res)
            stem = Path(sheet).stem
            for i, pg in enumerate(pages, start=1):
                suffix   = f"_p{i}" if len(pages) > 1 else ""
                out_path = os.path.join(args.out_dir, f"{stem}_graded{suffix}.png")
                cv2.imwrite(out_path, pg["debug"])
            print(f"  overlay(s) -> {args.out_dir}/{stem}_graded*.png")

    print_summary(results)

    if keys:
        csv_path = args.csv or os.path.join(args.out_dir, "results.csv")
        write_csv(results, csv_path)


if __name__ == "__main__":
    main()
