#!/usr/bin/env python3
"""
Turing patterns — diffusion-driven instability, verified from scratch.

Model: the Schnakenberg system (Murray, Mathematical Biology, II §2.4), the
canonical two-species Turing reaction-diffusion model, in dimensionless form:

    u_t = gamma * (a - u + u^2 v) +     laplacian(u)
    v_t = gamma * (b      - u^2 v) + d * laplacian(v)

u is the (slow-diffusing) activator, v the (fast-diffusing) inhibitor, d = D_v/D_u.

What this script checks, in order:
  (1) the homogeneous steady state (u*, v*),
  (2) reaction-only stability: with NO diffusion the state is stable
      (Tr J < 0, Det J > 0),
  (3) the dispersion relation lambda(k): WITH diffusion a band of wavenumbers
      grows -> Turing instability; the critical ratio d_c; the fastest-growing
      mode k* and its wavelength 2*pi/k*,
  (4) a real 2D finite-difference simulation: does the emergent pattern's
      measured wavelength match the linear prediction 2*pi/k* ?
  (5) the control: set d = 1 (equal diffusion) -> NO instability, NO pattern.

Every number the web page states is reproduced here. Run: python3 verify.py
"""
import numpy as np

# ---- parameters (a standard point in Schnakenberg Turing space) -------------
a, b = 0.1, 0.9
d    = 40.0          # inhibitor diffuses 40x faster than activator
gamma = 14300.0      # domain-scale / reaction rate (sets how many spots fit);
                     # chosen so the unit cell holds ~8 features (matches the page)

def kinetics_jacobian(a, b):
    u = a + b
    v = b / u**2
    fu = -1 + 2*u*v          # = (b-a)/(a+b)
    fv = u**2
    gu = -2*u*v              # = -2b/(a+b)
    gv = -u**2
    return u, v, np.array([[fu, fv],[gu, gv]])

u_s, v_s, J = kinetics_jacobian(a, b)
fu, fv, gu, gv = J[0,0], J[0,1], J[1,0], J[1,1]

print("="*64)
print("(1) STEADY STATE")
print(f"    u* = {u_s:.6f}   v* = {v_s:.6f}")
print(f"    check f = a-u+u^2 v = {a - u_s + u_s**2*v_s:.2e}  (should be 0)")
print(f"    check g = b  -u^2 v = {b       - u_s**2*v_s:.2e}  (should be 0)")

print("="*64)
print("(2) REACTION-ONLY STABILITY  (no diffusion)")
trJ = gamma*(fu+gv); detJ = gamma**2*(fu*gv-fv*gu)
print(f"    kinetics J = [[{fu:+.4f}, {fv:+.4f}], [{gu:+.4f}, {gv:+.4f}]]")
print(f"    Tr(gamma J) = {trJ:+.4f}   (< 0 required)  -> {'STABLE' if trJ<0 else 'unstable'}")
print(f"    Det(gamma J)= {detJ:+.4f}  (> 0 required)  -> {'OK' if detJ>0 else 'BAD'}")
print(f"    eigenvalues of gamma J: {np.linalg.eigvals(gamma*J)}")
print("    => with no diffusion every small perturbation decays. The uniform")
print("       state is stable. (This is the surprise to come.)")

# ---- (3) dispersion relation ------------------------------------------------
# Linearize: perturbation ~ exp(lambda t + i k.x).  Growth matrix
#   M(k^2) = gamma J - k^2 * diag(1, d)
# lambda_+(k^2) = larger-real-part eigenvalue. Instability <=> det M(k^2) < 0
# for some k^2 (since Tr M < 0 always here).
def lambda_plus(k2):
    M = gamma*J - k2*np.diag([1.0, d])
    ev = np.linalg.eigvals(M)
    return np.max(ev.real)

def detM(k2):
    return d*k2**2 - gamma*(d*fu+gv)*k2 + gamma**2*(fu*gv-fv*gu)

# Turing conditions (Murray): with Tr J<0, Det J>0 already, also need
#   (i)  d*fu + gv > 0
#   (ii) (d*fu+gv)^2 > 4 d Det(J_kin)
cond_i  = d*fu + gv
detJkin = fu*gv - fv*gu
cond_ii = (d*fu+gv)**2 - 4*d*detJkin
print("="*64)
print("(3) DISPERSION RELATION  (with diffusion, d = %.1f)" % d)
print(f"    Turing cond (i)  d*fu+gv = {cond_i:+.4f}  (> 0 ?) -> {cond_i>0}")
print(f"    Turing cond (ii) (d*fu+gv)^2 - 4 d Det = {cond_ii:+.4f} (> 0 ?) -> {cond_ii>0}")

# critical diffusion ratio d_c: solve (d fu + gv)^2 = 4 d Det(J_kin)  in d
#   fu^2 d^2 + (2 fu gv - 4 Det) d + gv^2 = 0
A_ = fu**2; B_ = 2*fu*gv - 4*detJkin; C_ = gv**2
disc = B_**2 - 4*A_*C_
dc1 = (-B_ - np.sqrt(disc))/(2*A_); dc2 = (-B_ + np.sqrt(disc))/(2*A_)
d_c = max(dc1, dc2)   # the physical root with d*fu+gv>0
print(f"    critical ratio d_c = {d_c:.4f}  (Turing needs d > d_c; here d={d})")

# critical wavenumber at onset: k_c^2 = gamma (d fu + gv)/(2 d)
k_c2 = gamma*(d*fu+gv)/(2*d)
print(f"    k_c^2 (min of det M) = {k_c2:.3f}")

# fastest-growing mode: maximize lambda_+(k^2) numerically
k2grid = np.linspace(1e-6, 4*k_c2, 200001)
lam = np.array([lambda_plus(x) for x in k2grid[::1]])  # vectorize below instead
# (faster vectorized version)
def lambda_plus_vec(k2):
    A11 = gamma*fu - k2; A22 = gamma*gv - d*k2
    tr = A11 + A22
    det = A11*A22 - gamma**2*fv*gu
    return (tr + np.sqrt(tr**2 - 4*det + 0j)).real/2
lam = lambda_plus_vec(k2grid)
i_star = np.argmax(lam)
k2_star = k2grid[i_star]; k_star = np.sqrt(k2_star)
lam_star = lam[i_star]
wavelength = 2*np.pi/k_star
# unstable band: where lambda_+ > 0
pos = k2grid[lam > 0]
k_minus = np.sqrt(pos.min()); k_plus = np.sqrt(pos.max())
print(f"    fastest-growing k* = {k_star:.4f}  (k*^2={k2_star:.3f}), growth lambda* = {lam_star:.3f}")
print(f"    unstable band: k in [{k_minus:.3f}, {k_plus:.3f}]")
print(f"    PREDICTED pattern wavelength  L* = 2 pi / k* = {wavelength:.5f}")
print(f"    (domain is [0,1]^2, so expect ~ {1/wavelength:.2f} wavelengths across)")

# ---- (4) real 2D simulation -------------------------------------------------
def predict(a, b, d, gamma):
    """Return (k_star, wavelength) for the fastest-growing mode."""
    _, _, J = kinetics_jacobian(a, b)
    fu, fv, gu, gv = J[0,0], J[0,1], J[1,0], J[1,1]
    k_c2 = gamma*(d*fu+gv)/(2*d)
    k2 = np.linspace(1e-6, 4*max(k_c2,1.0), 400001)
    A11 = gamma*fu - k2; A22 = gamma*gv - d*k2
    tr = A11 + A22; det = A11*A22 - gamma**2*fv*gu
    lam = (tr + np.sqrt(tr**2 - 4*det + 0j)).real/2
    i = np.argmax(lam)
    ks = np.sqrt(k2[i])
    return ks, 2*np.pi/ks, lam[i]

def simulate(d_sim, gamma_sim, ppw=22, steps=None, seed=1, L=1.0):
    """Explicit FD on a periodic [0,L]^2, grid resolved to ~ppw points per
    predicted wavelength. Runs to saturation. The intrinsic wavelength 2pi/k*
    is set by (gamma,d) only; L>1 just fits more wavelengths (finer
    quantization), so it is the honest test of the prediction."""
    ks, wl, lam = predict(a, b, d_sim if d_sim>1 else 40.0, gamma_sim)
    dx = wl/ppw
    N = int(round(L/dx)); N += N % 2
    dx = L/N
    dt = 0.8*dx*dx/(4*max(1.0, d_sim))          # explicit-diffusion stability
    if steps is None:                            # run ~ to lam*t = 14 (saturated)
        steps = int(14.0/max(lam,1.0)/dt) if lam>0 else int(0.05/dt)
    rng = np.random.default_rng(seed)
    u = u_s + 0.01*(rng.random((N,N))-0.5)
    v = v_s + 0.01*(rng.random((N,N))-0.5)
    inv = 1.0/dx**2
    def lap(z):
        return (np.roll(z,1,0)+np.roll(z,-1,0)+np.roll(z,1,1)+np.roll(z,-1,1)-4*z)*inv
    for _ in range(steps):
        u += dt*(gamma_sim*(a - u + u*u*v) +        lap(u))
        v += dt*(gamma_sim*(b      - u*u*v) + d_sim*lap(v))
    return u, dx, N, steps

def measure_rowcol(field, L):
    """Mirror the web page exactly: sum the 1-D power spectra over all rows and
    all columns, take the integer peak mode m -> wavelength L/m. On L=1 this is
    the quantized measure the page reports live."""
    f = field - field.mean(); N = field.shape[0]
    P = np.zeros(N)
    P += (np.abs(np.fft.fft(f, axis=1))**2).sum(axis=0)
    P += (np.abs(np.fft.fft(f, axis=0))**2).sum(axis=1)
    m = np.argmax(P[1:N//2+1]) + 1
    return L/m, m

def measure_radial(field, dx):
    """Sub-bin radial power-spectrum peak (parabolic interpolation), for the
    large-domain quantization-free test."""
    f = field - field.mean()
    P = np.abs(np.fft.fft2(f))**2
    N = field.shape[0]
    k1d = np.fft.fftfreq(N, d=dx)*2*np.pi
    KX, KY = np.meshgrid(k1d, k1d)
    Kr = np.sqrt(KX**2+KY**2)
    dk = 2*np.pi/(N*dx)
    kbin = np.arange(dk/2, Kr.max(), dk)
    idx = np.digitize(Kr.ravel(), kbin)
    radial = np.array([P.ravel()[idx==i].sum() for i in range(1,len(kbin))])
    kcent = 0.5*(kbin[:-1]+kbin[1:])
    i = np.argmax(radial)
    if 0 < i < len(radial)-1:                       # parabolic sub-bin peak
        y0,y1,y2 = radial[i-1],radial[i],radial[i+1]
        off = 0.5*(y0-y2)/(y0-2*y1+y2) if (y0-2*y1+y2)!=0 else 0
    else: off = 0
    kpeak = kcent[i] + off*dk
    return 2*np.pi/kpeak, kpeak

print("="*64)
print("(4) 2D SIMULATION  -> measure emergent wavelength, compare to prediction")
print("    (each run compares its OWN linear prediction to its OWN measurement;")
print("     wavelength scales with gamma, so the absolute number differs by row)")
print("  (4a) THE PAGE'S CONFIG (gamma=14300, d=40, 128^2), measured exactly as")
print("       the browser does (row+col FFT, integer mode m -> wavelength 1/m):")
ks, wl, lam = predict(a, b, 40.0, 14300.0)
uf, dx2, N2, st = simulate(40.0, 14300.0, L=1.0, ppw=16, steps=22000)
mwl, m = measure_rowcol(uf, 1.0)
print(f"       {st} steps · predicted 2pi/k* = {wl:.4f} ({1/wl:.2f} features)"
      f" · measured = 1/{m} = {mwl:.4f} · amplitude {uf.std():.3f}")
print(f"       k*={ks:.3f}; the cell holds ~8 features and 8 fits exactly, so the")
print(f"       page shows a clean predicted=measured match ({m} features).")
print("  (4b) PREDICTED ONSET feature-count vs d on the page (linear k*); the")
print("       saturated pattern tracks this within ~one feature (nonlinear")
print("       selection settles a touch shorter -> +0..1 feature):")
for dd in (12.0, 20.0, 40.0):
    _, w, _ = predict(a, b, dd, 14300.0)
    print(f"       d={dd:>4.0f}: predicted onset wl {w:.4f}  ->  ~{round(1.0/w)} features")
print("  (4c) CLEAN, quantization-free check (gamma=1000, d=40, large field L=8,")
print("       ~17 wavelengths so no rounding; sub-bin radial peak):")
ksb, wlb, lamb = predict(a, b, 40.0, 1000.0)
ufb, dxb, Nb, stb = simulate(40.0, 1000.0, L=8.0, ppw=14)
mwlb, kpb = measure_radial(ufb, dxb)
print(f"       grid {Nb}^2, {stb} steps · predicted {wlb:.4f} · measured {mwlb:.4f}"
      f" · err {abs(mwlb-wlb)/wlb*100:.1f}% · amplitude {ufb.std():.3f}")

# ---- (5) control: equal diffusion, no Turing instability --------------------
print("="*64)
print("(5) CONTROL  d = 1 (equal diffusion): linear theory predicts NO pattern")
def lambda_plus_vec_d(k2, dd):
    A11 = gamma*fu - k2; A22 = gamma*gv - dd*k2
    tr = A11 + A22; det = A11*A22 - gamma**2*fv*gu
    return (tr + np.sqrt(tr**2 - 4*det + 0j)).real/2
lam_eq = lambda_plus_vec_d(k2grid, 1.0)
print(f"    max growth rate over all k (d=1): {lam_eq.max():.4f}  (<=0 => stable)")
u_eq, _, _, _ = simulate(1.0, 1000.0, steps=40000)
print(f"    simulated pattern amplitude (std of u): {u_eq.std():.2e}  (~0 => flat)")
print("="*64)
print("SUMMARY: a uniform state, provably stable without diffusion, is")
print("destabilized BY diffusion into a pattern whose wavelength the linear")
print("theory predicts. Turn the diffusion difference off and the pattern dies.")
