#!/usr/bin/env python3
"""
Futures Paper Trading Bot
Strategy: 9/21 EMA crossover with RSI filter on 15-minute bars.
Instruments: ES=F (E-mini S&P 500), MES=F (Micro E-mini S&P 500)
All state tracked locally in JSON files.

Protections:
  - Trading-hours filter: only trade 09:30–16:00 and 18:00–23:00 ET
  - ATR volatility filter: skip entries when ATR is in the top 10% of recent values
  - 1% stop loss per position
  - Max one open futures position at a time
"""

import os
import json
import time
import logging
from datetime import datetime, timezone, time as dtime
from zoneinfo import ZoneInfo
from typing import Optional

import pandas as pd
import yfinance as yf
import requests
from dotenv import load_dotenv

# ── Configuration ─────────────────────────────────────────────────────────────
load_dotenv()

DISCORD_WEBHOOK      = os.getenv("DISCORD_WEBHOOK_URL", "")
TRADE_LOG_FILE       = "futures_trade_log.json"
PORTFOLIO_FILE       = "futures_portfolio.json"
STARTING_CASH        = 10_000.0
POSITION_SIZE_USD    = float(os.getenv("FUTURES_POSITION_SIZE_USD", "5000"))

WATCHLIST = ["ES=F", "MES=F"]

# Futures contract specs: multiplier converts price points to USD,
# margin is the simulated capital reserved per contract.
FUTURES_SPECS = {
    "ES=F":  {"multiplier": 50,  "margin_per_contract": 15_000},
    "MES=F": {"multiplier": 5,   "margin_per_contract":  1_500},
}

EMA_FAST         = 9
EMA_SLOW         = 21
RSI_PERIOD       = 14
RSI_BUY_MAX      = 70
RSI_SELL_TRIGGER = 75
LOOP_SECONDS     = 900  # 15 minutes

# ── Risk parameters ───────────────────────────────────────────────────────────
ATR_PERIOD          = 14
ATR_HIGH_PERCENTILE = 0.90   # skip entries when ATR > 90th-percentile of recent values
STOP_LOSS_PCT       = 0.01   # 1% below entry price
MAX_OPEN_CONTRACTS  = 1      # at most one futures position at a time across all instruments

ET = ZoneInfo("America/New_York")

# Allowed trading windows (ET). Avoids thinly-traded 23:00–09:30 overnight hours.
TRADE_WINDOWS = [
    (dtime(9, 30),  dtime(16, 0)),   # regular US equity session
    (dtime(18, 0),  dtime(23, 0)),   # evening futures session
]

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__name__)


# ── Portfolio state ───────────────────────────────────────────────────────────

def load_portfolio() -> dict:
    if os.path.exists(PORTFOLIO_FILE):
        try:
            with open(PORTFOLIO_FILE) as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError):
            log.warning("Could not read %s — starting fresh.", PORTFOLIO_FILE)
    initial = {"cash": STARTING_CASH, "positions": {}}
    save_portfolio(initial)
    return initial


def save_portfolio(portfolio: dict):
    with open(PORTFOLIO_FILE, "w") as f:
        json.dump(portfolio, f, indent=2)


# ── Trade log ─────────────────────────────────────────────────────────────────

def log_trade(ticker: str, action: str, price: float, contracts: int,
              cash_after: float, reasoning: str, pnl: Optional[float] = None):
    spec = FUTURES_SPECS[ticker]
    entry = {
        "timestamp":        datetime.now(timezone.utc).isoformat(),
        "ticker":           ticker,
        "action":           action,
        "price":            round(price, 2),
        "contracts":        contracts,
        "multiplier":       spec["multiplier"],
        "notional_value":   round(price * spec["multiplier"] * contracts, 2),
        "margin_used":      round(spec["margin_per_contract"] * contracts, 2),
        "cash_after":       round(cash_after, 2),
        "pnl":              round(pnl, 2) if pnl is not None else None,
        "reasoning":        reasoning,
    }

    records: list = []
    if os.path.exists(TRADE_LOG_FILE):
        try:
            with open(TRADE_LOG_FILE) as f:
                records = json.load(f)
        except (json.JSONDecodeError, IOError):
            records = []

    records.append(entry)
    with open(TRADE_LOG_FILE, "w") as f:
        json.dump(records, f, indent=2)

    log.info(
        "LOGGED  %-4s  %-6s  @ %.2f  contracts=%d  notional=$%.0f  cash=$%.2f",
        action, ticker, price, contracts,
        price * spec["multiplier"] * contracts, cash_after,
    )


# ── Discord notifications ─────────────────────────────────────────────────────

def notify_discord(ticker: str, action: str, price: float, contracts: int,
                   cash_after: float, pnl: Optional[float], reasoning: str):
    if not DISCORD_WEBHOOK:
        return

    spec     = FUTURES_SPECS[ticker]
    notional = price * spec["multiplier"] * contracts
    margin   = spec["margin_per_contract"] * contracts
    is_buy   = action == "BUY"
    icon     = "🟢" if is_buy else "🔴"
    color    = 0x2ECC71 if is_buy else 0xE74C3C

    fields = [
        {"name": "Price",          "value": f"{price:,.2f}",       "inline": True},
        {"name": "Contracts",      "value": str(contracts),         "inline": True},
        {"name": "Notional Value", "value": f"${notional:,.0f}",   "inline": True},
        {"name": "Margin Used",    "value": f"${margin:,.0f}",     "inline": True},
        {"name": "Cash Remaining", "value": f"${cash_after:,.2f}", "inline": True},
    ]
    if pnl is not None:
        fields.append({"name": "P&L", "value": f"${pnl:+,.2f}", "inline": True})
    fields.append({"name": "Signal", "value": reasoning, "inline": False})

    payload = {
        "embeds": [{
            "title":     f"{icon} {action}: {ticker}",
            "color":     color,
            "fields":    fields,
            "timestamp": datetime.now(timezone.utc).isoformat(),
        }]
    }

    try:
        resp = requests.post(DISCORD_WEBHOOK, json=payload, timeout=10)
        resp.raise_for_status()
    except Exception as exc:
        log.warning("Discord notification failed: %s", exc)


# ── Technical indicators ──────────────────────────────────────────────────────

def calc_ema(series: pd.Series, period: int) -> pd.Series:
    return series.ewm(span=period, adjust=False).mean()


def calc_rsi(series: pd.Series, period: int = 14) -> pd.Series:
    delta    = series.diff()
    gain     = delta.clip(lower=0)
    loss     = (-delta).clip(lower=0)
    avg_gain = gain.ewm(com=period - 1, adjust=False).mean()
    avg_loss = loss.ewm(com=period - 1, adjust=False).mean()
    rs       = avg_gain / avg_loss
    return 100 - (100 / (1 + rs))


def calc_atr(df: pd.DataFrame, period: int = ATR_PERIOD) -> pd.Series:
    """Wilder's Average True Range using EWM smoothing."""
    prev_close = df["close"].shift(1)
    tr = pd.concat([
        df["high"] - df["low"],
        (df["high"] - prev_close).abs(),
        (df["low"]  - prev_close).abs(),
    ], axis=1).max(axis=1)
    return tr.ewm(com=period - 1, adjust=False).mean()


# ── Trading-hours filter ──────────────────────────────────────────────────────

def in_trading_window() -> bool:
    """True if the current ET time falls within an allowed trading window."""
    t = datetime.now(ET).time()
    return any(start <= t < end for start, end in TRADE_WINDOWS)


# ── Data fetching ─────────────────────────────────────────────────────────────

def fetch_bars(symbol: str) -> Optional[pd.DataFrame]:
    """Return 15-min bars with ema9, ema21, rsi14, and atr14 columns."""
    try:
        ticker = yf.Ticker(symbol)
        df = ticker.history(period="5d", interval="15m", auto_adjust=True)

        if df.empty:
            log.warning("%s: no data returned by yfinance", symbol)
            return None

        df = df[["High", "Low", "Close"]].rename(
            columns={"High": "high", "Low": "low", "Close": "close"}
        ).copy()

        if len(df) < EMA_SLOW + 5:
            log.warning("%s: only %d bars (need %d+)", symbol, len(df), EMA_SLOW + 5)
            return None

        df["ema9"]  = calc_ema(df["close"], EMA_FAST)
        df["ema21"] = calc_ema(df["close"], EMA_SLOW)
        df["rsi14"] = calc_rsi(df["close"], RSI_PERIOD)
        df["atr14"] = calc_atr(df, ATR_PERIOD)
        return df

    except Exception as exc:
        log.error("%s: failed to fetch bars — %s", symbol, exc)
        return None


# ── Simulated order execution ─────────────────────────────────────────────────

def place_buy(ticker: str, price: float, reasoning: str, portfolio: dict):
    spec               = FUTURES_SPECS[ticker]
    margin_per         = spec["margin_per_contract"]
    contracts_possible = int(min(POSITION_SIZE_USD, portfolio["cash"]) / margin_per)

    if contracts_possible < 1:
        log.warning(
            "%s: need $%.0f margin/contract but only $%.2f available in position budget — skipping",
            ticker, margin_per, min(POSITION_SIZE_USD, portfolio["cash"]),
        )
        return

    margin_total = contracts_possible * margin_per
    portfolio["cash"] -= margin_total
    portfolio["positions"][ticker] = {
        "contracts":  contracts_possible,
        "avg_price":  round(price, 2),
        "margin_held": round(margin_total, 2),
        "entry_time": datetime.now(timezone.utc).isoformat(),
    }
    save_portfolio(portfolio)

    log.info(
        "BUY  %-6s  contracts=%d @ %.2f  margin=$%.0f  cash_left=$%.2f",
        ticker, contracts_possible, price, margin_total, portfolio["cash"],
    )
    log_trade(ticker, "BUY", price, contracts_possible, portfolio["cash"], reasoning)
    notify_discord(ticker, "BUY", price, contracts_possible, portfolio["cash"], None, reasoning)


def place_sell(ticker: str, price: float, reasoning: str, portfolio: dict):
    pos = portfolio["positions"].get(ticker)
    if not pos:
        return

    spec       = FUTURES_SPECS[ticker]
    contracts  = pos["contracts"]
    avg_price  = pos["avg_price"]
    margin     = pos["margin_held"]

    # P&L = (exit_price - entry_price) × multiplier × contracts
    pnl      = (price - avg_price) * spec["multiplier"] * contracts
    returned = margin + pnl  # return margin plus any gain/loss

    portfolio["cash"] += returned
    del portfolio["positions"][ticker]
    save_portfolio(portfolio)

    log.info(
        "SELL %-6s  contracts=%d @ %.2f  pnl=$%+.2f  cash=$%.2f",
        ticker, contracts, price, pnl, portfolio["cash"],
    )
    log_trade(ticker, "SELL", price, contracts, portfolio["cash"], reasoning, pnl)
    notify_discord(ticker, "SELL", price, contracts, portfolio["cash"], pnl, reasoning)


# ── Strategy evaluation ───────────────────────────────────────────────────────

def evaluate(symbol: str, portfolio: dict):
    # Time filter — checked before any API calls to avoid wasted fetches.
    if not in_trading_window():
        log.debug("%-6s  outside trading window (ET %s) — skipping",
                  symbol, datetime.now(ET).strftime("%H:%M"))
        return

    df = fetch_bars(symbol)
    if df is None or len(df) < 3:
        return

    price  = df["close"].iloc[-1]
    rsi_v  = df["rsi14"].iloc[-1]
    atr_v  = float(df["atr14"].iloc[-1])

    e9_now,  e9_prev  = df["ema9"].iloc[-1],  df["ema9"].iloc[-2]
    e21_now, e21_prev = df["ema21"].iloc[-1], df["ema21"].iloc[-2]

    bullish_cross = (e9_prev <= e21_prev) and (e9_now > e21_now)
    bearish_cross = (e9_prev >= e21_prev) and (e9_now < e21_now)

    in_position = symbol in portfolio["positions"]

    if in_position:
        pos        = portfolio["positions"][symbol]
        stop_price = round(pos["avg_price"] * (1 - STOP_LOSS_PCT), 2)

        # ── Stop loss — highest-priority exit ─────────────────────────────────
        if price <= stop_price:
            place_sell(
                symbol, price,
                f"Stop loss: {price:.2f} ≤ stop {stop_price:.2f} "
                f"({STOP_LOSS_PCT*100:.0f}% below entry {pos['avg_price']:.2f})",
                portfolio,
            )
            return

        # ── Regular exits ─────────────────────────────────────────────────────
        reason = None
        if rsi_v > RSI_SELL_TRIGGER:
            reason = f"RSI {rsi_v:.1f} exceeded sell threshold {RSI_SELL_TRIGGER}"
        elif bearish_cross:
            reason = (
                f"9 EMA ({e9_now:.2f}) crossed below 21 EMA ({e21_now:.2f}); "
                f"RSI {rsi_v:.1f}"
            )

        if reason:
            place_sell(symbol, price, reason, portfolio)
        else:
            spec       = FUTURES_SPECS[symbol]
            unrealized = (price - pos["avg_price"]) * spec["multiplier"] * pos["contracts"]
            log.debug(
                "%-6s  holding  contracts=%d  avg=%.2f  now=%.2f  stop=%.2f  "
                "unrealized=$%+.2f  atr=%.2f  ema9=%.2f ema21=%.2f rsi=%.1f",
                symbol, pos["contracts"], pos["avg_price"], price, stop_price,
                unrealized, atr_v, e9_now, e21_now, rsi_v,
            )
        return

    # ── No crossover signal → nothing to do ──────────────────────────────────
    if not (bullish_cross and rsi_v < RSI_BUY_MAX):
        log.debug(
            "%-6s  no signal  ema9=%.2f ema21=%.2f rsi=%.1f bullish=%s",
            symbol, e9_now, e21_now, rsi_v, bullish_cross,
        )
        return

    # ── Entry guards ──────────────────────────────────────────────────────────

    # 1. Max one futures position across all instruments
    if len(portfolio["positions"]) >= MAX_OPEN_CONTRACTS:
        log.info("%-6s  SKIP  already holding %d futures position(s) (max %d)",
                 symbol, len(portfolio["positions"]), MAX_OPEN_CONTRACTS)
        return

    # 2. ATR volatility filter
    atr_threshold = float(df["atr14"].dropna().quantile(ATR_HIGH_PERCENTILE))
    if atr_v > atr_threshold:
        log.info(
            "%-6s  SKIP  ATR %.2f > %.0fth-pct threshold %.2f — elevated volatility",
            symbol, atr_v, ATR_HIGH_PERCENTILE * 100, atr_threshold,
        )
        return

    place_buy(
        symbol, price,
        f"9 EMA ({e9_now:.2f}) crossed above 21 EMA ({e21_now:.2f}); "
        f"RSI {rsi_v:.1f} < {RSI_BUY_MAX}; ATR {atr_v:.2f} within normal range",
        portfolio,
    )


# ── Main loop ─────────────────────────────────────────────────────────────────

def portfolio_summary(portfolio: dict) -> str:
    positions = portfolio["positions"]
    if not positions:
        return "no open positions"
    parts = []
    for sym, pos in positions.items():
        parts.append(f"{sym}×{pos['contracts']}@{pos['avg_price']:.2f}")
    return ", ".join(parts)


def main():
    portfolio = load_portfolio()

    log.info("═══ Futures Paper Trading Bot started ═══")
    log.info("Cash: $%.2f  |  Positions: %s", portfolio["cash"], portfolio_summary(portfolio))
    log.info("Instruments: %s", ", ".join(WATCHLIST))
    log.info(
        "Strategy: EMA %d/%d | RSI buy<%.0f sell>%.0f | position_size=$%.0f",
        EMA_FAST, EMA_SLOW, RSI_BUY_MAX, RSI_SELL_TRIGGER, POSITION_SIZE_USD,
    )
    for sym, spec in FUTURES_SPECS.items():
        log.info(
            "  %s  multiplier=%d  margin_per_contract=$%.0f",
            sym, spec["multiplier"], spec["margin_per_contract"],
        )
    log.info(
        "Protections: stop=%.0f%% | max_contracts=%d | atr_filter=top_%.0fpct | windows=%s",
        STOP_LOSS_PCT * 100, MAX_OPEN_CONTRACTS, ATR_HIGH_PERCENTILE * 100,
        " & ".join(f"{s.strftime('%H:%M')}–{e.strftime('%H:%M')} ET" for s, e in TRADE_WINDOWS),
    )

    try:
        while True:
            log.info("═══ Scanning futures (%s) ═══", ", ".join(WATCHLIST))
            log.info("Cash: $%.2f  |  Positions: %s", portfolio["cash"], portfolio_summary(portfolio))

            for symbol in WATCHLIST:
                try:
                    evaluate(symbol, portfolio)
                except Exception:
                    log.exception("Unhandled error evaluating %s", symbol)
                time.sleep(2.0)  # polite rate-limiting for yfinance

            log.info("Scan complete. Next scan in %d min.", LOOP_SECONDS // 60)
            time.sleep(LOOP_SECONDS)

    except KeyboardInterrupt:
        log.info("Bot stopped by user.")
        log.info(
            "Final — Cash: $%.2f  |  Positions: %s",
            portfolio["cash"], portfolio_summary(portfolio),
        )


if __name__ == "__main__":
    main()
