#!/usr/bin/env catnip
# Détection de déforestation Sentinel-2 via STAC
# Pipeline : recherche STAC → lecture COG → NDVI → perte de végétation → surface
#
# Données réelles Sentinel-2 L2A (10m) via le catalogue STAC earth-search (AWS)
# Lecture partielle de Cloud-Optimized GeoTIFF (seule la zone d'intérêt est téléchargée)
#
# Nécessite un accès internet
# Si aucune scène n'est trouvée : ajuster les dates ou augmenter max_cloud
#
# Ref: https://en.wikipedia.org/wiki/Normalized_difference_vegetation_index
# Ref: https://stacspec.org/
#
# DEPS: pyproj pystac_client rasterio shapely

warnings = import('warnings')
import('builtins', 'RuntimeWarning')
warnings.filterwarnings("ignore", message="invalid value", category=RuntimeWarning)

rasterio = import('rasterio')
numpy = import('numpy')
rio_windows = import('rasterio.windows')
stac_client = import('pystac_client')
shapely_geo = import('shapely.geometry')
pyproj_mod = import('pyproj')
tempfile = import('tempfile')
os = import('os')
sys = import('sys')
rio_warp = import('rasterio.warp')

# ── Structures ─────────────────────────────────────────────────────────

struct BandData {
    array; crs; transform;
}

struct Severity {
    label; min_thresh; max_thresh;
}

# ── Configuration ───────────────────────────────────────────────────────

print("⇒ Configuration")

stac_url = "https://earth-search.aws.element84.com/v1"
collection = "sentinel-2-l2a"
band_red = "red"  # B04, 10m
band_nir = "nir"  # B08, 10m

# AOI de démo (fixe)
aoi_name = "Para"
aoi = shapely_geo.box(-55.8, -7.9, -55.6, -7.7)

date_before = "2024-07-01/2024-09-30"
date_after = "2025-07-01/2025-09-30"
max_cloud = 20
threshold = -0.2
cleanup_outputs = False

print(f"  T1  : {date_before}  T2 : {date_after}")
print(f"  AOI : {aoi_name} {aoi.bounds}")

# ── Fonctions ───────────────────────────────────────────────────────────

# (A - B) / (A + B), protégé contre division par zéro
normalized_diff = (a, b) => {
    numpy.divide(numpy.subtract(a, b), numpy.maximum(numpy.add(a, b), 1e-10))
}

# Couverture nuageuse d'un item (défaut 100% si absent)
cloud_cover = (item) => { item.properties['eo:cloud_cover'] ?? 100 }

# Filtre les items par couverture nuageuse
filter_by_cloud = (items, max_cc) => {
    kept = list()
    for item in items {
        if cloud_cover(item) < max_cc {
            kept.append(item)
        }
    }
    kept
}

# Sélectionne l'item le moins nuageux (fold min)
pick_clearest = (items) => {
    if len(items) == 0 {
        print("    -> aucune scène disponible après filtre nuage")
        None
    } else {
        best = fold(
            items,
            items[0],
            (acc, item) => {
                if cloud_cover(item) < cloud_cover(acc) { item } else { acc }
            }
        )
        cc = cloud_cover(best)
        print(f"    -> {best.id} (cloud: {round(cc, 1)}%)")
        best
    }
}

# Crée un chemin temporaire de façon sûre (sans race condition)
safe_temp_path = (suffix, prefix) => {
    tmp = tempfile.mkstemp(suffix=suffix, prefix=prefix)
    fd = tmp[0]
    path = tmp[1]
    os.close(fd)
    path
}

# Aligne un raster source sur la grille de référence (CRS + transform + shape)
align_to_ref_grid = (src_arr, src_transform, src_crs, ref_shape, ref_transform, ref_crs) => {
    src_nodata = -9999.0
    src_clean = numpy.where(numpy.isnan(src_arr), src_nodata, src_arr).astype(numpy.float32)
    dst = numpy.full(ref_shape, numpy.nan, dtype=numpy.float32)
    rio_warp.reproject(
        source=src_clean,
        destination=dst,
        src_transform=src_transform,
        src_crs=src_crs,
        src_nodata=src_nodata,
        dst_transform=ref_transform,
        dst_crs=ref_crs,
        dst_nodata=numpy.nan,
        resampling=rio_warp.Resampling.bilinear,
    )
    dst
}

# Lit une bande COG via fenêtre spatiale (seuls les pixels de l'AOI sont téléchargés)
read_band = (item, band_key, aoi_geom) => {
    href = item.assets[band_key].href
    with ds = rasterio.open(href) {
        # Reprojection AOI (lon/lat) vers CRS raster (UTM)
        proj = pyproj_mod.Transformer.from_crs("EPSG:4326", ds.crs, always_xy=True)
        ll = proj.transform(aoi_geom.bounds[0], aoi_geom.bounds[1])
        ur = proj.transform(aoi_geom.bounds[2], aoi_geom.bounds[3])

        win = rio_windows.from_bounds(ll[0], ll[1], ur[0], ur[1], transform=ds.transform)
        arr = ds.read(1, window=win, out_dtype=numpy.float32)
        valid_mask = ds.read_masks(1, window=win)

        nodata = ds.nodata
        raster_crs = ds.crs
        win_t = rio_windows.transform(win, ds.transform)

        # Nodata → NaN
        if nodata is not None {
            arr = numpy.where(numpy.equal(arr, nodata), numpy.nan, arr)
        }
        # Masque interne raster (0 = invalide) → NaN
        arr = numpy.where(numpy.equal(valid_mask, 0), numpy.nan, arr)

        BandData(arr, raster_crs, win_t)
    }
}

# Stats compactes d'un array (NaN-safe)
show_stats = (label, arr) => {
    print(
        f"    {label} : [{round(float(numpy.nanmin(arr)), 3)}, {round(float(numpy.nanmax(arr)), 3)}]  mean = {round(float(numpy.nanmean(arr)), 3)}"
    )
}

# ── Recherche STAC ──────────────────────────────────────────────────────

print()
print(f"⇒ Recherche STAC ({stac_url})")

catalog = stac_client.Client.open(stac_url)

# Collecte + filtre pour chaque date
search_and_filter = (label, date_range) => {
    print(f"  {label} : {date_range}")
    search = catalog.search(collections=list(collection), intersects=aoi, datetime=date_range, max_items=50)
    raw = list()
    for item in search.items() {
        raw.append(item)
    }
    print(f"    Scènes brutes : {len(raw)}")
    filtered = filter_by_cloud(raw, max_cloud)
    print(f"    Cloud < {max_cloud}% : {len(filtered)}")
    filtered
}

items1 = search_and_filter("T1", date_before)
items2 = search_and_filter("T2", date_after)

best1 = pick_clearest(items1)
best2 = pick_clearest(items2)

if best1 is None or best2 is None {
    print("\n✗ Impossible de continuer: aucune scène exploitable pour T1 ou T2.")
    print("  Actions: élargir les dates et/ou augmenter max_cloud.")
    sys.exit(1)
}

# ── Lecture des bandes ──────────────────────────────────────────────────

print()
print("⇒ Lecture COG (fenêtre partielle sur AOI)")

print("  T1 :")
red1 = read_band(best1, band_red, aoi)
nir1 = read_band(best1, band_nir, aoi)
print(f"    {red1.array.shape} pixels chargés")

print("  T2 :")
red2 = read_band(best2, band_red, aoi)
nir2 = read_band(best2, band_nir, aoi)
print(f"    {red2.array.shape} pixels chargés")

# ── NDVI ────────────────────────────────────────────────────────────────

print()
print("⇒ Calcul NDVI")

# Les DN Sentinel-2 sont en réflectance ×10000, mais NDVI est scale-invariant
ndvi1 = normalized_diff(nir1.array, red1.array)
ndvi2 = normalized_diff(nir2.array, red2.array)

show_stats("NDVI T1", ndvi1)
show_stats("NDVI T2", ndvi2)

# ── Détection de perte ──────────────────────────────────────────────────

print()
print("⇒ Détection de perte")

need_align = (ndvi2.shape != ndvi1.shape or red2.crs != red1.crs or red2.transform != red1.transform)

if need_align {
    print("  Alignement T2 -> grille T1 (reprojection/rééchantillonnage)")
    ndvi2_aligned = align_to_ref_grid(ndvi2, red2.transform, red2.crs, ndvi1.shape, red1.transform, red1.crs)
} else {
    ndvi2_aligned = ndvi2
}

delta = numpy.subtract(ndvi2_aligned, ndvi1)
show_stats("Delta ", delta)

loss = numpy.less(delta, threshold)
loss_px = int(numpy.nansum(loss))
valid_px = int(numpy.sum(numpy.logical_not(numpy.isnan(delta))))

print(f"  Seuil : {threshold}")
print(f"  Perte : {loss_px} / {valid_px} valides ({round(loss_px * 100 / max(valid_px, 1), 1)}%)")

# ── Surface ─────────────────────────────────────────────────────────────

print()
print("⇒ Surface perdue")

# Résolution en mètres (CRS UTM → px = 10m pour bandes 10m)
pixel_w = abs(red1.transform.a)
pixel_h = abs(red1.transform.e)
pixel_area_ha = (pixel_w * pixel_h) / 10000

print(f"  Résolution : {pixel_w} x {pixel_h} m")

loss_ha = round(loss_px * pixel_area_ha, 2)
print(f"  Perte estimée : {loss_ha} ha ({round(loss_ha / 100, 4)} km2)")

# ── Sévérité ────────────────────────────────────────────────────────────

print()
print("⇒ Sévérité")

severities = list(
    Severity("Modérée", 0.20, 0.35),
    Severity("Sévère", 0.35, 0.50),
    Severity("Critique", 0.50, 999.0),
)

severities.[(s) => {
    mask = if s.max_thresh >= 999.0 {
        numpy.less(delta, -s.min_thresh)
    } else {
        numpy.logical_and(numpy.less(delta, -s.min_thresh), numpy.greater_equal(delta, -s.max_thresh))
    }
    print(f"  {s.label} : {int(numpy.nansum(mask))} px")
}]

if loss_px > 0 {
    print(f"  Drop moyen : {round(float(numpy.nanmean(delta[loss])), 3)}")
}

# ── Export ──────────────────────────────────────────────────────────────

rows = delta.shape[0]
cols = delta.shape[1]
loss_uint8 = loss.astype(numpy.uint8)

# Masque binaire
out_mask = safe_temp_path("_loss_mask.tif", "catnip_")
with dst = rasterio.open(out_mask, 'w', driver='GTiff', height=rows, width=cols, count=1, dtype='uint8', crs=red1.crs,
    transform=red1.transform) {
    dst.write(loss_uint8, 1)
}

# Delta NDVI (valeurs continues)
out_delta = safe_temp_path("_delta_ndvi.tif", "catnip_")
with dst = rasterio.open(out_delta, 'w', driver='GTiff', height=rows, width=cols, count=1, dtype='float32', crs=red1.crs,
    transform=red1.transform) {
    dst.write(delta.astype(numpy.float32), 1)
}

print()
print("⇒ Export GeoTIFF")
print(f"  Masque : {out_mask}")
print(f"  Delta  : {out_delta}")

# ── Visualisation (matplotlib optionnel) ────────────────────────────────

if import('importlib.util').find_spec('matplotlib') is not None {
    print()
    print("⇒ Visualisation")

    plt = import('matplotlib.pyplot')
    mcolors = import('matplotlib.colors')

    fig = plt.figure(figsize=tuple(16, 4))

    ax1 = fig.add_subplot(1, 4, 1)
    im1 = ax1.imshow(ndvi1, cmap="RdYlGn", vmin=-0.2, vmax=0.9)
    ax1.set_title("NDVI T1")
    plt.colorbar(im1, ax=ax1, shrink=0.8)

    ax2 = fig.add_subplot(1, 4, 2)
    im2 = ax2.imshow(ndvi2, cmap="RdYlGn", vmin=-0.2, vmax=0.9)
    ax2.set_title("NDVI T2")
    plt.colorbar(im2, ax=ax2, shrink=0.8)

    ax3 = fig.add_subplot(1, 4, 3)
    im3 = ax3.imshow(delta, cmap="RdBu", vmin=-0.5, vmax=0.5)
    ax3.set_title("Delta NDVI")
    plt.colorbar(im3, ax=ax3, shrink=0.8)

    ax4 = fig.add_subplot(1, 4, 4)
    ax4.imshow(loss_uint8, cmap=mcolors.ListedColormap(list("lightgray", "red")), interpolation="nearest")
    ax4.set_title("Perte détectée")

    plt.tight_layout()
    plot_path = safe_temp_path("_deforestation.png", "catnip_")
    plt.savefig(plot_path, dpi=150)
    plt.close()
    print(f"  Figure : {plot_path}")
} else {
    print("\n  matplotlib absent, visualisation ignorée")
}

# TODO : masque nuages/ombres (bande SCL, 20m → resample)
# TODO : croisement PRODES (rasterize polygones de référence)
# TODO : métriques supervisées (IoU, précision, rappel vs PRODES)

if cleanup_outputs {
    os.remove(out_mask)
    os.remove(out_delta)
    print("\n  Fichiers temporaires nettoyés")
} else {
    print("\n  Fichiers conservés (cleanup_outputs = False)")
}