examples/performance/jit_benchmark.py
#!/usr/bin/env python3
"""JIT compilation benchmark for Catnip.

Demonstrates the trace-based JIT compiler using Cranelift.
Compares interpreter vs JIT-compiled execution for hot loops and hot functions.

The JIT compiler supports:
- Loops: Int, float, boolean types, nested loops, conditional branches
- Functions (since v0.0.3): Recursive and non-recursive functions called frequently
- Typical speedups: 100-200x on numeric loops, 1.1x on simple functions

Usage:
    python docs/examples/performance/jit_benchmark.py
"""

import time


def benchmark_sum_loop():
    """Compare interpreter vs JIT for a simple sum loop."""
    from catnip import Catnip
    from catnip._rs import Compiler

    code = """
    total = 0
    for i in range(1, 100001) {
        total = total + i
    }
    total
    """

    # Parse and compile to bytecode
    c = Catnip(vm_mode='on')
    ast = c.parse(code)
    compiler = Compiler()
    bytecode = compiler.compile(ast)

    # Get VM
    from catnip._rs import VM
    vm = VM()
    vm.set_context(c.context)

    # Warm up (also triggers JIT compilation)
    vm.execute(bytecode, (), {}, None)

    # Without JIT
    vm.disable_jit()
    start = time.perf_counter()
    result_interp = vm.execute(bytecode, (), {}, None)
    interp_time = (time.perf_counter() - start) * 1000

    # With JIT
    vm.enable_jit()
    start = time.perf_counter()
    result_jit = vm.execute(bytecode, (), {}, None)
    jit_time = (time.perf_counter() - start) * 1000

    # Get stats
    stats = vm.get_jit_stats()

    print("Sum(1..100000) Benchmark")
    print("=" * 40)
    print(f"Result: {result_jit} (expected: 5000050000)")
    print(f"Interpreter: {interp_time:.2f}ms")
    print(f"JIT:         {jit_time:.2f}ms")
    if jit_time > 0:
        print(f"Speedup:     {interp_time / jit_time:.1f}x")
    print()
    print("JIT Stats:")
    for key, value in stats.items():
        print(f"  {key}: {value}")


def benchmark_float_loop():
    """Benchmark float accumulation."""
    from catnip import Catnip
    from catnip._rs import Compiler, VM

    code = """
    x = 0.0
    for i in range(1000000) {
        x = x + 1.5
    }
    x
    """

    c = Catnip(vm_mode='on')
    ast = c.parse(code)
    compiler = Compiler()
    bytecode = compiler.compile(ast)

    vm = VM()
    vm.set_context(c.context)
    vm.execute(bytecode, (), {}, None)

    # Without JIT
    vm.disable_jit()
    start = time.perf_counter()
    result_interp = vm.execute(bytecode, (), {}, None)
    interp_time = (time.perf_counter() - start) * 1000

    # With JIT
    vm.enable_jit()
    start = time.perf_counter()
    result_jit = vm.execute(bytecode, (), {}, None)
    jit_time = (time.perf_counter() - start) * 1000

    print()
    print("Float Accumulation (1M iterations)")
    print("=" * 40)
    print(f"Result: {result_jit} (expected: 1500000.0)")
    print(f"Interpreter: {interp_time:.2f}ms")
    print(f"JIT:         {jit_time:.2f}ms")
    if jit_time > 0:
        print(f"Speedup:     {interp_time / jit_time:.1f}x")


def benchmark_conditional_loop():
    """Benchmark loop with conditional branch (side exits)."""
    from catnip import Catnip
    from catnip._rs import Compiler, VM

    code = """
    count = 0
    for i in range(1000000) {
        if i > 500000 {
            count = count + 1
        }
    }
    count
    """

    c = Catnip(vm_mode='on')
    ast = c.parse(code)
    compiler = Compiler()
    bytecode = compiler.compile(ast)

    vm = VM()
    vm.set_context(c.context)
    vm.execute(bytecode, (), {}, None)

    # Without JIT
    vm.disable_jit()
    start = time.perf_counter()
    result_interp = vm.execute(bytecode, (), {}, None)
    interp_time = (time.perf_counter() - start) * 1000

    # With JIT
    vm.enable_jit()
    start = time.perf_counter()
    result_jit = vm.execute(bytecode, (), {}, None)
    jit_time = (time.perf_counter() - start) * 1000

    print()
    print("Conditional Loop (1M iterations, 50% side exits)")
    print("=" * 40)
    print(f"Result: {result_jit} (expected: 499999)")
    print(f"Interpreter: {interp_time:.2f}ms")
    print(f"JIT:         {jit_time:.2f}ms")
    if jit_time > 0:
        speedup = interp_time / jit_time
        print(f"Speedup:     {speedup:.2f}x")
        if speedup < 1:
            print("  Note: Side exits cause overhead when guards fail frequently")


def benchmark_function_loop():
    """Benchmark loop inside a function (tests JIT for function scopes)."""
    from catnip import Catnip
    from catnip._rs import Compiler, VM

    code = """
    counter = (n) => {
        i = 0
        sum = 0
        while i < n {
            sum = sum + 1
            i = i + 1
        }
        sum
    }
    result = counter(1000000)
    result
    """

    c = Catnip(vm_mode='on')
    ast = c.parse(code)
    compiler = Compiler()
    bytecode = compiler.compile(ast)

    vm = VM()
    vm.set_context(c.context)
    vm.execute(bytecode, (), {}, None)

    # Without JIT
    vm.disable_jit()
    start = time.perf_counter()
    result_interp = vm.execute(bytecode, (), {}, None)
    interp_time = (time.perf_counter() - start) * 1000

    # With JIT
    vm.enable_jit()
    start = time.perf_counter()
    result_jit = vm.execute(bytecode, (), {}, None)
    jit_time = (time.perf_counter() - start) * 1000

    print()
    print("Function Loop (1M iterations inside function)")
    print("=" * 40)
    print(f"Result: {result_jit} (expected: 1000000)")
    print(f"Interpreter: {interp_time:.2f}ms")
    print(f"JIT:         {jit_time:.2f}ms")
    if jit_time > 0:
        print(f"Speedup:     {interp_time / jit_time:.1f}x")


def benchmark_hot_function():
    """Benchmark JIT compilation of frequently called functions (since v0.0.3)."""
    from catnip import Catnip
    from catnip._rs import Compiler, VM

    code = """
    square = (x) => { x * x }

    result = 0
    for i in range(10000) {
        result = square(i)
    }
    result
    """

    c = Catnip(vm_mode='on')
    ast = c.parse(code)
    compiler = Compiler()
    bytecode = compiler.compile(ast)

    vm = VM()
    vm.set_context(c.context)
    vm.execute(bytecode, (), {}, None)

    # Without JIT
    vm.disable_jit()
    start = time.perf_counter()
    result_interp = vm.execute(bytecode, (), {}, None)
    interp_time = (time.perf_counter() - start) * 1000

    # With JIT
    vm.enable_jit()
    start = time.perf_counter()
    result_jit = vm.execute(bytecode, (), {}, None)
    jit_time = (time.perf_counter() - start) * 1000

    print()
    print("Hot Function JIT (function called 10k times)")
    print("=" * 40)
    print(f"Result: {result_jit} (expected: 99980001)")
    print(f"Interpreter: {interp_time:.2f}ms")
    print(f"JIT:         {jit_time:.2f}ms")
    if jit_time > 0:
        speedup = interp_time / jit_time
        print(f"Speedup:     {speedup:.2f}x")
        print()
        print("Note: Modest speedup (~1.1x) due to boxing/unboxing overhead.")
        print("      Functions with loops inside benefit more from loop JIT.")


def benchmark_recursive_function():
    """Benchmark recursive function compilation."""
    from catnip import Catnip
    from catnip._rs import Compiler, VM

    code = """
    {
        factorial = (n) => {
            if n <= 1 {
                1
            } else {
                n * factorial(n - 1)
            }
        }

        result = 0
        i = 0
        while i < 120 {
            result = factorial(5)
            i = i + 1
        }
        result
    }
    """

    # Parse and compile
    c = Catnip(vm_mode='on')
    ast = c.parse(code)
    compiler = Compiler()
    bytecode = compiler.compile(ast)

    # Without JIT
    vm_nojit = VM()
    vm_nojit.set_context(c.context)
    start = time.perf_counter()
    result_interp = vm_nojit.execute(bytecode, (), {}, None)
    interp_time = (time.perf_counter() - start) * 1000

    # With JIT (need fresh VM)
    vm_jit = VM()
    vm_jit.set_context(c.context)
    vm_jit.enable_jit()
    start = time.perf_counter()
    result_jit = vm_jit.execute(bytecode, (), {}, None)
    jit_time = (time.perf_counter() - start) * 1000

    stats = vm_jit.get_jit_stats()

    print("Recursive Function (factorial) Benchmark")
    print("=" * 40)
    print(f"Result: {result_jit} (expected: 120)")
    print(f"Calls: 120 × factorial(5) = 600 recursive calls")
    print(f"Interpreter: {interp_time:.2f}ms")
    print(f"JIT:         {jit_time:.2f}ms")
    if jit_time > 0:
        print(f"Speedup:     {interp_time / jit_time:.1f}x")
    print()
    print("Note: Recursive calls compiled via CallSelf with NaN re-boxing.")
    print("      Depth > MAX_RECURSION_DEPTH (10000) triggers graceful fallback.")


def show_jit_info():
    """Display JIT configuration information."""
    from catnip._rs import VM

    vm = VM()
    print("JIT Configuration")
    print("=" * 40)
    print(f"JIT enabled: {vm.is_jit_enabled()}")
    print("Backend: Cranelift (x86-64)")
    print("Strategy: Trace-based compilation")
    print("Hot threshold: 100 iterations/calls")
    print("Supports:")
    print("  - Loops (for/while)")
    print("  - Functions (recursive and non-recursive, since v0.0.3)")
    print("Scope support: Module-level and function-local variables")
    print()


if __name__ == "__main__":
    show_jit_info()
    benchmark_sum_loop()
    benchmark_float_loop()
    benchmark_conditional_loop()
    benchmark_function_loop()
    benchmark_hot_function()
    benchmark_recursive_function()