#!/usr/bin/env python3
"""
fc_chroma_key.py — key the studio background out of the FC frames -> RGBA.

ROBUST + ADAPTIVE + SAFE:
  1. Auto-sample the real background colour from each frame's TOP corners (a portrait's
     head is centred, so the top corners are reliably background).
  2. Mark every pixel within a colour-DISTANCE of that bg as "bg-similar."
  3. BORDER-CONNECTED CONSTRAINT (the key safety net): only clear bg-similar pixels that
     are CONNECTED TO THE FRAME EDGE. The true background and the hair-gap channels that
     open to it get cleared; her face/skin are interior islands, so even a grey-ish skin
     patch is PROTECTED and never keyed. This is what stops a naive colour key from eating
     her face. Because her face is safe, the threshold can run hot enough to catch a
     non-uniform background.

Deterministic, no model, ~instant. --lock-bg samples the bg ONCE and reuses it across a
sequence so the matte is temporally STABLE (kills the "hair appearing/reappearing"
flicker that per-frame model mattes caused).

Usage:
  fc_chroma_key.py <in_dir_or_file> <out_dir_or_file> [--thresh 60 --feather 1.5 --lock-bg] [--no-border]
  --thresh    : colour-distance from sampled bg below which a pixel is bg-similar (raise to remove more)
  --no-border : disable the border-connected safety net (key purely by colour — will eat grey-ish skin)
"""
import argparse, os, glob
import numpy as np
import cv2


def sample_bg(bgr, k=22):
    h, w = bgr.shape[:2]
    corners = np.concatenate([
        bgr[0:k, 0:k].reshape(-1, 3),
        bgr[0:k, w - k:w].reshape(-1, 3),
    ]).astype(np.float32)
    return np.median(corners, axis=0)            # median BGR of the two top corners


def key_frame(bgr, thresh, feather, bg=None, border_only=True):
    if bg is None:
        bg = sample_bg(bgr)
    dist = np.linalg.norm(bgr.astype(np.float32) - bg, axis=2)
    bgmask = (dist < thresh).astype(np.uint8)        # 1 where colour-similar to bg
    if border_only:
        # keep only bg-similar regions CONNECTED TO THE BORDER (true background +
        # hair-gap channels that open to it). Her grey-ish skin is an interior island
        # -> dropped from the mask -> stays opaque. This protects her face.
        num, labels = cv2.connectedComponents(bgmask, connectivity=8)
        border = np.concatenate([labels[0, :], labels[-1, :], labels[:, 0], labels[:, -1]])
        keep = np.unique(border)
        keep = keep[keep != 0]
        bgmask = (np.isin(labels, keep) & bgmask.astype(bool)).astype(np.uint8)
    alpha = np.where(bgmask.astype(bool), 0, 255).astype(np.uint8)
    alpha = cv2.medianBlur(alpha, 5)             # close stray specks, keep her solid
    if feather > 0:
        alpha = cv2.GaussianBlur(alpha, (0, 0), feather)
    return np.dstack([bgr, alpha]).astype(np.uint8)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("inp")
    ap.add_argument("out")
    ap.add_argument("--thresh", type=float, default=60.0)
    ap.add_argument("--feather", type=float, default=1.5)
    ap.add_argument("--lock-bg", action="store_true",
                    help="sample bg ONCE from the first frame and reuse (steadier across a sequence)")
    ap.add_argument("--no-border", action="store_true",
                    help="disable the border-connected safety net (pure colour key)")
    a = ap.parse_args()
    is_file = os.path.isfile(a.inp)
    files = [a.inp] if is_file else sorted(glob.glob(os.path.join(a.inp, "*.png")))
    if not is_file:
        os.makedirs(a.out, exist_ok=True)
    bg = None
    if a.lock_bg and files:
        first = cv2.imread(files[0])
        bg = sample_bg(first) if first is not None else None
    for f in files:
        bgr = cv2.imread(f)
        if bgr is None:
            continue
        rgba = key_frame(bgr, a.thresh, a.feather, bg=bg, border_only=not a.no_border)
        out = a.out if is_file else os.path.join(a.out, os.path.basename(f))
        cv2.imwrite(out, rgba)
    tag = ("locked " + str([int(x) for x in bg])) if bg is not None else "per-frame"
    print(f"keyed {len(files)} frame(s) | bg(BGR)={tag} | border_only={not a.no_border}")


if __name__ == "__main__":
    main()
