#!/usr/bin/env catnip
# Détection de perte de végétation par analyse multi-temporelle
# Pipeline : NDVI(T1) vs NDVI(T2) → masque de perte → surface estimée → export
#
# Simule deux acquisitions Sentinel-2 à 5 ans d'intervalle
# avec déforestation partielle et expansion urbaine
#
# Ref: https://en.wikipedia.org/wiki/Change_detection_(GIS)
# Ref: https://en.wikipedia.org/wiki/Normalized_difference_vegetation_index
#
# DEPS: matplotlib numpy rasterio

rasterio = import('rasterio')
numpy = import('numpy')
rio_transform = import('rasterio.transform')
tempfile = import('tempfile')
os = import('os')

# ── Structures ─────────────────────────────────────────────────────────

struct BandData {
    array; crs; transform;
}

struct Severity {
    label; min_thresh; max_thresh;
}

# ── Outils ──────────────────────────────────────────────────────────────

# Différence normalisée (A - B) / (A + B)
normalized_diff = (a, b) => {
    numpy.divide(numpy.subtract(a, b), numpy.maximum(numpy.add(a, b), 1e-10))
}

# Fabrique une bande spectrale avec 4 zones de réflectance
# (forêt NO, eau NE, urbain SO, sol nu SE)
make_band = (rows, cols, v_forest, v_water, v_urban, v_bare, noise) => {
    half_r = int(rows / 2)
    half_c = int(cols / 2)
    f = numpy.random.uniform(v_forest - noise, v_forest + noise, tuple(half_r, half_c)).astype(numpy.float32)
    w = numpy.random.uniform(v_water - noise, v_water + noise, tuple(half_r, half_c)).astype(numpy.float32)
    u = numpy.random.uniform(v_urban - noise, v_urban + noise, tuple(half_r, half_c)).astype(numpy.float32)
    b = numpy.random.uniform(v_bare - noise, v_bare + noise, tuple(half_r, half_c)).astype(numpy.float32)
    numpy.block(list(list(f, w), list(u, b)))
}

# Écrit 4 bandes dans un GeoTIFF temporaire
write_geotiff = (blue, green, red, nir, rows, cols, t) => {
    path = tempfile.mktemp(suffix=".tif")
    with dst = rasterio.open(path, 'w', driver='GTiff', height=rows, width=cols, count=4, dtype='float32',
        crs='EPSG:4326', transform=t) {
        dst.write(blue, 1)
        dst.write(green, 2)
        dst.write(red, 3)
        dst.write(nir, 4)
    }
    path
}

# ── Génération des deux dates ───────────────────────────────────────────

print("⇒ Génération de deux acquisitions Sentinel-2")

rows = 256
cols = 256
t = rio_transform.from_bounds(2.25, 48.80, 2.42, 48.90, cols, rows)

# T1 : forêt dense, eau claire, urbain modéré, sol nu
numpy.random.seed(42)

blue_t1 = make_band(rows, cols, 0.05, 0.06, 0.12, 0.10, 0.02)
green_t1 = make_band(rows, cols, 0.08, 0.05, 0.11, 0.12, 0.02)
red_t1 = make_band(rows, cols, 0.04, 0.03, 0.14, 0.15, 0.02)
nir_t1 = make_band(rows, cols, 0.45, 0.03, 0.16, 0.25, 0.02)

# T2 (5 ans plus tard) : déforestation partielle + expansion urbaine
# Forêt : PIR diminue, Rouge augmente (stress / coupe)
# Urbain : s'étend légèrement (réflectance plus uniforme)
numpy.random.seed(99)

blue_t2 = make_band(rows, cols, 0.08, 0.06, 0.13, 0.11, 0.02)
green_t2 = make_band(rows, cols, 0.09, 0.05, 0.12, 0.12, 0.02)
red_t2 = make_band(rows, cols, 0.12, 0.03, 0.15, 0.16, 0.02)
nir_t2 = make_band(rows, cols, 0.22, 0.03, 0.17, 0.23, 0.02)

path_t1 = write_geotiff(blue_t1, green_t1, red_t1, nir_t1, rows, cols, t)
path_t2 = write_geotiff(blue_t2, green_t2, red_t2, nir_t2, rows, cols, t)

print(f"  T1 : {path_t1}")
print(f"  T2 : {path_t2}")

# ── Calcul NDVI aux deux dates ──────────────────────────────────────────

print()
print("⇒ NDVI par date")

src1 = rasterio.open(path_t1)
src2 = rasterio.open(path_t2)

ndvi_t1 = normalized_diff(src1.read(4), src1.read(3))
ndvi_t2 = normalized_diff(src2.read(4), src2.read(3))

print(
    f"  T1 : mean = {round(float(numpy.mean(ndvi_t1)), 3)} [{round(float(numpy.min(ndvi_t1)), 3)}, {round(float(numpy.max(ndvi_t1)), 3)}]"
)
print(
    f"  T2 : mean = {round(float(numpy.mean(ndvi_t2)), 3)} [{round(float(numpy.min(ndvi_t2)), 3)}, {round(float(numpy.max(ndvi_t2)), 3)}]"
)

# ── Détection de perte ──────────────────────────────────────────────────

print()
print("⇒ Détection de perte de végétation")

# Delta NDVI (négatif = perte)
delta = numpy.subtract(ndvi_t2, ndvi_t1)

print(
    f"  Delta NDVI : mean = {round(float(numpy.mean(delta)), 3)} [{round(float(numpy.min(delta)), 3)}, {round(float(numpy.max(delta)), 3)}]"
)

# Seuil de perte significative
threshold = -0.15

# Masque : True là où la végétation a reculé au-delà du seuil
loss_mask = numpy.less(delta, threshold)
loss_pixels = int(numpy.sum(loss_mask))
total_px = rows * cols

print(f"  Seuil : {threshold}")
print(f"  Pixels en perte : {loss_pixels} / {total_px} ({round(loss_pixels * 100 / total_px, 1)}%)")

# ── Estimation de surface ──────────────────────────────────────────────

print()
print("⇒ Estimation de surface perdue")

# Résolution du pixel en degrés → mètres (approximation latitude Paris)
res_x = src1.res[0]
res_y = src1.res[1]

# 1° latitude ≈ 111 320 m, 1° longitude ≈ 111 320 × cos(lat) m
lat_center = 48.85
math = import('math')
m_per_deg_lat = 111320.0
m_per_deg_lon = 111320.0 * math.cos(math.radians(lat_center))

pixel_width_m = res_x * m_per_deg_lon
pixel_height_m = res_y * m_per_deg_lat
pixel_area_m2 = pixel_width_m * pixel_height_m
pixel_area_ha = pixel_area_m2 / 10000

print(f"  Pixel : {round(pixel_width_m, 1)} x {round(pixel_height_m, 1)} m")
print(f"  Aire pixel : {round(pixel_area_m2, 1)} m2 ({round(pixel_area_ha, 4)} ha)")

loss_ha = round(loss_pixels * pixel_area_ha, 2)
loss_km2 = round(loss_ha / 100, 4)

print(f"  Surface perdue : {loss_ha} ha ({loss_km2} km2)")

# ── Intensité de la perte ──────────────────────────────────────────────

print()
print("⇒ Intensité de la perte")

# NDVI moyen dans les zones en perte
mean_drop = round(float(numpy.mean(delta[loss_mask])), 3)
max_drop = round(float(numpy.min(delta[loss_mask])), 3)

print(f"  Drop moyen : {mean_drop}")
print(f"  Drop max   : {max_drop}")

# Distribution par sévérité
severities = list(
    Severity("Modérée", 0.15, 0.30),
    Severity("Sévère", 0.30, 0.50),
    Severity("Critique", 0.50, 999.0),
)

severities.[(s) => {
    mask = if s.max_thresh >= 999.0 {
        numpy.less(delta, -s.min_thresh)
    } else {
        numpy.logical_and(numpy.less(delta, -s.min_thresh), numpy.greater_equal(delta, -s.max_thresh))
    }
    print(f"  {s.label} : {int(numpy.sum(mask))} px")
}]

# ── Export ──────────────────────────────────────────────────────────────

# Masque de perte en GeoTIFF (0 = stable, 1 = perte)
output_mask = tempfile.mktemp(suffix="_loss_mask.tif")

mask_int = loss_mask.astype(numpy.uint8)
with dst = rasterio.open(output_mask, 'w', driver='GTiff', height=rows, width=cols, count=1, dtype='uint8',
    crs='EPSG:4326', transform=src1.transform) {
    dst.write(mask_int, 1)
}

# Delta NDVI en GeoTIFF (valeurs continues)
output_delta = tempfile.mktemp(suffix="_delta_ndvi.tif")

with dst = rasterio.open(output_delta, 'w', driver='GTiff', height=rows, width=cols, count=1, dtype='float32',
    crs='EPSG:4326', transform=src1.transform) {
    dst.write(delta.astype(numpy.float32), 1)
}

print()
print("⇒ Export GeoTIFF")
print(f"  Masque binaire : {output_mask}")
print(f"  Delta NDVI     : {output_delta}")

# ── Visualisation matplotlib (optionnel) ────────────────────────────────
# Nécessite : uv pip install matplotlib
# Commenter cette section si matplotlib n'est pas installé

importlib_util = import('importlib.util')
has_matplotlib = importlib_util.find_spec('matplotlib') != None

if has_matplotlib {
    print()
    print("⇒ Visualisation")

    plt = import('matplotlib.pyplot')
    colors = import('matplotlib.colors')

    fig = plt.figure(figsize=tuple(14, 5))

    # NDVI T1
    ax1 = fig.add_subplot(1, 3, 1)
    im1 = ax1.imshow(ndvi_t1, cmap="RdYlGn", vmin=-0.5, vmax=1.0)
    ax1.set_title("NDVI T1")
    plt.colorbar(im1, ax=ax1, shrink=0.8)

    # NDVI T2
    ax2 = fig.add_subplot(1, 3, 2)
    im2 = ax2.imshow(ndvi_t2, cmap="RdYlGn", vmin=-0.5, vmax=1.0)
    ax2.set_title("NDVI T2")
    plt.colorbar(im2, ax=ax2, shrink=0.8)

    # Masque de perte
    ax3 = fig.add_subplot(1, 3, 3)
    cmap = colors.ListedColormap(list("lightgray", "red"))
    ax3.imshow(mask_int, cmap=cmap, interpolation="nearest")
    ax3.set_title("Perte de végétation")

    plt.tight_layout()
    plot_path = tempfile.mktemp(suffix="_change_detection.png")
    plt.savefig(plot_path, dpi=150)
    plt.close()
    print(f"  Figure sauvegardée : {plot_path}")
} else {
    print("\n  matplotlib non disponible, visualisation ignorée")
    print("  (uv pip install matplotlib pour activer)")
}

# ── Nettoyage ───────────────────────────────────────────────────────────

src1.close()
src2.close()
os.remove(path_t1)
os.remove(path_t2)
os.remove(output_mask)
os.remove(output_delta)
print("\n  Fichiers temporaires nettoyés")