examples/embedding/01_dataframe_dsl.py
"""
Exemple d'intégration de Catnip comme DSL pour manipuler des DataFrames.

Montre comment :
1. Sous-classer Context pour ajouter un état (DataFrame)
2. Sous-classer Catnip pour injecter des fonctions DSL
3. Exécuter des scripts simples qui manipulent la DataFrame

Inspiré de l'intégration Catkin (générateur de sites statiques).
"""

import pandas as pd
from catnip import Catnip, Context, pass_context


class DataFrameContext(Context):
    """
    Contexte d'exécution enrichi avec une DataFrame.

    La DataFrame est accessible via `_` dans les scripts Catnip.
    """

    def __init__(self, df: pd.DataFrame, **kwargs):
        super().__init__(**kwargs)
        self._df = df
        # expose la DataFrame comme variable `_` dans le contexte
        self.globals['_'] = df

    @property
    def df(self) -> pd.DataFrame:
        return self._df

    def set_df(self, value: pd.DataFrame):
        self._df = value
        self.globals['_'] = value


class DataFrameDSL(Catnip):
    """
    DSL Catnip pour manipuler des DataFrames.

    Les fonctions DSL reçoivent le contexte via @pass_context
    et opèrent sur ctx.df.
    """

    # Fonctions DSL - toutes reçoivent ctx en premier argument via @pass_context

    @staticmethod
    def _sort(ctx, col=None, reverse=False):
        """Trie la DataFrame par colonne."""
        if col is None:
            col = ctx.df.columns[0]
        ctx.df.sort_values(by=col, ascending=not reverse, inplace=True)

    @staticmethod
    def _head(ctx, n=5):
        """Garde les n premières lignes."""
        ctx.set_df(ctx.df.head(n))

    @staticmethod
    def _tail(ctx, n=5):
        """Garde les n dernières lignes."""
        ctx.set_df(ctx.df.tail(n))

    @staticmethod
    def _filter(ctx, col, op, value):
        """Filtre la DataFrame selon une condition."""
        ops = {
            '==': lambda a, b: a == b,
            '!=': lambda a, b: a != b,
            '>': lambda a, b: a > b,
            '<': lambda a, b: a < b,
            '>=': lambda a, b: a >= b,
            '<=': lambda a, b: a <= b,
        }
        if op not in ops:
            raise ValueError(f"Opérateur inconnu: {op}")
        mask = ops[op](ctx.df[col], value)
        ctx.set_df(ctx.df[mask])

    @staticmethod
    def _select(ctx, *cols):
        """Sélectionne des colonnes."""
        ctx.set_df(ctx.df[list(cols)])

    @staticmethod
    def _drop(ctx, *cols):
        """Supprime des colonnes."""
        ctx.set_df(ctx.df.drop(columns=list(cols)))

    @staticmethod
    def _rename(ctx, old, new):
        """Renomme une colonne."""
        ctx.df.rename(columns={old: new}, inplace=True)

    @staticmethod
    def _groupby(ctx, col):
        """Retourne un résumé groupé (count par groupe)."""
        result = ctx.df.groupby(col).size().reset_index(name='count')
        ctx.set_df(result)

    @staticmethod
    def _show(ctx):
        """Affiche la DataFrame courante."""
        print(ctx.df.to_string(index=False))

    # Dictionnaire des fonctions DSL injectées dans le contexte Catnip
    DSL_FUNCTIONS = dict(
        sort=pass_context(_sort),
        head=pass_context(_head),
        tail=pass_context(_tail),
        filter=pass_context(_filter),
        select=pass_context(_select),
        drop=pass_context(_drop),
        rename=pass_context(_rename),
        groupby=pass_context(_groupby),
        show=pass_context(_show),
    )

    def __init__(self, df: pd.DataFrame, **kwargs):
        # crée le contexte enrichi avec la DataFrame
        context = DataFrameContext(df)
        super().__init__(context=context, **kwargs)
        # injecte les fonctions DSL
        self.context.globals.update(self.DSL_FUNCTIONS)

    def run(self, script: str) -> pd.DataFrame:
        """Exécute un script DSL et retourne la DataFrame résultante."""
        self.parse(script)
        self.execute()
        return self.context.df


# --- Démonstration ---

if __name__ == '__main__':
    # DataFrame d'exemple
    data = pd.DataFrame(
        {
            'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'],
            'age': [25, 30, 35, 28, 22],
            'city': ['Paris', 'Lyon', 'Paris', 'Marseille', 'Lyon'],
            'score': [85, 92, 78, 95, 88],
        }
    )

    print("⇒ DataFrame initiale")
    print(data.to_string(index=False))
    print()

    # Script DSL : filtre, trie et affiche
    script1 = """
    filter('age', '>=', 25)
    sort('score', True)
    show()
    """

    print("⇒ Script 1: filter('age', '>=', 25) + sort('score', True)")
    dsl = DataFrameDSL(data.copy())
    result = dsl.run(script1)
    print()

    # Script DSL : groupby
    script2 = """
    groupby('city')
    sort('count', True)
    show()
    """

    print("⇒ Script 2: groupby('city') + sort('count', True)")
    dsl = DataFrameDSL(data.copy())
    result = dsl.run(script2)
    print()

    # Script DSL : pipeline avec head
    script3 = """
    sort('age')
    head(3)
    select('name', 'age')
    show()
    """

    print("⇒ Script 3: sort('age') + head(3) + select('name', 'age')")
    dsl = DataFrameDSL(data.copy())
    result = dsl.run(script3)