import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

# ---------- Build a synthetic test image (grayscale, [0,1]) ----------
H, W = 256, 256
yy, xx = np.mgrid[0:H, 0:W]
grad = (xx / (W - 1)) * 0.5  # horizontal gradient 0..0.5

# Checkerboard pattern
tile = 16
checker = (((yy // tile) + (xx // tile)) % 2).astype(float) * 0.4

# Concentric circle pattern
r = np.sqrt((xx - W/2)**2 + (yy - H/2)**2)
rings = 0.2 * (np.sin(2*np.pi * r / 20) * 0.5 + 0.5)

# Combine and add noise
rng = np.random.default_rng(1)
noise = 0.08 * rng.standard_normal((H, W))
img = np.clip(0.2 + grad + checker + rings + noise, 0.0, 1.0)

# ---------- Separable 1D IIR (Exponential) along rows and columns ----------
def ema_separable(img, alpha, zero_phase=False):
    """
    Apply 1D exponential IIR along rows, then columns.
    If zero_phase is True, do forward-backward along each axis.
    """
    b = np.array([alpha], float)
    a = np.array([1.0, -(1.0 - alpha)], float)
    
    # Row pass
    Y = np.empty_like(img)
    if not zero_phase:
        for i in range(img.shape[0]):
            Y[i, :] = signal.lfilter(b, a, img[i, :])
    else:
        for i in range(img.shape[0]):
            fwd = signal.lfilter(b, a, img[i, :])
            rev = signal.lfilter(b, a, fwd[::-1])[::-1]
            Y[i, :] = rev
    
    # Column pass
    Z = np.empty_like(Y)
    if not zero_phase:
        for j in range(Y.shape[1]):
            Z[:, j] = signal.lfilter(b, a, Y[:, j])
    else:
        for j in range(Y.shape[1]):
            fwd = signal.lfilter(b, a, Y[:, j])
            rev = signal.lfilter(b, a, fwd[::-1])[::-1]
            Z[:, j] = rev
    return Z

alpha = 0.25  # smoothing strength

img_sep = ema_separable(img, alpha, zero_phase=False)       # causal row→col
img_sep_zp = ema_separable(img, alpha, zero_phase=True)     # zero-phase

# ---------- Plot figures ----------
def plot_im(data, title):
    plt.figure(figsize=(4.5, 4.5))
    plt.imshow(np.clip(data, 0, 1), cmap='gray', vmin=0, vmax=1)
    plt.axis('off')
    plt.title(title)
    plt.show()

plot_im(img, "Original")
plot_im(img_sep, f"Separable IIR (EMA) row→col, α={alpha}")
plot_im(img_sep_zp, f"Separable IIR zero-phase (fwd/back), α={alpha}")
