codex/geospatial/sentinel2_deforestation.cat
# Détection de déforestation Sentinel-2 via STAC
# Pipeline : recherche STAC → lecture COG → NDVI → perte de végétation → surface
#
# Données réelles Sentinel-2 L2A (10m) via le catalogue STAC earth-search (AWS)
# Lecture partielle de Cloud-Optimized GeoTIFF (seule la zone d'intérêt est téléchargée)
#
# Nécessite un accès internet
# Si aucune scène n'est trouvée : ajuster les dates ou augmenter max_cloud
#
# Ref: https://en.wikipedia.org/wiki/Normalized_difference_vegetation_index
# Ref: https://stacspec.org/
#
# Installation :
#   uv pip install rasterio numpy pystac-client shapely pyproj matplotlib
#
# Exécuter :
#   catnip docs/codex/geospatial/sentinel2_deforestation.cat

rasterio = import("rasterio")
numpy = import("numpy")
rio_windows = import("rasterio.windows")
stac_client = import("pystac_client")
shapely_geo = import("shapely.geometry")
pyproj_mod = import("pyproj")
importlib_util = import("importlib.util")
tempfile = import("tempfile")
os = import("os")
sys = import("sys")
rio_warp = import("rasterio.warp")

# ── Configuration ───────────────────────────────────────────────────────

print("⇒ Configuration")

stac_url = "https://earth-search.aws.element84.com/v1"
collection = "sentinel-2-l2a"
band_red = "red"  # B04, 10m
band_nir = "nir"  # B08, 10m

# AOI de démo (fixe)
aoi_name = "Para"
aoi = shapely_geo.box(-55.8, -7.9, -55.6, -7.7)

date_before = "2024-07-01/2024-09-30"
date_after = "2025-07-01/2025-09-30"
max_cloud = 20
threshold = -0.2
cleanup_outputs = False

print("  T1  :", date_before, " T2 :", date_after)
print("  AOI :", aoi_name, aoi.bounds)

# ── Fonctions ───────────────────────────────────────────────────────────

# (A - B) / (A + B), protégé contre division par zéro
normalized_diff = (a, b) => {
    numpy.divide(numpy.subtract(a, b), numpy.maximum(numpy.add(a, b), 1e-10))
}

# Filtre les items par couverture nuageuse
filter_by_cloud = (items, max_cc) => {
    kept = list()
    for item in items {
        cc = item.properties.get("eo:cloud_cover", 100)
        if cc < max_cc {
            kept.append(item)
        }
    }
    kept
}

# Sélectionne l'item le moins nuageux (min linéaire)
pick_clearest = (items) => {
    if len(items) == 0 {
        print("    -> aucune scène disponible après filtre nuage")
        None
    } else {
        best = items[0]
        best_cc = best.properties.get("eo:cloud_cover", 100)
        for item in items {
            cc = item.properties.get("eo:cloud_cover", 100)
            if cc < best_cc {
                best = item
                best_cc = cc
            }
        }
        print("    -> ", best.id, " (cloud:", round(best_cc, 1), "%)")
        best
    }
}

# Crée un chemin temporaire de façon sûre (sans race condition)
safe_temp_path = (suffix, prefix) => {
    tmp = tempfile.mkstemp(suffix=suffix, prefix=prefix)
    fd = tmp[0]
    path = tmp[1]
    os.close(fd)
    path
}

# Aligne un raster source sur la grille de référence (CRS + transform + shape)
align_to_ref_grid = (src_arr, src_transform, src_crs, ref_shape, ref_transform, ref_crs) => {
    src_nodata = -9999.0
    src_clean = numpy.where(numpy.isnan(src_arr), src_nodata, src_arr).astype(numpy.float32)
    dst = numpy.full(ref_shape, numpy.nan, dtype=numpy.float32)
    rio_warp.reproject(
        source=src_clean,
        destination=dst,
        src_transform=src_transform,
        src_crs=src_crs,
        src_nodata=src_nodata,
        dst_transform=ref_transform,
        dst_crs=ref_crs,
        dst_nodata=numpy.nan,
        resampling=rio_warp.Resampling.bilinear,
    )
    dst
}

# Lit une bande COG via fenêtre spatiale (seuls les pixels de l'AOI sont téléchargés)
read_band = (item, band_key, aoi_geom) => {
    href = item.assets[band_key].href
    ds = rasterio.open(href)

    # Reprojection AOI (lon/lat) vers CRS raster (UTM)
    proj = pyproj_mod.Transformer.from_crs("EPSG:4326", ds.crs, always_xy=True)
    ll = proj.transform(aoi_geom.bounds[0], aoi_geom.bounds[1])
    ur = proj.transform(aoi_geom.bounds[2], aoi_geom.bounds[3])

    win = rio_windows.from_bounds(ll[0], ll[1], ur[0], ur[1], transform=ds.transform)
    arr = ds.read(1, window=win).astype(numpy.float32)
    valid_mask = ds.read_masks(1, window=win)

    nodata = ds.nodata
    raster_crs = ds.crs
    win_t = rio_windows.transform(win, ds.transform)
    ds.close()

    # Nodata → NaN
    if nodata != None {
        arr = numpy.where(numpy.equal(arr, nodata), numpy.nan, arr)
    }
    # Masque interne raster (0 = invalide) → NaN
    arr = numpy.where(numpy.equal(valid_mask, 0), numpy.nan, arr)

    dict(array=arr, crs=raster_crs, transform=win_t)
}

# Stats compactes d'un array (NaN-safe)
show_stats = (label, arr) => {
    print("    ", label, ": [", round(float(numpy.nanmin(arr)), 3), ",",
        round(float(numpy.nanmax(arr)), 3), "]  mean =", round(float(numpy.nanmean(arr)), 3))
}

# ── Recherche STAC ──────────────────────────────────────────────────────

print("\n⇒ Recherche STAC (" + stac_url + ")")

catalog = stac_client.Client.open(stac_url)

# Collecte + filtre pour chaque date
search_and_filter = (label, date_range) => {
    print("  " + label + " : " + date_range)
    search = catalog.search(collections=list(collection), intersects=aoi, datetime=date_range, max_items=50)
    raw = list()
    for item in search.items() {
        raw.append(item)
    }
    print("    Scènes brutes :", len(raw))
    filtered = filter_by_cloud(raw, max_cloud)
    print("    Cloud <", max_cloud, "% :", len(filtered))
    filtered
}

items1 = search_and_filter("T1", date_before)
items2 = search_and_filter("T2", date_after)

best1 = pick_clearest(items1)
best2 = pick_clearest(items2)

if best1 == None or best2 == None {
    print("\n✗ Impossible de continuer: aucune scène exploitable pour T1 ou T2.")
    print("  Actions: élargir les dates et/ou augmenter max_cloud.")
    sys.exit(1)
}

# ── Lecture des bandes ──────────────────────────────────────────────────

print("\n⇒ Lecture COG (fenêtre partielle sur AOI)")

print("  T1 :")
red1 = read_band(best1, band_red, aoi)
nir1 = read_band(best1, band_nir, aoi)
print("    ", red1["array"].shape, "pixels chargés")

print("  T2 :")
red2 = read_band(best2, band_red, aoi)
nir2 = read_band(best2, band_nir, aoi)
print("    ", red2["array"].shape, "pixels chargés")

# ── NDVI ────────────────────────────────────────────────────────────────

print("\n⇒ Calcul NDVI")

# Les DN Sentinel-2 sont en réflectance ×10000, mais NDVI est scale-invariant
ndvi1 = normalized_diff(nir1["array"], red1["array"])
ndvi2 = normalized_diff(nir2["array"], red2["array"])

show_stats("NDVI T1", ndvi1)
show_stats("NDVI T2", ndvi2)

# ── Détection de perte ──────────────────────────────────────────────────

print("\n⇒ Détection de perte")

need_align = (ndvi2.shape != ndvi1.shape or red2["crs"] != red1["crs"] or red2["transform"] != red1["transform"])

if need_align {
    print("  Alignement T2 -> grille T1 (reprojection/rééchantillonnage)")
    ndvi2_aligned = align_to_ref_grid(ndvi2, red2["transform"], red2["crs"], ndvi1.shape, red1["transform"], red1["crs"]
        )
} else {
    ndvi2_aligned = ndvi2
}

delta = numpy.subtract(ndvi2_aligned, ndvi1)
show_stats("Delta ", delta)

loss = numpy.less(delta, threshold)
loss_px = int(numpy.nansum(loss))
valid_px = int(numpy.sum(numpy.logical_not(numpy.isnan(delta))))

print("  Seuil :", threshold)
print("  Perte :", loss_px, "/", valid_px, "valides (", round(loss_px * 100 / max(valid_px, 1), 1), "%)")

# ── Surface ─────────────────────────────────────────────────────────────

print("\n⇒ Surface perdue")

# Résolution en mètres (CRS UTM → px = 10m pour bandes 10m)
pixel_w = abs(red1["transform"].a)
pixel_h = abs(red1["transform"].e)
pixel_area_ha = (pixel_w * pixel_h) / 10000

print("  Résolution :", pixel_w, "x", pixel_h, "m")

loss_ha = round(loss_px * pixel_area_ha, 2)
print("  Perte estimée :", loss_ha, "ha (", round(loss_ha / 100, 4), "km2 )")

# ── Sévérité ────────────────────────────────────────────────────────────

print("\n⇒ Sévérité")

moderate = numpy.logical_and(numpy.less(delta, -0.2), numpy.greater_equal(delta, -0.35))
severe = numpy.logical_and(numpy.less(delta, -0.35), numpy.greater_equal(delta, -0.5))
critical = numpy.less(delta, -0.5)

print("  Modérée  (0.20-0.35) :", int(numpy.nansum(moderate)), "px")
print("  Sévère   (0.35-0.50) :", int(numpy.nansum(severe)), "px")
print("  Critique (> 0.50)    :", int(numpy.nansum(critical)), "px")

if loss_px > 0 {
    print("  Drop moyen :", round(float(numpy.nanmean(delta[loss])), 3))
}

# ── Export ──────────────────────────────────────────────────────────────

rows = delta.shape[0]
cols = delta.shape[1]
loss_uint8 = loss.astype(numpy.uint8)

# Masque binaire
out_mask = safe_temp_path("_loss_mask.tif", "catnip_")
dst = rasterio.open(out_mask, "w", driver="GTiff", height=rows, width=cols, count=1, dtype="uint8",
    crs=red1["crs"], transform=red1["transform"])
dst.write(loss_uint8, 1)
dst.close()

# Delta NDVI (valeurs continues)
out_delta = safe_temp_path("_delta_ndvi.tif", "catnip_")
dst = rasterio.open(out_delta, "w", driver="GTiff", height=rows, width=cols, count=1, dtype="float32",
    crs=red1["crs"], transform=red1["transform"])
dst.write(delta.astype(numpy.float32), 1)
dst.close()

print("\n⇒ Export GeoTIFF")
print("  Masque :", out_mask)
print("  Delta  :", out_delta)

# ── Visualisation (matplotlib optionnel) ────────────────────────────────

if importlib_util.find_spec("matplotlib") != None {
    print("\n⇒ Visualisation")

    plt = import("matplotlib.pyplot")
    mcolors = import("matplotlib.colors")

    fig = plt.figure(figsize=tuple(16, 4))

    ax1 = fig.add_subplot(1, 4, 1)
    im1 = ax1.imshow(ndvi1, cmap="RdYlGn", vmin=-0.2, vmax=0.9)
    ax1.set_title("NDVI T1")
    plt.colorbar(im1, ax=ax1, shrink=0.8)

    ax2 = fig.add_subplot(1, 4, 2)
    im2 = ax2.imshow(ndvi2, cmap="RdYlGn", vmin=-0.2, vmax=0.9)
    ax2.set_title("NDVI T2")
    plt.colorbar(im2, ax=ax2, shrink=0.8)

    ax3 = fig.add_subplot(1, 4, 3)
    im3 = ax3.imshow(delta, cmap="RdBu", vmin=-0.5, vmax=0.5)
    ax3.set_title("Delta NDVI")
    plt.colorbar(im3, ax=ax3, shrink=0.8)

    ax4 = fig.add_subplot(1, 4, 4)
    ax4.imshow(loss_uint8, cmap=mcolors.ListedColormap(list("lightgray", "red")), interpolation="nearest")
    ax4.set_title("Perte détectée")

    plt.tight_layout()
    plot_path = safe_temp_path("_deforestation.png", "catnip_")
    plt.savefig(plot_path, dpi=150)
    plt.close()
    print("  Figure :", plot_path)
} else {
    print("\n  matplotlib absent, visualisation ignorée")
}

# TODO : masque nuages/ombres (bande SCL, 20m → resample)
# TODO : croisement PRODES (rasterize polygones de référence)
# TODO : métriques supervisées (IoU, précision, rappel vs PRODES)

if cleanup_outputs {
    os.remove(out_mask)
    os.remove(out_delta)
    print("\n  Fichiers temporaires nettoyés")
} else {
    print("\n  Fichiers conservés (cleanup_outputs = False)")
}