#!/usr/bin/env python3
"""
compare_pdb.py — semantic PDB comparator for unit tests.

Usage:
  python compare_pdb.py ref.pdb out.pdb --atol 1e-4
  python compare_pdb.py ref.pdb out.pdb --atol 1e-4 --ignore-serial --ignore-bfactor --ignore-occupancy
"""

from __future__ import annotations
import argparse
import sys
from dataclasses import dataclass
from typing import List, Tuple, Optional
import math

@dataclass(frozen=True)
class AtomKey:
    record: str         # ATOM or HETATM
    name: str           # atom name (cols 13-16)
    altloc: str         # altLoc (col 17)
    resname: str        # residue name (cols 18-20)
    chain: str          # chainID (col 22)
    resseq: int         # resSeq (cols 23-26)
    icode: str          # iCode (col 27)
    element: str        # element (cols 77-78) may be blank

@dataclass
class AtomRec:
    key: AtomKey
    serial: Optional[int]
    x: float
    y: float
    z: float
    occ: Optional[float]
    bfac: Optional[float]

def _slice(line: str, a: int, b: int) -> str:
    # 1-based inclusive columns -> python slice
    return line[a-1:b]

def parse_pdb(path: str) -> List[AtomRec]:
    atoms: List[AtomRec] = []
    with open(path, "r", encoding="utf-8", errors="replace") as f:
        for ln, line in enumerate(f, 1):
            rec = _slice(line, 1, 6).strip()
            if rec not in ("ATOM", "HETATM"):
                continue

            try:
                serial_s = _slice(line, 7, 11).strip()
                serial = int(serial_s) if serial_s else None

                name   = _slice(line, 13, 16)
                altloc = _slice(line, 17, 17)
                resn   = _slice(line, 18, 20)
                chain  = _slice(line, 22, 22)
                resseq = int(_slice(line, 23, 26).strip() or "0")
                icode  = _slice(line, 27, 27)

                x = float(_slice(line, 31, 38).strip())
                y = float(_slice(line, 39, 46).strip())
                z = float(_slice(line, 47, 54).strip())

                occ_s  = _slice(line, 55, 60).strip()
                b_s    = _slice(line, 61, 66).strip()
                occ  = float(occ_s) if occ_s else None
                bfac = float(b_s) if b_s else None

                elem = _slice(line, 77, 78).strip()

            except Exception as e:
                raise ValueError(f"{path}:{ln}: failed to parse ATOM/HETATM line: {e}\n{line}") from e

            key = AtomKey(
                record=rec,
                name=name,
                altloc=altloc,
                resname=resn,
                chain=chain,
                resseq=resseq,
                icode=icode,
                element=elem,
            )
            atoms.append(AtomRec(key=key, serial=serial, x=x, y=y, z=z, occ=occ, bfac=bfac))
    return atoms

def dist(a: AtomRec, b: AtomRec) -> float:
    return math.sqrt((a.x-b.x)**2 + (a.y-b.y)**2 + (a.z-b.z)**2)

def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("ref")
    ap.add_argument("out")
    ap.add_argument("--atol", type=float, default=1e-4, help="Absolute tolerance for coordinate differences (Å)")
    ap.add_argument("--ignore-serial", action="store_true")
    ap.add_argument("--ignore-occupancy", action="store_true")
    ap.add_argument("--ignore-bfactor", action="store_true")
    ap.add_argument("--max-report", type=int, default=20)
    args = ap.parse_args()

    A = parse_pdb(args.ref)
    B = parse_pdb(args.out)

    diffs = []

    if len(A) != len(B):
        diffs.append(f"Atom count differs: ref={len(A)} out={len(B)}")

    n = min(len(A), len(B))
    for i in range(n):
        a, b = A[i], B[i]

        if a.key != b.key:
            diffs.append(f"[{i}] Atom identity differs:\n  ref={a.key}\n  out={b.key}")

        if (not args.ignore_serial) and (a.serial != b.serial):
            diffs.append(f"[{i}] serial differs: ref={a.serial} out={b.serial}")

        if (not args.ignore_occupancy) and (a.occ != b.occ):
            diffs.append(f"[{i}] occupancy differs: ref={a.occ} out={b.occ}")

        if (not args.ignore_bfactor) and (a.bfac != b.bfac):
            diffs.append(f"[{i}] bfactor differs: ref={a.bfac} out={b.bfac}")

        d = dist(a, b)
        if d > args.atol:
            diffs.append(f"[{i}] coord mismatch |Δr|={d:.6g} Å > atol={args.atol} "
                         f"(ref=({a.x:.3f},{a.y:.3f},{a.z:.3f}) out=({b.x:.3f},{b.y:.3f},{b.z:.3f}))")

        if len(diffs) >= args.max_report:
            break

    if diffs:
        print(f"[compare_pdb] FAIL: showing up to {args.max_report} issues")
        for d in diffs:
            print(" -", d)
        return 1

    print("[compare_pdb] OK")
    return 0

if __name__ == "__main__":
    raise SystemExit(main())
