#!/usr/bin/env python3
"""
Backtester for EMA 9/21 crossover + RSI strategy — 3-variant comparison.
Mirrors trading_bot.py logic using 1-hour bars over ~2 years.

Variants
  1. Base          — EMA 9/21 crossover + RSI only
  2. SPY Trend     — Base + buy only when SPY > its 50-day MA
  3. 5% Stop Loss  — Base + 5% hard stop loss per trade
"""

import json
import time
from datetime import datetime, timedelta
from collections import defaultdict

import pandas as pd
import yfinance as yf

# ── Config ────────────────────────────────────────────────────────────────────
STARTING_CASH     = 10_000.0
POSITION_SIZE_USD = 1_000.0
EMA_FAST          = 9
EMA_SLOW          = 21
RSI_PERIOD        = 14
RSI_BUY_MAX       = 70
RSI_SELL_TRIGGER  = 75
SPY_MA_PERIOD     = 50    # trading days
STOP_LOSS_PCT     = 0.05  # 5 %
RESULTS_FILE      = "backtest_results.json"

WATCHLIST = [
    "AAPL", "MSFT", "NVDA", "AMZN", "GOOGL",
    "META", "TSLA", "AMD", "SPY", "QQQ",
    "NBIS", "SOFI", "BE", "ES=F", "MES=F",
]

VARIANTS = [
    {
        "id": 1, "name": "Base Strategy",
        "desc": "EMA 9/21 crossover + RSI, no extra filter",
        "spy_filter": False, "stop_loss": None,
    },
    {
        "id": 2, "name": "SPY Trend Filter",
        "desc": f"Base + buy only when SPY > {SPY_MA_PERIOD}-day MA",
        "spy_filter": True,  "stop_loss": None,
    },
    {
        "id": 3, "name": "5% Stop Loss",
        "desc": f"Base + {STOP_LOSS_PCT*100:.0f}% hard stop loss per trade",
        "spy_filter": False, "stop_loss": STOP_LOSS_PCT,
    },
]


# ── Indicators ────────────────────────────────────────────────────────────────

def calc_ema(series: pd.Series, period: int) -> pd.Series:
    return series.ewm(span=period, adjust=False).mean()


def calc_rsi(series: pd.Series, period: int = 14) -> pd.Series:
    delta    = series.diff()
    gain     = delta.clip(lower=0)
    loss     = (-delta).clip(lower=0)
    avg_gain = gain.ewm(com=period - 1, adjust=False).mean()
    avg_loss = loss.ewm(com=period - 1, adjust=False).mean()
    rs       = avg_gain / avg_loss
    return 100 - (100 / (1 + rs))


# ── Data download ─────────────────────────────────────────────────────────────

def download_all(symbols: list) -> dict:
    # 729 days keeps us inside yfinance's strict 730-day 1-hour window.
    end   = datetime.now()
    start = end - timedelta(days=729)
    data  = {}

    for sym in symbols:
        print(f"  {sym:<8}", end=" ", flush=True)
        try:
            df = yf.download(
                sym, start=start, end=end,
                interval="1h", auto_adjust=True, progress=False,
            )
            time.sleep(0.5)  # polite rate-limit guard

            if df.empty:
                print("no data")
                continue

            if isinstance(df.columns, pd.MultiIndex):
                df.columns = df.columns.get_level_values(0)

            close_col = next((c for c in df.columns if c.lower() == "close"), None)
            if close_col is None:
                print("no Close column")
                continue

            df = df[[close_col]].rename(columns={close_col: "close"}).copy()
            df.dropna(subset=["close"], inplace=True)

            if len(df) < EMA_SLOW + RSI_PERIOD + 5:
                print(f"too few bars ({len(df)})")
                continue

            df["ema9"]  = calc_ema(df["close"], EMA_FAST)
            df["ema21"] = calc_ema(df["close"], EMA_SLOW)
            df["rsi14"] = calc_rsi(df["close"], RSI_PERIOD)
            df.dropna(inplace=True)

            data[sym] = df
            print(f"{len(df):5,} bars  {df.index[0].date()} → {df.index[-1].date()}")

        except Exception as exc:
            print(f"ERROR: {exc}")

    return data


# ── SPY 50-day MA lookup ──────────────────────────────────────────────────────

def build_spy_ma50_lookup(spy_df: pd.DataFrame) -> dict:
    """
    {normalized_timestamp: bool} — True when SPY daily close > 50-day SMA.
    Uses daily closes derived from the hourly data already in memory.
    """
    dates       = spy_df.index.normalize()
    daily_close = spy_df["close"].groupby(dates).last()
    ma50        = daily_close.rolling(SPY_MA_PERIOD).mean()
    above       = (daily_close > ma50).dropna()
    return above.to_dict()   # pd.Timestamp(midnight, tz) → bool


# ── Backtesting engine ────────────────────────────────────────────────────────

def run_backtest(
    data: dict,
    spy_filter: bool = False,
    stop_loss: float | None = None,
    spy_ma50: dict | None = None,
) -> tuple:
    cash      = STARTING_CASH
    positions = {}            # sym → {qty, avg_cost}
    trades    = []
    portfolio_history = []

    all_ts     = sorted(set().union(*(set(df.index) for df in data.values())))
    index_sets = {sym: set(df.index) for sym, df in data.items()}
    prev_row   = {}
    last_price = {}

    for ts in all_ts:
        # SPY trend gate — one lookup per bar, shared across all symbols
        if spy_filter and spy_ma50 is not None:
            spy_ok = bool(spy_ma50.get(ts.normalize(), False))
        else:
            spy_ok = True

        for sym in WATCHLIST:
            if sym not in data or ts not in index_sets[sym]:
                continue

            row   = data[sym].loc[ts]
            price = float(row["close"])
            last_price[sym] = price

            if sym not in prev_row:
                prev_row[sym] = row
                continue

            pr       = prev_row[sym]
            e9_now   = float(row["ema9"])
            e21_now  = float(row["ema21"])
            e9_prev  = float(pr["ema9"])
            e21_prev = float(pr["ema21"])
            rsi_v    = float(row["rsi14"])

            bullish = (e9_prev <= e21_prev) and (e9_now > e21_now)
            bearish = (e9_prev >= e21_prev) and (e9_now < e21_now)

            if sym not in positions:
                if bullish and rsi_v < RSI_BUY_MAX and spy_ok:
                    avail = min(POSITION_SIZE_USD, cash)
                    qty   = int(avail / price)
                    if qty >= 1:
                        cost  = qty * price
                        cash -= cost
                        positions[sym] = {"qty": qty, "avg_cost": price}
                        trades.append({
                            "timestamp":  ts.isoformat(),
                            "symbol":     sym,
                            "action":     "BUY",
                            "price":      round(price, 4),
                            "qty":        qty,
                            "cost":       round(cost, 2),
                            "cash_after": round(cash, 2),
                            "rsi":        round(rsi_v, 2),
                        })
            else:
                exit_reason = None
                # Stop loss checked first — overrides normal exits
                if stop_loss is not None:
                    if price <= positions[sym]["avg_cost"] * (1 - stop_loss):
                        exit_reason = "STOP_LOSS"
                if exit_reason is None:
                    if rsi_v > RSI_SELL_TRIGGER:
                        exit_reason = "RSI_OVERBOUGHT"
                    elif bearish:
                        exit_reason = "BEARISH_CROSS"

                if exit_reason:
                    pos      = positions.pop(sym)
                    proceeds = pos["qty"] * price
                    pnl      = proceeds - pos["qty"] * pos["avg_cost"]
                    pnl_pct  = pnl / (pos["qty"] * pos["avg_cost"]) * 100
                    cash    += proceeds
                    trades.append({
                        "timestamp":   ts.isoformat(),
                        "symbol":      sym,
                        "action":      "SELL",
                        "price":       round(price, 4),
                        "qty":         pos["qty"],
                        "proceeds":    round(proceeds, 2),
                        "pnl":         round(pnl, 2),
                        "pnl_pct":     round(pnl_pct, 2),
                        "cash_after":  round(cash, 2),
                        "exit_reason": exit_reason,
                        "rsi":         round(rsi_v, 2),
                    })

            prev_row[sym] = row

        pv = cash + sum(
            positions[s]["qty"] * last_price[s]
            for s in positions if s in last_price
        )
        portfolio_history.append((ts, round(pv, 2)))

    # Force-close any open positions at last known price
    last_ts = all_ts[-1] if all_ts else None
    for sym, pos in list(positions.items()):
        price = last_price.get(sym)
        if price is None:
            continue
        proceeds = pos["qty"] * price
        pnl      = proceeds - pos["qty"] * pos["avg_cost"]
        pnl_pct  = pnl / (pos["qty"] * pos["avg_cost"]) * 100
        cash    += proceeds
        trades.append({
            "timestamp":   last_ts.isoformat() if last_ts else "",
            "symbol":      sym,
            "action":      "SELL",
            "price":       round(price, 4),
            "qty":         pos["qty"],
            "proceeds":    round(proceeds, 2),
            "pnl":         round(pnl, 2),
            "pnl_pct":     round(pnl_pct, 2),
            "cash_after":  round(cash, 2),
            "exit_reason": "END_OF_BACKTEST",
            "rsi":         None,
        })

    return trades, round(cash, 2), portfolio_history


# ── Metrics ───────────────────────────────────────────────────────────────────

def calc_metrics(trades: list, final_cash: float, portfolio_history: list) -> dict:
    closed   = [t for t in trades if t["action"] == "SELL" and "pnl" in t]
    n_buys   = sum(1 for t in trades if t["action"] == "BUY")
    n_closed = len(closed)
    winners  = [t for t in closed if t["pnl"] > 0]
    losers   = [t for t in closed if t["pnl"] <= 0]
    win_rate = len(winners) / n_closed * 100 if n_closed else 0.0

    best  = max(closed, key=lambda t: t["pnl"]) if closed else None
    worst = min(closed, key=lambda t: t["pnl"]) if closed else None

    total_return = (final_cash - STARTING_CASH) / STARTING_CASH * 100

    peak   = STARTING_CASH
    max_dd = 0.0
    for _, v in portfolio_history:
        if v > peak:
            peak = v
        dd = (peak - v) / peak * 100
        if dd > max_dd:
            max_dd = dd

    monthly: dict = defaultdict(lambda: {"start": None, "end": None})
    for ts_obj, v in portfolio_history:
        m = ts_obj.strftime("%Y-%m")
        if monthly[m]["start"] is None:
            monthly[m]["start"] = v
        monthly[m]["end"] = v

    monthly_pnl = {}
    for m, d in sorted(monthly.items()):
        s = d["start"] or 0.0
        e = d["end"]   or 0.0
        monthly_pnl[m] = {
            "start_value": round(s, 2),
            "end_value":   round(e, 2),
            "pnl":         round(e - s, 2),
            "return_pct":  round((e - s) / s * 100, 2) if s else 0.0,
        }

    sym_stats: dict = defaultdict(
        lambda: {"trades": 0, "wins": 0, "losses": 0, "total_pnl": 0.0}
    )
    exit_reasons: dict = defaultdict(int)
    for t in closed:
        st = sym_stats[t["symbol"]]
        st["trades"]    += 1
        st["total_pnl"] += t["pnl"]
        if t["pnl"] > 0:
            st["wins"] += 1
        else:
            st["losses"] += 1
        exit_reasons[t.get("exit_reason", "UNKNOWN")] += 1

    return {
        "summary": {
            "starting_cash":    STARTING_CASH,
            "final_cash":       final_cash,
            "total_return_pct": round(total_return, 2),
            "total_trades":     n_buys,
            "closed_trades":    n_closed,
            "win_rate_pct":     round(win_rate, 2),
            "winning_trades":   len(winners),
            "losing_trades":    len(losers),
            "best_trade":       best,
            "worst_trade":      worst,
            "max_drawdown_pct": round(max_dd, 2),
        },
        "exit_reasons": dict(exit_reasons),
        "monthly_pnl":  monthly_pnl,
        "symbol_stats": {k: dict(v) for k, v in sym_stats.items()},
        "all_trades":   trades,
    }


# ── Terminal output ───────────────────────────────────────────────────────────

def _trade_str(t: dict | None) -> str:
    if not t:
        return "—"
    return f"{t['symbol']} {t['pnl']:+,.0f} ({t['pnl_pct']:+.0f}%)"


def print_comparison(variants: list, results: list):
    W  = 84
    LW = 22   # label column width
    CW = 18   # each variant column width

    def divider(ch="─"):
        print(f"  {ch*LW}  {ch*CW}  {ch*CW}  {ch*CW}")

    def row(label: str, vals: list, align="<"):
        fmt = f"{{:{align}{CW}}}"
        cells = "  ".join(fmt.format(str(v)) for v in vals)
        print(f"  {label:<{LW}}  {cells}")

    # ── Header ────────────────────────────────────────────────────────────────
    print("\n" + "═" * W)
    print("  STRATEGY COMPARISON  —  EMA 9/21 + RSI  |  1-Hour Bars  |  ~2 Years")
    print("  Universe: 13 stocks + ES=F + MES=F  |  $10,000 starting capital")
    print("─" * W)
    for v in variants:
        print(f"  V{v['id']}: {v['name']:<18} — {v['desc']}")
    print("═" * W)
    row("", [f"V{v['id']}: {v['name']}" for v in variants])
    divider("─")

    # ── Core metrics ──────────────────────────────────────────────────────────
    s = [r["summary"] for r in results]
    row("Final Portfolio",   [f"${x['final_cash']:,.2f}"         for x in s])
    row("Total Return",      [f"{x['total_return_pct']:+.2f}%"   for x in s])
    row("Max Drawdown",      [f"{x['max_drawdown_pct']:.2f}%"    for x in s])
    divider()
    row("Trades Entered",    [f"{x['total_trades']:,}"           for x in s])
    row("Trades Closed",     [f"{x['closed_trades']:,}"          for x in s])
    row("Win Rate",          [f"{x['win_rate_pct']:.1f}%"        for x in s])
    row("Winners / Losers",  [f"{x['winning_trades']}/{x['losing_trades']}" for x in s])
    divider()
    row("Best Trade",        [_trade_str(x["best_trade"])        for x in s])
    row("Worst Trade",       [_trade_str(x["worst_trade"])       for x in s])

    # ── Exit reason breakdown ─────────────────────────────────────────────────
    divider()
    all_reasons = sorted(
        set().union(*(r["exit_reasons"].keys() for r in results))
    )
    for reason in all_reasons:
        row(f"  Exit: {reason}", [r["exit_reasons"].get(reason, 0) for r in results])

    print("═" * W)

    # ── Monthly P&L side-by-side ──────────────────────────────────────────────
    all_months = sorted(set().union(*(r["monthly_pnl"].keys() for r in results)))

    print(f"\n  MONTHLY P&L  (P&L $ / Return %)")
    print(f"  {'─'*9}  {'─'*18}  {'─'*18}  {'─'*18}")
    print(
        f"  {'Month':<9}  "
        f"{'V1: Base':<18}  "
        f"{'V2: SPY Filter':<18}  "
        f"{'V3: Stop Loss':<18}"
    )
    print(f"  {'─'*9}  {'─'*18}  {'─'*18}  {'─'*18}")

    for m in all_months:
        cols = []
        for r in results:
            d     = r["monthly_pnl"].get(m, {"pnl": 0, "return_pct": 0})
            pnl   = d["pnl"]
            ret   = d["return_pct"]
            arrow = "▲" if pnl >= 0 else "▼"
            cols.append(f"${pnl:>+8,.0f}  {ret:>+5.1f}%{arrow}")
        print(f"  {m}  {'  '.join(cols)}")

    # ── Per-symbol comparison ─────────────────────────────────────────────────
    all_syms = sorted(
        set().union(*(r["symbol_stats"].keys() for r in results)),
        key=lambda s: results[0]["symbol_stats"].get(s, {}).get("total_pnl", 0),
        reverse=True,
    )
    if all_syms:
        print(f"\n  PER-SYMBOL P&L  (total $ across all trades)")
        print(f"  {'─'*8}  {'─'*18}  {'─'*18}  {'─'*18}")
        print(
            f"  {'Symbol':<8}  "
            f"{'V1: Base':<18}  "
            f"{'V2: SPY Filter':<18}  "
            f"{'V3: Stop Loss':<18}"
        )
        print(f"  {'─'*8}  {'─'*18}  {'─'*18}  {'─'*18}")
        for sym in all_syms:
            cols = []
            for r in results:
                st = r["symbol_stats"].get(sym)
                if st:
                    cols.append(
                        f"${st['total_pnl']:>+7,.0f}  "
                        f"{st['wins']}W/{st['losses']}L"
                    )
                else:
                    cols.append("no trades")
            print(f"  {sym:<8}  {'  '.join(cols)}")

    print(f"\n  Full results saved → {RESULTS_FILE}")
    print("═" * W + "\n")


# ── Entry point ───────────────────────────────────────────────────────────────

def main():
    W = 84
    print("═" * W)
    print("  BACKTESTER  —  EMA 9/21 + RSI  |  3-Variant Comparison")
    print("  Fetching ~2 years of 1-hour bars for 15 symbols...")
    print("═" * W)

    data = download_all(WATCHLIST)
    if not data:
        print("No data downloaded. Exiting.")
        return

    # Precompute SPY 50-day MA for variant 2
    spy_ma50 = None
    if "SPY" in data:
        spy_ma50 = build_spy_ma50_lookup(data["SPY"])
        above    = sum(1 for v in spy_ma50.values() if v)
        total    = len(spy_ma50)
        print(f"\n  SPY 50-day MA: above trend on {above}/{total} days ({above/total*100:.0f}%)")

    print(f"\n  Running {len(VARIANTS)} variants over "
          f"{sum(len(df) for df in data.values()):,} total bars...\n")

    all_results = []
    for v in VARIANTS:
        print(f"  V{v['id']}: {v['name']}...", end=" ", flush=True)
        trades, final_cash, ph = run_backtest(
            data,
            spy_filter=v["spy_filter"],
            stop_loss=v["stop_loss"],
            spy_ma50=spy_ma50,
        )
        metrics = calc_metrics(trades, final_cash, ph)
        metrics["variant"] = v
        metrics["portfolio_history"] = [
            {"timestamp": ts.isoformat(), "value": val} for ts, val in ph
        ]
        all_results.append(metrics)

        s = metrics["summary"]
        print(
            f"${s['final_cash']:,.2f}  "
            f"({s['total_return_pct']:+.2f}%)  "
            f"win={s['win_rate_pct']:.1f}%  "
            f"dd={s['max_drawdown_pct']:.1f}%"
        )

    with open(RESULTS_FILE, "w") as f:
        json.dump({"variants": VARIANTS, "results": all_results}, f, indent=2, default=str)

    print_comparison(VARIANTS, all_results)


if __name__ == "__main__":
    main()
