examples/embedding/03_etl_pipeline.py
"""
Exemple d'intégration de Catnip comme DSL pour pipelines ETL.
Montre comment :
1. Créer un DSL déclaratif pour transformation de données
2. Charger des données depuis CSV/dict
3. Appliquer des transformations via règles Catnip
4. Exporter vers JSON/CSV/dict
Use case : Pipelines de transformation de données avec logique métier en Catnip.
"""
import json
from io import StringIO
from catnip import Catnip, Context, pass_context
class ETLContext(Context):
"""
Contexte de transformation ETL.
Stocke les données en cours de transformation et l'historique des opérations.
"""
def __init__(self, data: list[dict], **kwargs):
super().__init__(**kwargs)
self._data = data
self._operations = []
# Expose les données dans le contexte
self.globals['data'] = data
self.globals['len'] = len
@property
def data(self) -> list[dict]:
return self._data
@property
def operations(self) -> list:
return self._operations
def update_data(self, new_data: list[dict]):
"""Mise à jour des données après transformation."""
self._data = new_data
self.globals['data'] = new_data
def log_operation(self, op_name: str, details: str = ''):
"""Enregistre une opération appliquée."""
self._operations.append({'operation': op_name, 'details': details})
class ETLDSL(Catnip):
"""
DSL Catnip pour pipelines ETL.
Syntaxe déclarative pour filtrer, transformer, agréger des données.
"""
@staticmethod
def _filter_rows(ctx, condition_field: str, operator: str, value):
"""Filtre les lignes selon une condition."""
filtered = []
op_map = {
'==': lambda a, b: a == b,
'!=': lambda a, b: a != b,
'>': lambda a, b: a > b,
'>=': lambda a, b: a >= b,
'<': lambda a, b: a < b,
'<=': lambda a, b: a <= b,
}
if operator not in op_map:
raise ValueError(f"Opérateur inconnu: {operator}")
for row in ctx.data:
if condition_field in row:
if op_map[operator](row[condition_field], value):
filtered.append(row)
ctx.update_data(filtered)
ctx.log_operation('filter', f"{condition_field} {operator} {value}")
return len(filtered)
@staticmethod
def _map_field(ctx, field: str, operation: str, value=None):
"""Applique une transformation à un champ."""
op_map = {
'multiply': lambda x: x * value,
'add': lambda x: x + value,
'upper': lambda x: x.upper() if isinstance(x, str) else x,
'lower': lambda x: x.lower() if isinstance(x, str) else x,
}
if operation not in op_map:
raise ValueError(f"Opération inconnue: {operation}")
for row in ctx.data:
if field in row:
row[field] = op_map[operation](row[field])
ctx.log_operation('map', f"Transform field '{field}' ({operation})")
return len(ctx.data)
@staticmethod
def _rename_field(ctx, old_name: str, new_name: str):
"""Renomme un champ."""
for row in ctx.data:
if old_name in row:
row[new_name] = row.pop(old_name)
ctx.log_operation('rename', f"{old_name} → {new_name}")
return len(ctx.data)
@staticmethod
def _add_field(ctx, field: str, formula: str, field1: str, field2: str = None):
"""Ajoute un nouveau champ calculé."""
formulas = {
'multiply': lambda r: r.get(field1, 0) * r.get(field2, 1),
'add': lambda r: r.get(field1, 0) + r.get(field2, 0),
'concat': lambda r: str(r.get(field1, '')) + str(r.get(field2, '')),
}
if formula not in formulas:
raise ValueError(f"Formule inconnue: {formula}")
for row in ctx.data:
row[field] = formulas[formula](row)
ctx.log_operation('add_field', f"Added '{field}' = {formula}({field1}, {field2})")
return len(ctx.data)
@staticmethod
def _drop_field(ctx, field: str):
"""Supprime un champ."""
for row in ctx.data:
row.pop(field, None)
ctx.log_operation('drop_field', f"Dropped '{field}'")
return len(ctx.data)
@staticmethod
def _sort_by(ctx, field: str, reverse: bool = False):
"""Trie les données par un champ."""
ctx._data = sorted(ctx.data, key=lambda r: r.get(field, ''), reverse=reverse)
ctx.globals['data'] = ctx._data
ctx.log_operation('sort', f"By '{field}' {'desc' if reverse else 'asc'}")
return len(ctx.data)
@staticmethod
def _limit(ctx, n: int):
"""Limite le nombre de lignes."""
ctx.update_data(ctx.data[:n])
ctx.log_operation('limit', f"First {n} rows")
return len(ctx.data)
@staticmethod
def _group_by(ctx, field: str, agg_field: str, agg_func: str):
"""Agrège les données par un champ."""
groups = {}
for row in ctx.data:
key = row.get(field)
if key not in groups:
groups[key] = []
if agg_field in row:
groups[key].append(row[agg_field])
agg_map = {
'sum': sum,
'count': len,
'avg': lambda vals: sum(vals) / len(vals) if vals else 0,
'min': lambda vals: min(vals) if vals else None,
'max': lambda vals: max(vals) if vals else None,
}
if agg_func not in agg_map:
raise ValueError(f"Fonction d'agrégation inconnue: {agg_func}")
result = [
{field: key, f"{agg_func}_{agg_field}": agg_map[agg_func](vals)}
for key, vals in groups.items()
]
ctx.update_data(result)
ctx.log_operation('group_by', f"{field}, {agg_func}({agg_field})")
return len(result)
# Fonctions DSL injectées
DSL_FUNCTIONS = dict(
filter_rows=pass_context(_filter_rows),
map_field=pass_context(_map_field),
rename_field=pass_context(_rename_field),
add_field=pass_context(_add_field),
drop_field=pass_context(_drop_field),
sort_by=pass_context(_sort_by),
limit=pass_context(_limit),
group_by=pass_context(_group_by),
)
def __init__(self, data: list[dict], **kwargs):
context = ETLContext(data)
super().__init__(context=context, **kwargs)
self.context.globals.update(self.DSL_FUNCTIONS)
def transform(self, pipeline_script: str) -> list[dict]:
"""
Exécute le pipeline de transformation.
Returns:
list[dict] - Données transformées
"""
self.parse(pipeline_script)
self.execute()
return self.context.data
def to_json(self, indent: int = 2) -> str:
"""Exporte les données en JSON."""
return json.dumps(self.context.data, indent=indent, ensure_ascii=False)
def to_csv(self) -> str:
"""Exporte les données en CSV."""
if not self.context.data:
return ''
output = StringIO()
fields = list(self.context.data[0].keys())
output.write(','.join(fields) + '\n')
for row in self.context.data:
values = [str(row.get(f, '')) for f in fields]
output.write(','.join(values) + '\n')
return output.getvalue()
# --- Démonstration ---
if __name__ == '__main__':
print("▸ Exemple 1 : Pipeline de nettoyage de données")
print()
# Données brutes (simulation CSV)
raw_data = [
{'id': 1, 'name': 'Alice', 'age': 28, 'salary': 50000, 'dept': 'Engineering'},
{'id': 2, 'name': 'Bob', 'age': 35, 'salary': 60000, 'dept': 'Sales'},
{'id': 3, 'name': 'Charlie', 'age': 22, 'salary': 45000, 'dept': 'Engineering'},
{'id': 4, 'name': 'Diana', 'age': 29, 'salary': 55000, 'dept': 'Marketing'},
{'id': 5, 'name': 'Eve', 'age': 31, 'salary': 58000, 'dept': 'Engineering'},
]
pipeline = """
# Filtrer les ingénieurs
filter_rows('dept', '==', 'Engineering')
# Augmenter les salaires de 10%
map_field('salary', 'multiply', 1.1)
# Renommer champ
rename_field('salary', 'annual_comp')
# Trier par compensation
sort_by('annual_comp', True)
"""
etl = ETLDSL(raw_data.copy())
result = etl.transform(pipeline)
print(f"Résultat : {len(result)} lignes après transformation")
print()
print(etl.to_json())
print()
print("Opérations appliquées :")
for op in etl.context.operations:
print(f" - {op['operation']}: {op['details']}")
print()
print("▸ Exemple 2 : Agrégation par département")
print()
data = [
{'name': 'Alice', 'dept': 'Engineering', 'salary': 50000},
{'name': 'Bob', 'dept': 'Sales', 'salary': 60000},
{'name': 'Charlie', 'dept': 'Engineering', 'salary': 45000},
{'name': 'Diana', 'dept': 'Sales', 'salary': 55000},
]
pipeline = """
group_by('dept', 'salary', 'avg')
sort_by('avg_salary', True)
"""
etl = ETLDSL(data)
result = etl.transform(pipeline)
print("Salaire moyen par département :")
print(etl.to_json())
print()
print("▸ Exemple 3 : Pipeline complet avec ajout de champs")
print()
data = [
{'product': 'Laptop', 'price': 1000, 'quantity': 5},
{'product': 'Mouse', 'price': 20, 'quantity': 50},
{'product': 'Keyboard', 'price': 80, 'quantity': 30},
]
pipeline = """
# Ajouter champ total (price * quantity)
add_field('total', 'multiply', 'price', 'quantity')
# Filtrer les totaux > 1000
filter_rows('total', '>', 1000)
# Supprimer le champ quantity
drop_field('quantity')
# Trier par total
sort_by('total', True)
"""
etl = ETLDSL(data)
result = etl.transform(pipeline)
print(f"Produits avec total > 1000 : {len(result)}")
print()
print(etl.to_json())
print()
print("Export CSV :")
print(etl.to_csv())