#!/usr/bin/env python3
"""
compare_h5.py — semantic HDF5 comparator for unit tests (dtype-agnostic).

Policy:
- Skip eigenvector datasets (gauge dependent): */eigenvectors, */eigenvectors2, ...
- Compare /QMdata/transition_dipoles/ind<N> allowing global sign flip and dedicated tolerances:
    accept if A ~= B OR A ~= -B within (td_rtol, td_atol)

Other:
- Recursive compare groups/datasets/attributes
- Robust compound dtype handling (dtype.names) with coercion and string normalization
- Ignore attrs, ignore path subtrees, ignore datasets by regex
"""

from __future__ import annotations

import argparse
import re
import sys
from dataclasses import dataclass
from typing import Any, Iterable, List, Set, Tuple

import h5py
import numpy as np


@dataclass
class Diff:
    path: str
    kind: str
    message: str


def should_ignore_path(path: str, ignore_paths: Iterable[str]) -> bool:
    for p in ignore_paths:
        p = p.rstrip("/")
        if not p:
            continue
        if path == p or path.startswith(p + "/"):
            return True
    return False


def should_ignore_dataset(path: str, ignore_dataset_regex: List[re.Pattern[str]]) -> bool:
    return any(rx.search(path) for rx in ignore_dataset_regex)


def normalize_bytes(x: Any) -> Any:
    if isinstance(x, (bytes, np.bytes_)):
        return x.decode(errors="replace")
    return x


def _as_object_array(x: np.ndarray) -> np.ndarray:
    return x.astype(object, copy=False)


def _normalize_string_field(arr: np.ndarray) -> np.ndarray:
    obj = _as_object_array(arr)
    return np.vectorize(lambda t: normalize_bytes(t))(obj)


def _dtype_is_string_like(dt: np.dtype) -> bool:
    return dt.kind in ("S", "U", "O")


def compare_scalars_generic(av: Any, bv: Any, rtol: float, atol: float) -> bool:
    if isinstance(av, np.generic):
        av = av.item()
    if isinstance(bv, np.generic):
        bv = bv.item()

    av = normalize_bytes(av)
    bv = normalize_bytes(bv)

    if isinstance(av, (tuple, list)) and isinstance(bv, (tuple, list)) and len(av) == len(bv):
        return all(compare_scalars_generic(x, y, rtol, atol) for x, y in zip(av, bv))

    if isinstance(av, (int, np.integer)) and isinstance(bv, (int, np.integer)):
        return int(av) == int(bv)

    if isinstance(av, (float, np.floating)) and isinstance(bv, (float, np.floating)):
        return bool(np.isclose(av, bv, rtol=rtol, atol=atol, equal_nan=True))

    if isinstance(av, (complex, np.complexfloating)) and isinstance(bv, (complex, np.complexfloating)):
        avc = complex(av)
        bvc = complex(bv)
        return bool(
            np.isclose(avc.real, bvc.real, rtol=rtol, atol=atol, equal_nan=True)
            and np.isclose(avc.imag, bvc.imag, rtol=rtol, atol=atol, equal_nan=True)
        )

    return av == bv


def _report_first_mismatches(
    diffs: List[Diff],
    path: str,
    kind: str,
    coords: List[Tuple[int, ...]],
    a: np.ndarray,
    b: np.ndarray,
    max_report: int,
    prefix: str = "",
) -> None:
    for coord in coords[:max_report]:
        diffs.append(Diff(path, kind, f"{prefix}Mismatch at {coord}: A={a[coord]!r} vs B={b[coord]!r}"))
    if len(coords) > max_report:
        diffs.append(Diff(path, kind, f"{prefix}... {len(coords) - max_report} more mismatches"))


def compare_arrays_generic(
    a: np.ndarray,
    b: np.ndarray,
    path: str,
    diffs: List[Diff],
    rtol: float,
    atol: float,
    max_report: int,
) -> None:
    if a.shape != b.shape:
        diffs.append(Diff(path, "shape", f"Shape differs: A {a.shape} vs B {b.shape}"))
        return

    a_struct = a.dtype.names is not None
    b_struct = b.dtype.names is not None

    if a_struct or b_struct:
        if a_struct and not b_struct:
            try:
                b = np.array(b, dtype=a.dtype)
            except Exception:
                diffs.append(Diff(path, "dtype", f"One side structured and other not (coercion failed): A {a.dtype} vs B {b.dtype}"))
                return
        if b_struct and not a_struct:
            try:
                a = np.array(a, dtype=b.dtype)
            except Exception:
                diffs.append(Diff(path, "dtype", f"One side structured and other not (coercion failed): A {a.dtype} vs B {b.dtype}"))
                return

        if (a.dtype.names is None) or (b.dtype.names is None):
            diffs.append(Diff(path, "dtype", f"Structured dtype detection mismatch: A {a.dtype} vs B {b.dtype}"))
            return

        if a.dtype.names != b.dtype.names:
            diffs.append(Diff(path, "dtype", f"Field names differ: A {a.dtype.names} vs B {b.dtype.names}"))
            return

        for name in a.dtype.names:
            try:
                af = a[name]
                bf = b[name]

                if _dtype_is_string_like(af.dtype) or _dtype_is_string_like(bf.dtype):
                    afn = _normalize_string_field(af)
                    bfn = _normalize_string_field(bf)
                    if not np.array_equal(afn, bfn):
                        idxs = np.argwhere(afn != bfn)
                        coords = [tuple(map(int, c)) for c in idxs]
                        _report_first_mismatches(diffs, f"{path}:{name}", "value", coords, afn, bfn, max_report)
                    continue

                compare_arrays_generic(af, bf, f"{path}:{name}", diffs, rtol, atol, max_report)
            except Exception as e:
                diffs.append(Diff(f"{path}:{name}", "compare_error", f"Error comparing field '{name}': {e}"))
        return

    if a.dtype.kind == "O" or b.dtype.kind == "O":
        a_flat = a.ravel()
        b_flat = b.ravel()
        mismatches = []
        for i, (av, bv) in enumerate(zip(a_flat, b_flat)):
            if not compare_scalars_generic(av, bv, rtol, atol):
                mismatches.append(i)
                if len(mismatches) >= max_report:
                    break
        for i in mismatches:
            diffs.append(Diff(path, "value", f"Mismatch at flat index {i}: A={a_flat[i]!r} vs B={b_flat[i]!r} (rtol={rtol}, atol={atol})"))
        if mismatches and a_flat.size > max_report:
            diffs.append(Diff(path, "value", "… more mismatches may exist (report truncated)"))
        return

    if a.dtype.kind in ("S", "U") or b.dtype.kind in ("S", "U"):
        a2 = _normalize_string_field(a)
        b2 = _normalize_string_field(b)
        if not np.array_equal(a2, b2):
            idxs = np.argwhere(a2 != b2)
            coords = [tuple(map(int, c)) for c in idxs]
            _report_first_mismatches(diffs, path, "value", coords, a2, b2, max_report)
        return

    if (
        np.issubdtype(a.dtype, np.floating)
        or np.issubdtype(a.dtype, np.complexfloating)
        or np.issubdtype(b.dtype, np.floating)
        or np.issubdtype(b.dtype, np.complexfloating)
    ):
        ok = np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=True)
        if not np.all(ok):
            idxs = np.argwhere(~ok)
            for coord in [tuple(map(int, c)) for c in idxs][:max_report]:
                diffs.append(Diff(path, "value", f"Mismatch at {coord}: A={a[coord]} vs B={b[coord]} (rtol={rtol}, atol={atol})"))
            if idxs.shape[0] > max_report:
                diffs.append(Diff(path, "value", f"... {idxs.shape[0] - max_report} more mismatches"))
        return

    if not np.array_equal(a, b):
        idxs = np.argwhere(a != b)
        coords = [tuple(map(int, c)) for c in idxs]
        _report_first_mismatches(diffs, path, "value", coords, a, b, max_report)


_TDIPOLE_RE = re.compile(r"^/QMdata/transition_dipoles/ind\d+$")


def compare_transition_dipoles_allow_sign(
    a: np.ndarray,
    b: np.ndarray,
    path: str,
    diffs: List[Diff],
    rtol: float,
    atol: float,
    max_report: int,
) -> None:
    if a.shape != b.shape:
        diffs.append(Diff(path, "shape", f"Shape differs: A {a.shape} vs B {b.shape}"))
        return

    if a.size == 0 and b.size == 0:
        return
    if a.size == 0 or b.size == 0:
        diffs.append(Diff(path, "value", f"One is empty: A size={a.size} vs B size={b.size}"))
        return

    # Check both sign conventions
    ok_same = np.isclose(a, b, rtol=rtol, atol=atol, equal_nan=True)
    if np.all(ok_same):
        return

    ok_flip = np.isclose(a, -b, rtol=rtol, atol=atol, equal_nan=True)
    if np.all(ok_flip):
        return

    # Report mismatches for the better fit (same vs flipped)
    bad_same = np.argwhere(~ok_same)
    bad_flip = np.argwhere(~ok_flip)

    if bad_flip.shape[0] < bad_same.shape[0]:
        coords = [tuple(map(int, c)) for c in bad_flip]
        for coord in coords[:max_report]:
            diffs.append(Diff(
                path, "value",
                f"Mismatch (even allowing sign flip) at {coord}: A={a[coord]} vs -B={(-b)[coord]} (rtol={rtol}, atol={atol})"
            ))
        if len(coords) > max_report:
            diffs.append(Diff(path, "value", f"... {len(coords) - max_report} more mismatches"))
    else:
        coords = [tuple(map(int, c)) for c in bad_same]
        for coord in coords[:max_report]:
            diffs.append(Diff(
                path, "value",
                f"Mismatch at {coord}: A={a[coord]} vs B={b[coord]} (rtol={rtol}, atol={atol})"
            ))
        if len(coords) > max_report:
            diffs.append(Diff(path, "value", f"... {len(coords) - max_report} more mismatches"))


def compare_attrs(
    a_obj: h5py.Dataset | h5py.Group,
    b_obj: h5py.Dataset | h5py.Group,
    path: str,
    diffs: List[Diff],
    ignore_attrs: Set[str],
    rtol: float,
    atol: float,
    max_report: int,
) -> None:
    a_keys = set(a_obj.attrs.keys())
    b_keys = set(b_obj.attrs.keys())

    a_keys_eff = {k for k in a_keys if k not in ignore_attrs}
    b_keys_eff = {k for k in b_keys if k not in ignore_attrs}

    for k in sorted(a_keys_eff - b_keys_eff):
        diffs.append(Diff(path, "attr_missing_in_b", f"Attribute '{k}' exists in A but not B"))
    for k in sorted(b_keys_eff - a_keys_eff):
        diffs.append(Diff(path, "attr_missing_in_a", f"Attribute '{k}' exists in B but not A"))

    for k in sorted(a_keys_eff & b_keys_eff):
        av = normalize_bytes(a_obj.attrs.get(k))
        bv = normalize_bytes(b_obj.attrs.get(k))
        try:
            if isinstance(av, np.ndarray) or isinstance(bv, np.ndarray):
                compare_arrays_generic(np.array(av), np.array(bv), f"{path}@{k}", diffs, rtol, atol, max_report)
            else:
                if not compare_scalars_generic(av, bv, rtol, atol):
                    diffs.append(Diff(path, "attr_value", f"Attribute '{k}' differs: A={av!r} vs B={bv!r}"))
        except Exception as e:
            diffs.append(Diff(path, "attr_compare_error", f"Attribute '{k}' compare error: {e}"))


def compare_dataset(
    a_ds: h5py.Dataset,
    b_ds: h5py.Dataset,
    path: str,
    diffs: List[Diff],
    rtol: float,
    atol: float,
    td_rtol: float,
    td_atol: float,
    max_report: int,
) -> None:
    if a_ds.shape != b_ds.shape:
        diffs.append(Diff(path, "shape", f"Shape differs: A {a_ds.shape} vs B {b_ds.shape}"))
        return

    if a_ds.dtype != b_ds.dtype:
        if (a_ds.dtype.names is None) or (b_ds.dtype.names is None):
            diffs.append(Diff(path, "dtype", f"Dtype differs: A {a_ds.dtype} vs B {b_ds.dtype}"))

    try:
        a_arr = np.array(a_ds[...])
        b_arr = np.array(b_ds[...])
    except Exception as e:
        diffs.append(Diff(path, "read_error", f"Failed to read dataset: {e}"))
        return

    if _TDIPOLE_RE.match(path):
        compare_transition_dipoles_allow_sign(a_arr, b_arr, path, diffs, td_rtol, td_atol, max_report)
        return

    compare_arrays_generic(a_arr, b_arr, path, diffs, rtol, atol, max_report)


def walk_and_compare(
    a_file: h5py.File,
    b_file: h5py.File,
    ignore_paths: List[str],
    ignore_attrs: Set[str],
    ignore_dataset_regex: List[re.Pattern[str]],
    verbose_skips: bool,
    rtol: float,
    atol: float,
    td_rtol: float,
    td_atol: float,
    max_report: int,
    diffs: List[Diff],
) -> None:
    def recurse(path: str) -> None:
        if should_ignore_path(path, ignore_paths):
            return

        in_a = path in a_file
        in_b = path in b_file

        if not in_a and in_b:
            diffs.append(Diff(path, "missing_in_a", "Path exists in B but not A"))
            return
        if in_a and not in_b:
            diffs.append(Diff(path, "missing_in_b", "Path exists in A but not B"))
            return
        if not in_a and not in_b:
            return

        a_obj = a_file[path]
        b_obj = b_file[path]

        if isinstance(a_obj, h5py.Group) != isinstance(b_obj, h5py.Group):
            diffs.append(Diff(path, "type", f"Type differs: A {type(a_obj)} vs B {type(b_obj)}"))
            return

        compare_attrs(a_obj, b_obj, path, diffs, ignore_attrs, rtol, atol, max_report)

        if isinstance(a_obj, h5py.Group):
            a_children = set(a_obj.keys())
            b_children = set(b_obj.keys())

            for k in sorted(a_children - b_children):
                child = f"{path.rstrip('/')}/{k}" if path != "/" else f"/{k}"
                if not should_ignore_path(child, ignore_paths):
                    diffs.append(Diff(path, "member_missing_in_b", f"Member '{k}' exists in A but not B"))

            for k in sorted(b_children - a_children):
                child = f"{path.rstrip('/')}/{k}" if path != "/" else f"/{k}"
                if not should_ignore_path(child, ignore_paths):
                    diffs.append(Diff(path, "member_missing_in_a", f"Member '{k}' exists in B but not A"))

            for k in sorted(a_children & b_children):
                child = f"{path.rstrip('/')}/{k}" if path != "/" else f"/{k}"
                recurse(child)
        else:
            if not _TDIPOLE_RE.match(path):
                if should_ignore_dataset(path, ignore_dataset_regex):
                    if verbose_skips:
                        print(f"[compare_h5] SKIP dataset {path}", file=sys.stderr)
                    return
            compare_dataset(a_obj, b_obj, path, diffs, rtol, atol, td_rtol, td_atol, max_report)

    recurse("/")


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("file_a", help="Reference / expected .h5")
    ap.add_argument("file_b", help="Output / actual .h5")

    ap.add_argument("--rtol", type=float, default=1e-7, help="Global relative tolerance for floats/complex")
    ap.add_argument("--atol", type=float, default=1e-6, help="Global absolute tolerance for floats/complex")

    ap.add_argument("--td-rtol", type=float, default=None,
                    help="Transition dipole relative tolerance (defaults to --rtol)")
    ap.add_argument("--td-atol", type=float, default=None,
                    help="Transition dipole absolute tolerance (defaults to max(--atol, 2e-6))")

    ap.add_argument("--ignore-attr", action="append", default=[],
                    help="Attribute name to ignore (repeatable)")
    ap.add_argument("--ignore-path", action="append", default=[],
                    help="HDF5 path subtree to ignore (repeatable)")
    ap.add_argument("--ignore-dataset-regex", action="append", default=[],
                    help="Regex to ignore dataset paths (repeatable)")
    ap.add_argument("--verbose-skips", action="store_true", help="Print skipped datasets to stderr.")

    ap.add_argument("--max-report", type=int, default=20)
    ap.add_argument("--max-diffs", type=int, default=500)
    args = ap.parse_args()

    td_rtol = args.td_rtol if args.td_rtol is not None else args.rtol
    # default: slightly looser than global, because dipoles often wobble at ~1e-6
    td_atol = args.td_atol if args.td_atol is not None else max(args.atol, 2e-6)

    default_skip = [
        r"/eigenvectors\d*$",
    ]
    ignore_dataset_regex = [re.compile(p) for p in (default_skip + (args.ignore_dataset_regex or []))]

    diffs: List[Diff] = []
    try:
        with h5py.File(args.file_a, "r") as fa, h5py.File(args.file_b, "r") as fb:
            walk_and_compare(
                fa, fb,
                ignore_paths=list(args.ignore_path),
                ignore_attrs=set(args.ignore_attr),
                ignore_dataset_regex=ignore_dataset_regex,
                verbose_skips=args.verbose_skips,
                rtol=args.rtol,
                atol=args.atol,
                td_rtol=td_rtol,
                td_atol=td_atol,
                max_report=args.max_report,
                diffs=diffs,
            )
    except OSError as e:
        print(f"[compare_h5] ERROR opening files: {e}", file=sys.stderr)
        return 2

    if diffs:
        print(f"[compare_h5] FAIL: {len(diffs)} differences found")
        for d in diffs[:args.max_diffs]:
            print(f" - {d.path}: {d.kind}: {d.message}")
        if len(diffs) > args.max_diffs:
            print(f" ... {len(diffs) - args.max_diffs} more differences not shown")
        return 1

    print("[compare_h5] OK: files match (eigenvectors skipped; transition dipoles sign-aware with dedicated tolerances)")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
