import cv2
import numpy as np
import os
import sys

VIDEO_PATH = "videoplayback.mp4"
OUTPUT_DIR = "./output"
FRAME_START = 13020   # 7:14
FRAME_END   = 14580   # 8:06
REINIT_INTERVAL = 200
SPEED_THRESHOLDS = [25, 40, 60, 80]
DRIFT_WINDOW = 500
os.makedirs(OUTPUT_DIR, exist_ok=True)

cap = cv2.VideoCapture(VIDEO_PATH)
if not cap.isOpened():
    print(f"ERROR: Cannot open {VIDEO_PATH}")
    sys.exit(1)

total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
print(f"Video: {w}x{h}, {fps:.1f} fps, {total_frames} frames")
print(f"Processing frames {FRAME_START}–{FRAME_END}  "
      f"({FRAME_START/fps:.1f}s – {FRAME_END/fps:.1f}s)")

cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
browser_frame_idx = [0]
playing = [False]

SCALE = min(2, 960 // w, 540 // h)
if SCALE < 1:
    SCALE = 1

cv2.namedWindow("Browse Video", cv2.WINDOW_AUTOSIZE)

def read_frame_at(idx):
    idx = max(0, min(idx, total_frames - 1))
    cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
    ret, frm = cap.read()
    return frm if ret else None

selected_frame = None
while True:
    frm = read_frame_at(browser_frame_idx[0])
    if frm is None:
        browser_frame_idx[0] = max(0, browser_frame_idx[0] - 1)
        continue
    disp = cv2.resize(frm, (w * SCALE, h * SCALE), interpolation=cv2.INTER_LINEAR)
    t_sec = browser_frame_idx[0] / fps
    t_min = int(t_sec // 60)
    t_s = t_sec % 60
    status = "PLAYING" if playing[0] else "PAUSED"
    cv2.putText(disp, f"Frame {browser_frame_idx[0]}/{total_frames}  {t_min}:{t_s:05.2f}  [{status}]",
                (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 255, 255), 1)
    cv2.putText(disp, "SPACE:play/pause  </>:step  A/D:+-30  Q/W:+-300  ENTER:select",
                (10, h * SCALE - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (200, 200, 200), 1)
    cv2.imshow("Browse Video", disp)

    wait_ms = 1 if playing[0] else 50
    key = cv2.waitKey(wait_ms) & 0xFF

    if key == 32:  # SPACE — toggle play
        playing[0] = not playing[0]
    elif key == 13:  # ENTER — select this frame
        selected_frame = frm.copy()
        break
    elif key == 27 or key == ord('x'):  # ESC — quit
        cv2.destroyAllWindows()
        cap.release()
        print("Cancelled.")
        sys.exit(0)
    elif key == 81 or key == 2:  # LEFT arrow
        playing[0] = False
        browser_frame_idx[0] = max(0, browser_frame_idx[0] - 1)
    elif key == 83 or key == 3:  # RIGHT arrow
        playing[0] = False
        browser_frame_idx[0] = min(total_frames - 1, browser_frame_idx[0] + 1)
    elif key == ord('a'):
        playing[0] = False
        browser_frame_idx[0] = max(0, browser_frame_idx[0] - 30)
    elif key == ord('d'):
        playing[0] = False
        browser_frame_idx[0] = min(total_frames - 1, browser_frame_idx[0] + 30)
    elif key == ord('q'):
        playing[0] = False
        browser_frame_idx[0] = max(0, browser_frame_idx[0] - 300)
    elif key == ord('w'):
        playing[0] = False
        browser_frame_idx[0] = min(total_frames - 1, browser_frame_idx[0] + 300)
    elif playing[0]:
        browser_frame_idx[0] = min(total_frames - 1, browser_frame_idx[0] + 1)
        if browser_frame_idx[0] >= total_frames - 1:
            playing[0] = False

cv2.destroyWindow("Browse Video")
roi_frame_idx = browser_frame_idx[0]
first_frame = selected_frame
print(f"\nSelected frame {roi_frame_idx} ({roi_frame_idx/fps:.1f}s) for ROI selection")

def select_polygon(title, frame):
    """Let user click corners to define a polygon ROI.
    Returns (polygon_pts, bounding_rect) where polygon_pts is Nx2 array
    in original frame coords, and bounding_rect is (x,y,w,h)."""
    big = cv2.resize(frame, (w * SCALE, h * SCALE), interpolation=cv2.INTER_LINEAR)
    pts_scaled = []  # points in scaled coords
    done = [False]

    def mouse_cb(event, x, y, flags, param):
        if done[0]:
            return
        if event == cv2.EVENT_LBUTTONDOWN:
            pts_scaled.append((x, y))
        elif event == cv2.EVENT_RBUTTONDOWN and len(pts_scaled) > 0:
            pts_scaled.pop()  # undo last point

    cv2.namedWindow(title, cv2.WINDOW_AUTOSIZE)
    cv2.setMouseCallback(title, mouse_cb)

    while not done[0]:
        disp = big.copy()
        # Draw completed edges
        for k in range(len(pts_scaled)):
            cv2.circle(disp, pts_scaled[k], 5, (0, 255, 0), -1)
            cv2.putText(disp, str(k+1), (pts_scaled[k][0]+8, pts_scaled[k][1]-4),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
            if k > 0:
                cv2.line(disp, pts_scaled[k-1], pts_scaled[k], (0, 255, 0), 2)
        # Close the polygon visually if >= 3 pts
        if len(pts_scaled) >= 3:
            cv2.line(disp, pts_scaled[-1], pts_scaled[0], (0, 255, 0), 1)
        # Instructions
        cv2.putText(disp, "L-click: add corner | R-click: undo | ENTER: confirm | C: cancel",
                    (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 255), 1)
        cv2.putText(disp, f"{len(pts_scaled)} corners selected (min 3)",
                    (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 255), 1)
        cv2.imshow(title, disp)
        key = cv2.waitKey(30) & 0xFF
        if key in (13, 32):  # ENTER or SPACE
            if len(pts_scaled) >= 3:
                done[0] = True
        elif key == ord('c') or key == 27:
            cv2.destroyWindow(title)
            print(f"Selection cancelled for '{title}', exiting.")
            sys.exit(1)

    cv2.destroyWindow(title)

    # Convert to original frame coordinates
    poly = np.array([(int(x / SCALE), int(y / SCALE)) for x, y in pts_scaled], dtype=np.int32)
    # Bounding rect
    rx, ry, rw, rh = cv2.boundingRect(poly)
    return poly, (rx, ry, rw, rh)

def poly_centroid(poly):
    """Centroid of a polygon (Nx2 array)."""
    M = cv2.moments(poly)
    if M['m00'] != 0:
        return (M['m10'] / M['m00'], M['m01'] / M['m00'])
    return (float(poly[:, 0].mean()), float(poly[:, 1].mean()))

def make_masked_template(gray, poly, bbox):
    """Extract template from gray image, masked by polygon."""
    rx, ry, rw, rh = bbox
    crop = gray[ry:ry+rh, rx:rx+rw].copy()
    # Build mask in crop coordinates
    mask = np.zeros((rh, rw), dtype=np.uint8)
    shifted_poly = poly.copy()
    shifted_poly[:, 0] -= rx
    shifted_poly[:, 1] -= ry
    cv2.fillPoly(mask, [shifted_poly], 255)
    # Zero out pixels outside polygon
    crop[mask == 0] = 0
    return crop, mask

print("\n" + "=" * 60)
print("STEP 1: Click corners around the NOZZLE / carriage")
print("  L-click to add corner, R-click to undo, ENTER to confirm")
print("=" * 60)
nozzle_poly, nozzle_roi = select_polygon("Select NOZZLE polygon", first_frame)
print(f"Nozzle polygon: {nozzle_poly.tolist()}")
print(f"Nozzle bounding rect: {nozzle_roi}")

print("\n" + "=" * 60)
print("STEP 2: Click corners around a BED feature")
print("  L-click to add corner, R-click to undo, ENTER to confirm")
print("=" * 60)
bed_poly, bed_roi = select_polygon("Select BED polygon", first_frame)
print(f"Bed polygon: {bed_poly.tolist()}")
print(f"Bed bounding rect: {bed_roi}")

# ─── Tracker helpers ─────────────────────────────────────────────────────────
def _create_one_tracker():
    for factory in [
        lambda: cv2.TrackerCSRT.create(),
        lambda: cv2.TrackerCSRT_create(),
        lambda: cv2.legacy.TrackerCSRT_create(),
        lambda: cv2.TrackerKCF.create(),
        lambda: cv2.TrackerKCF_create(),
        lambda: cv2.legacy.TrackerKCF_create(),
        lambda: cv2.legacy.TrackerMOSSE_create(),
    ]:
        try:
            t = factory()
            if t is not None:
                return t
        except (AttributeError, cv2.error):
            continue
    return None

TRACKER_AVAILABLE = _create_one_tracker() is not None
print(f"Tracker available: {TRACKER_AVAILABLE}")

def create_trackers(frame, nozzle_bbox, bed_bbox):
    if not TRACKER_AVAILABLE:
        return None, None
    nt = _create_one_tracker(); bt = _create_one_tracker()
    nt.init(frame, nozzle_bbox); bt.init(frame, bed_bbox)
    return nt, bt

def create_single_tracker(frame, bbox):
    if not TRACKER_AVAILABLE:
        return None
    t = _create_one_tracker(); t.init(frame, bbox)
    return t

def get_center(bbox):
    return (bbox[0] + bbox[2] / 2.0, bbox[1] + bbox[3] / 2.0)

def template_match(gray, template, tmpl_mask, last_bbox, margin=40):
    lx, ly, lw, lh = [int(v) for v in last_bbox]
    sx1, sy1 = max(0, lx - margin), max(0, ly - margin)
    sx2 = min(gray.shape[1], lx + lw + margin)
    sy2 = min(gray.shape[0], ly + lh + margin)
    region = gray[sy1:sy2, sx1:sx2]
    th, tw = template.shape[:2]
    if region.shape[0] < th or region.shape[1] < tw:
        return None
    if tmpl_mask is not None:
        result = cv2.matchTemplate(region, template, cv2.TM_CCOEFF_NORMED, mask=tmpl_mask)
    else:
        result = cv2.matchTemplate(region, template, cv2.TM_CCOEFF_NORMED)
    _, max_val, _, max_loc = cv2.minMaxLoc(result)
    if max_val > 0.3:
        return (sx1 + max_loc[0], sy1 + max_loc[1], tw, th)
    return None

# ─── Initialize trackers ────────────────────────────────────────────────────
cap.set(cv2.CAP_PROP_POS_FRAMES, FRAME_START)
ret, init_frame = cap.read()

nozzle_tracker, bed_tracker = create_trackers(init_frame, nozzle_roi, bed_roi)

n_frames = FRAME_END - FRAME_START + 1

nozzle_cx = np.zeros(n_frames)
nozzle_cy = np.zeros(n_frames)
bed_cx = np.zeros(n_frames)
bed_cy = np.zeros(n_frames)
track_ok_nozzle = np.ones(n_frames, dtype=bool)
track_ok_bed = np.ones(n_frames, dtype=bool)

nc = poly_centroid(nozzle_poly); bc = poly_centroid(bed_poly)
nozzle_cx[0], nozzle_cy[0] = nc
bed_cx[0], bed_cy[0] = bc

last_nozzle_bbox = nozzle_roi
last_bed_bbox = bed_roi

init_gray = cv2.cvtColor(init_frame, cv2.COLOR_BGR2GRAY)
nozzle_template, nozzle_tmask = make_masked_template(init_gray, nozzle_poly, nozzle_roi)
bed_template, bed_tmask = make_masked_template(init_gray, bed_poly, bed_roi)

cv2.namedWindow("Video + Tracking", cv2.WINDOW_NORMAL)
cv2.namedWindow("Reconstructed Path", cv2.WINDOW_NORMAL)
cv2.resizeWindow("Video + Tracking", 960, 540)
cv2.resizeWindow("Reconstructed Path", 800, 600)
canvas = np.ones((600, 800, 3), dtype=np.uint8) * 255

# ─── Main tracking loop ─────────────────────────────────────────────────────
print("\nStarting tracking...")
print("Press 'q' to quit, 'p' to pause/resume\n")

for i in range(1, n_frames):
    frame_num = FRAME_START + i
    ret, frame = cap.read()
    if not ret:
        print(f"Frame read failed at {frame_num}, stopping.")
        n_frames = i
        nozzle_cx = nozzle_cx[:i]; nozzle_cy = nozzle_cy[:i]
        bed_cx = bed_cx[:i]; bed_cy = bed_cy[:i]
        track_ok_nozzle = track_ok_nozzle[:i]; track_ok_bed = track_ok_bed[:i]
        break

    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    if i % REINIT_INTERVAL == 0:
        nozzle_tracker, bed_tracker = create_trackers(frame, last_nozzle_bbox, last_bed_bbox)

    # Nozzle
    ok_n, bbox_n = False, None
    if nozzle_tracker is not None:
        ok_n, bbox_n = nozzle_tracker.update(frame)
    if ok_n:
        c = get_center(bbox_n)
        nozzle_cx[i], nozzle_cy[i] = c
        last_nozzle_bbox = tuple(int(v) for v in bbox_n)
    else:
        fb = template_match(gray, nozzle_template, nozzle_tmask, last_nozzle_bbox)
        if fb is not None:
            c = get_center(fb)
            nozzle_cx[i], nozzle_cy[i] = c
            last_nozzle_bbox = fb
            nozzle_tracker = create_single_tracker(frame, fb)
        else:
            nozzle_cx[i] = nozzle_cx[i-1]; nozzle_cy[i] = nozzle_cy[i-1]
            track_ok_nozzle[i] = False

    # Bed
    ok_b, bbox_b = False, None
    if bed_tracker is not None:
        ok_b, bbox_b = bed_tracker.update(frame)
    if ok_b:
        c = get_center(bbox_b)
        bed_cx[i], bed_cy[i] = c
        last_bed_bbox = tuple(int(v) for v in bbox_b)
    else:
        fb = template_match(gray, bed_template, bed_tmask, last_bed_bbox)
        if fb is not None:
            c = get_center(fb)
            bed_cx[i], bed_cy[i] = c
            last_bed_bbox = fb
            bed_tracker = create_single_tracker(frame, fb)
        else:
            bed_cx[i] = bed_cx[i-1]; bed_cy[i] = bed_cy[i-1]
            track_ok_bed[i] = False

    if i % 500 == 0:
        pct = 100 * i / n_frames
        print(f"  Frame {frame_num}/{FRAME_END} ({pct:.1f}%) | "
              f"nozzle=({nozzle_cx[i]:.1f},{nozzle_cy[i]:.1f}) "
              f"bed=({bed_cx[i]:.1f},{bed_cy[i]:.1f}) | "
              f"n_fail={(~track_ok_nozzle[:i+1]).sum()} "
              f"b_fail={(~track_ok_bed[:i+1]).sum()}")

    if i % 3 == 0:
        disp = frame.copy()
        # Draw nozzle polygon shifted to current tracked position
        lnb = last_nozzle_bbox
        cv2.polylines(disp, [nozzle_poly], True, (0,255,0), 2)
        cv2.rectangle(disp, (lnb[0], lnb[1]), (lnb[0]+lnb[2], lnb[1]+lnb[3]), (0,200,0), 1)
        lbb = last_bed_bbox
        cv2.polylines(disp, [bed_poly], True, (255,0,0), 2)
        cv2.rectangle(disp, (lbb[0], lbb[1]), (lbb[0]+lbb[2], lbb[1]+lbb[3]), (200,0,0), 1)
        cv2.circle(disp, (int(nozzle_cx[i]), int(nozzle_cy[i])), 3, (0,255,0), -1)
        cv2.circle(disp, (int(bed_cx[i]), int(bed_cy[i])), 3, (255,0,0), -1)
        st_n = "OK" if track_ok_nozzle[i] else "LOST"
        st_b = "OK" if track_ok_bed[i] else "LOST"
        cv2.putText(disp, f"Frame {frame_num} t={frame_num/fps:.1f}s",
                    (10,20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,255), 1)
        cv2.putText(disp, f"Nozzle [{st_n}]", (10,h-40),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0,255,0), 1)
        cv2.putText(disp, f"Bed [{st_b}]", (10,h-20),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255,0,0), 1)
        cv2.imshow("Video + Tracking", disp)

        if i > 20:
            px = nozzle_cx[:i+1] - bed_cx[:i+1]
            py = bed_cx[:i+1]
            rng_x = px.max() - px.min()
            rng_y = py.max() - py.min()
            if rng_x > 1 and rng_y > 1:
                canvas[:] = 255
                sx = 750/rng_x; sy = 550/rng_y; s = min(sx, sy)
                cx_off = 400 - s*(px.max()+px.min())/2
                cy_off = 300 - s*(py.max()+py.min())/2
                step = max(1, len(px)//8000)
                for j in range(step, len(px), step):
                    x1 = int(cx_off + s*px[j-step]); y1 = int(cy_off + s*py[j-step])
                    x2 = int(cx_off + s*px[j]);      y2 = int(cy_off + s*py[j])
                    if (0 <= x1 < 800 and 0 <= y1 < 600 and
                        0 <= x2 < 800 and 0 <= y2 < 600):
                        cv2.line(canvas, (x1, y1), (x2, y2), (0,0,0), 1)
                cv2.imshow("Reconstructed Path", canvas)

        key = cv2.waitKey(1) & 0xFF
        if key == ord('q'):
            print("User quit")
            n_frames = i+1
            nozzle_cx = nozzle_cx[:n_frames]; nozzle_cy = nozzle_cy[:n_frames]
            bed_cx = bed_cx[:n_frames]; bed_cy = bed_cy[:n_frames]
            track_ok_nozzle = track_ok_nozzle[:n_frames]
            track_ok_bed = track_ok_bed[:n_frames]
            break
        elif key == ord('p'):
            paused = True
            while paused:
                k2 = cv2.waitKey(100) & 0xFF
                if k2 in (ord('p'), ord('q')): paused = False

cap.release()
cv2.destroyAllWindows()

actual_n = len(nozzle_cx)
print(f"\nDone tracking. {actual_n} frames.")
print(f"  Nozzle failures: {(~track_ok_nozzle).sum()}")
print(f"  Bed failures:    {(~track_ok_bed).sum()}")

# ─── Compute positions ──────────────────────────────────────────────────────
print_X_raw = nozzle_cx - bed_cx
print_Y_raw = bed_cx.copy()

def high_pass(signal, window=DRIFT_WINDOW):
    kernel = np.ones(window) / window
    smoothed = np.convolve(signal, kernel, mode='same')
    for j in range(window // 2):
        smoothed[j] = np.mean(signal[:j + window//2 + 1])
        smoothed[-(j+1)] = np.mean(signal[-(j + window//2 + 1):])
    return signal - smoothed

print_X = high_pass(print_X_raw)
print_Y = high_pass(print_Y_raw)

# Per-frame speed on ALL raw signals (for segmentation)
raw_speed_nx = np.abs(np.diff(nozzle_cx, prepend=nozzle_cx[0]))
raw_speed_bc = np.abs(np.diff(bed_cx, prepend=bed_cx[0]))
raw_speed_comb = np.sqrt(np.diff(nozzle_cx, prepend=nozzle_cx[0])**2 +
                         np.diff(bed_cx, prepend=bed_cx[0])**2)

dpx = np.diff(print_X, prepend=print_X[0])
dpy = np.diff(print_Y, prepend=print_Y[0])
speed = np.sqrt(dpx**2 + dpy**2)

speed_percentiles = {}
for pct in SPEED_THRESHOLDS:
    nz = speed[speed > 0]
    thresh = np.percentile(nz, pct) if len(nz) > 0 else 1.0
    mask = speed < thresh
    speed_percentiles[pct] = mask
    print(f"  Speed p{pct}: thresh={thresh:.4f}, keep {mask.sum()}/{len(mask)}")

# ─── Save only the requested plot ───────────────────────────────────────────
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def get_segments(x, y, speed_arr, threshold):
    """Split path into segments wherever speed exceeds threshold."""
    travel = speed_arr > threshold
    segments = []
    start = 0
    for i in range(1, len(x)):
        if travel[i]:
            if i - start >= 2:
                segments.append((x[start:i], y[start:i]))
            start = i
    if len(x) - start >= 2:
        segments.append((x[start:], y[start:]))
    return segments

def plot_segments(ax, segments, lw, color='black', alpha=0.7):
    for sx, sy in segments:
        ax.plot(sx, sy, lw=lw, c=color, alpha=alpha, solid_capstyle='round')

travel_thresholds = {}
nz_speed = raw_speed_comb[raw_speed_comb > 0]
for tp in [50, 65, 75, 85, 92]:
    travel_thresholds[tp] = np.percentile(nz_speed, tp) if len(nz_speed) > 0 else 1.0
for abs_t in [1.0, 2.0, 3.0, 5.0]:
    travel_thresholds[f"abs{abs_t:.0f}"] = abs_t

target_name = "result.png"
target_path = os.path.join(OUTPUT_DIR, target_name)

xdata = nozzle_cx
ydata = bed_cx
fx = -xdata
fy = -ydata
tp_val = travel_thresholds["abs2"]
segs = get_segments(fx, fy, raw_speed_comb, tp_val)

fig, ax = plt.subplots(figsize=(16, 10))
plot_segments(ax, segs, lw=2.0)
ax.set_xlabel("X (flipXY)")
ax.set_ylabel("Y (flipXY)")
ax.set_title("nozzle_cx vs bed_cx flipXY (tpabs2, lw=2.0)")
ax.set_aspect('equal')
plt.tight_layout()
plt.savefig(target_path, dpi=200)
plt.close()

# Rotate the rendered image 90 degrees counter-clockwise in-place.
img = cv2.imread(target_path, cv2.IMREAD_UNCHANGED)
if img is None:
    raise RuntimeError(f"Failed to load generated image for rotation: {target_path}")
rot_img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
if not cv2.imwrite(target_path, rot_img):
    raise RuntimeError(f"Failed to save rotated image: {target_path}")

print(f"\nSaved {target_name} (rotated 90 deg CCW)\nAll done!")
