#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
restauration.py — general photo restauration using LaMa inpainting ONLY.

Key idea (Option A):
- We detect a face (if present) and build a big "protected" region (hair + shoulders).
- Scratches are detected everywhere, BUT we REMOVE all mask pixels inside the protected region.
- Result: background/clothes get restored, face stays untouched (no hallucinated face).

Usage:
  ./.venv/bin/python restauration.py -i restoreme.jpg -o out.png --passes 2 --save-mask mask.png

Optional:
  --no-protect-face        Disable the protection
  --save-mask-steps        Save mask_pass1.png etc.
  --protect-expand 0.75    Stronger/weaker expansion around the face
"""

import argparse
import sys
from pathlib import Path

import cv2
import numpy as np
from PIL import Image

from simple_lama_inpainting import SimpleLama


# ----------------------------
# IO helpers
# ----------------------------
def load_bgr(path: Path) -> np.ndarray:
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if img is None:
        raise RuntimeError(f"Could not read image: {path}")
    return img


def write_image(path: Path, bgr: np.ndarray) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    ok = cv2.imwrite(str(path), bgr)
    if not ok:
        raise RuntimeError(f"Could not write: {path}")


def ensure_uint8_mask(mask01: np.ndarray) -> np.ndarray:
    """mask01: bool or 0/1 or 0..255 -> returns 0/255 uint8."""
    m = mask01.copy()
    if m.dtype != np.uint8:
        m = (m > 0).astype(np.uint8) * 255
    else:
        if m.max() <= 1:
            m = (m > 0).astype(np.uint8) * 255
        else:
            m = (m > 0).astype(np.uint8) * 255
    return m


def resize_to(h_w_src, img_or_mask, is_mask=False):
    """Resize to (h,w) if needed."""
    h, w = h_w_src
    if img_or_mask.shape[0] == h and img_or_mask.shape[1] == w:
        return img_or_mask
    interp = cv2.INTER_NEAREST if is_mask else cv2.INTER_AREA
    return cv2.resize(img_or_mask, (w, h), interpolation=interp)


# ----------------------------
# Protection mask (Option A)
# ----------------------------
def build_protect_mask_face(
    bgr: np.ndarray,
    expand: float = 0.70,
    min_face: int = 40
) -> np.ndarray:
    """
    Returns uint8 mask 0/255 where 255 = protected (LaMa must NOT touch).
    If no face detected -> all zeros.
    """
    h, w = bgr.shape[:2]
    protect = np.zeros((h, w), dtype=np.uint8)

    # Use OpenCV built-in haarcascade (no downloads)
    cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
    face_cascade = cv2.CascadeClassifier(cascade_path)
    if face_cascade.empty():
        return protect  # fallback: no protection

    gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
    gray = cv2.equalizeHist(gray)

    faces = face_cascade.detectMultiScale(
        gray,
        scaleFactor=1.08,
        minNeighbors=5,
        minSize=(min_face, min_face)
    )
    if len(faces) == 0:
        return protect

    # Take the largest face (most likely main subject)
    faces = sorted(faces, key=lambda r: r[2] * r[3], reverse=True)
    x, y, fw, fh = faces[0]

    # Expand region to include hair + shoulders (heuristics)
    # Face box: expand sideways a lot, expand up (hair), expand down (neck+shoulders)
    pad_x = int(fw * (0.60 + expand))          # strong sideways
    pad_up = int(fh * (0.80 + expand))         # include hair
    pad_dn = int(fh * (1.70 + 1.2 * expand))   # shoulders/chest

    x1 = max(0, x - pad_x)
    y1 = max(0, y - pad_up)
    x2 = min(w, x + fw + pad_x)
    y2 = min(h, y + fh + pad_dn)

    protect[y1:y2, x1:x2] = 255

    # Slightly feather edges (so we don't create hard seams)
    protect = cv2.GaussianBlur(protect, (0, 0), 6.0)
    protect = (protect > 16).astype(np.uint8) * 255

    return protect


# ----------------------------
# Scratch/crack detection
# ----------------------------
def clamp_mask_ratio(mask: np.ndarray, max_mask_ratio: float) -> np.ndarray:
    """Avoid masking insane areas which will cause 'blurry repaint'."""
    if max_mask_ratio <= 0:
        return mask
    h, w = mask.shape[:2]
    ratio = float(np.count_nonzero(mask)) / float(h * w)
    if ratio <= max_mask_ratio:
        return mask

    m = mask.copy()
    it = 0
    while float(np.count_nonzero(m)) / float(h * w) > max_mask_ratio and it < 10:
        m = cv2.erode(
            m,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)),
            iterations=1
        )
        it += 1
    return m


def multiscale_scratch_mask(
    bgr: np.ndarray,
    sensitivity: float = 0.55,
    scales=(7, 11, 17),
    dilate: int = 1,
    close_iters: int = 0,
    blur_sigma: float = 0.6,
    max_mask_ratio: float = 0.10,
) -> np.ndarray:
    """
    Returns 0/255 mask where 255 will be inpainted.

    sensitivity:
      ~0.35  conservative
      ~0.55  balanced
      ~0.75  aggressive
    """
    gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)

    gray_s = cv2.GaussianBlur(gray, (0, 0), blur_sigma) if blur_sigma and blur_sigma > 0 else gray

    acc = np.zeros_like(gray_s, dtype=np.float32)

    for k in scales:
        k = int(max(3, k))
        if k % 2 == 0:
            k += 1
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))

        th = cv2.morphologyEx(gray_s, cv2.MORPH_TOPHAT, kernel)      # bright lines
        bh = cv2.morphologyEx(gray_s, cv2.MORPH_BLACKHAT, kernel)    # dark lines

        th = th.astype(np.float32)
        bh = bh.astype(np.float32)

        if th.max() > 0:
            th /= th.max()
        if bh.max() > 0:
            bh /= bh.max()

        acc += (th + bh)

    if acc.max() > 0:
        acc /= acc.max()

    med = float(np.median(acc))
    mad = float(np.median(np.abs(acc - med))) + 1e-6

    # higher sensitivity => lower threshold => mask more
    k = 2.6 - (np.clip(sensitivity, 0.0, 1.0) * 1.6)
    thr = med + k * mad

    mask = (acc >= thr).astype(np.uint8) * 255

    if int(close_iters) > 0:
        mask = cv2.morphologyEx(
            mask,
            cv2.MORPH_CLOSE,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)),
            iterations=int(close_iters),
        )

    if int(dilate) > 0:
        mask = cv2.dilate(
            mask,
            cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)),
            iterations=int(dilate),
        )

    mask = clamp_mask_ratio(mask, max_mask_ratio=max_mask_ratio)
    return mask.astype(np.uint8)


# ----------------------------
# LaMa inpainting
# ----------------------------
def lama_inpaint(bgr: np.ndarray, mask: np.ndarray, lama: SimpleLama) -> np.ndarray:
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    pil_img = Image.fromarray(rgb)
    pil_mask = Image.fromarray(ensure_uint8_mask(mask), mode="L")

    out_pil = lama(pil_img, pil_mask)  # PIL RGB
    out_rgb = np.array(out_pil)
    out_bgr = cv2.cvtColor(out_rgb, cv2.COLOR_RGB2BGR)
    return out_bgr


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("-i", "--input", required=True, help="Input image path")
    ap.add_argument("-o", "--output", required=True, help="Output image path (.png recommended)")
    ap.add_argument("--save-mask", default="", help="Optional: save FINAL combined mask (PNG)")
    ap.add_argument("--save-mask-steps", action="store_true", help="Optional: save mask per pass (mask_passX.png)")

    ap.add_argument("--passes", type=int, default=2, help="How many detect->inpaint passes (2 recommended)")
    ap.add_argument("--sensitivity", type=float, default=0.52, help="0.35 conservative .. 0.75 aggressive")
    ap.add_argument("--scales", default="7,11,17", help="Comma-separated odd kernel sizes (multi-scale)")
    ap.add_argument("--dilate", type=int, default=1, help="Mask dilate iterations")
    ap.add_argument("--close-iters", type=int, default=0, help="Mask close iterations")
    ap.add_argument("--blur-sigma", type=float, default=0.6, help="Pre-blur sigma for detection")
    ap.add_argument("--max-mask-ratio", type=float, default=0.08, help="Safety: max masked area ratio per pass")

    # Option A controls
    ap.add_argument("--no-protect-face", action="store_true", help="Disable face/upper-body protection")
    ap.add_argument("--protect-expand", type=float, default=0.70, help="How much to expand protected region around face")
    ap.add_argument("--save-protect", default="", help="Optional: save protect mask PNG")

    args = ap.parse_args()

    inp = Path(args.input)
    outp = Path(args.output)

    if not inp.exists():
        print(f"[ERROR] Input not found: {inp}")
        sys.exit(1)

    bgr = load_bgr(inp)
    base_h, base_w = bgr.shape[:2]

    try:
        scales = tuple(int(x.strip()) for x in args.scales.split(",") if x.strip())
        if not scales:
            scales = (7, 11, 17)
    except Exception:
        scales = (7, 11, 17)

    passes = int(max(1, args.passes))
    sensitivity = float(np.clip(args.sensitivity, 0.0, 1.0))

    # Build protect mask once (based on original)
    protect = np.zeros((base_h, base_w), dtype=np.uint8)
    if not args.no_protect_face:
        protect = build_protect_mask_face(bgr, expand=float(args.protect_expand))
        protected_ratio = float(np.count_nonzero(protect)) / float(protect.size)
        print(f"[INFO] Protect mask (face/upper-body): {protected_ratio*100:.2f}% protected")
        if args.save_protect:
            cv2.imwrite(str(Path(args.save_protect)), protect)
            print(f"[INFO] Protect mask written: {args.save_protect}")
    else:
        print("[INFO] Face protection: OFF")

    print("[INFO] LaMa init… (CUDA if available, else CPU)")
    lama = SimpleLama()

    combined_mask = np.zeros((base_h, base_w), dtype=np.uint8)
    cur = resize_to((base_h, base_w), bgr, is_mask=False).copy()

    for p in range(1, passes + 1):
        if cur.shape[0] != base_h or cur.shape[1] != base_w:
            cur = resize_to((base_h, base_w), cur, is_mask=False)

        print(f"[INFO] Pass {p}/{passes}: detecting scratches…")
        mask = multiscale_scratch_mask(
            cur,
            sensitivity=sensitivity,
            scales=scales,
            dilate=args.dilate,
            close_iters=args.close_iters,
            blur_sigma=args.blur_sigma,
            max_mask_ratio=args.max_mask_ratio,
        )
        if mask.shape[0] != base_h or mask.shape[1] != base_w:
            mask = resize_to((base_h, base_w), mask, is_mask=True)

        # OPTION A: remove mask inside protected region (no LaMa on faces)
        if protect is not None and np.any(protect):
            # protect==255 => force mask to 0 there
            mask = cv2.bitwise_and(mask, cv2.bitwise_not(protect))

        masked_ratio = float(np.count_nonzero(mask)) / float(mask.size)
        print(f"[INFO] Pass {p}: masked ratio = {masked_ratio*100:.2f}%")

        if masked_ratio < 0.001 and p > 1:
            print("[INFO] Mask very small → stopping early.")
            break

        combined_mask = cv2.bitwise_or(combined_mask, mask)

        if args.save_mask_steps:
            mp = Path(f"mask_pass{p}.png")
            cv2.imwrite(str(mp), mask)
            print(f"[INFO] Wrote: {mp}")

        print(f"[INFO] Pass {p}: inpainting…")
        cur = lama_inpaint(cur, mask, lama)

        if cur.shape[0] != base_h or cur.shape[1] != base_w:
            cur = resize_to((base_h, base_w), cur, is_mask=False)

    if args.save_mask:
        mpath = Path(args.save_mask)
        cv2.imwrite(str(mpath), combined_mask)
        print(f"[INFO] Mask written: {mpath}")

    write_image(outp, cur)
    print(f"[OK] Wrote: {outp}")


if __name__ == "__main__":
    main()

