#!/usr/bin/env python3
"""
================================================================================
W20D2: MY CARTPOLE SOLUTION
================================================================================

Student: James Wilson
Date: 2026-02-10
Algorithm: Double Dueling Deep Q-Network (DQN) - OPTIMIZED

================================================================================
DESIGN DECISIONS
================================================================================

1. Algorithm Choice: Double Dueling DQN
   Why:
   - Double DQN reduces Q-value overestimation
   - Dueling architecture improves state value estimation
   - Combined approach is stable and data-efficient for CartPole

2. Network Architecture:
   - Input: 4 state values
   - Hidden layers: 2 x 64 (ReLU) - smaller but sufficient for CartPole
   - Separate Value and Advantage streams

3. Key Enhancements:
   - Experience Replay (20k buffer)
   - Soft Target Network Updates (Polyak averaging, tau=0.005)
   - Huber loss for stability
   - Slower epsilon decay for thorough exploration

================================================================================
OPTIMIZATIONS APPLIED
================================================================================
- Learning rate: 5e-4 (more stable than 1e-3)
- Batch size: 128 (smoother gradients)
- Soft target updates every step (tau=0.005)
- Epsilon decay: 0.995 (slower for better exploration)
- Terminal penalty: -1.0 (less aggressive)
- Gradient clipping: 1.0 (prevents instability)
- Train every 4 steps (reduces sample correlation)
- Network: 64 hidden units (faster, sufficient for CartPole)

================================================================================
"""

import os, sys, subprocess

# =============================================================================
# PYTHON VERSION + VENV BOOTSTRAP
# =============================================================================

def ensure_venv():
    """
    Always create the venv using Python 3.12 explicitly.
    This avoids PyTorch incompatibility with Python 3.13.
    """
    if sys.prefix != sys.base_prefix:
        return  # already inside venv

    root = os.path.dirname(os.path.abspath(__file__))
    venv_dir = os.path.join(root, ".venv")

    PYTHON_CANDIDATES = [
        "/opt/homebrew/opt/python@3.12/bin/python3",
        "/opt/homebrew/opt/python@3.11/bin/python3",
        "/opt/homebrew/opt/python@3.10/bin/python3",
        sys.executable,
    ]

    PYTHON_OK = None
    for p in PYTHON_CANDIDATES:
        if os.path.exists(p):
            PYTHON_OK = p
            break

    if PYTHON_OK is None:
        print("\nERROR: No compatible Python (3.10-3.12) found.")
        print("Install one with:")
        print("  brew install python@3.12\n")
        sys.exit(1)

    if sys.platform == "win32":
        venv_python = os.path.join(venv_dir, "Scripts", "python.exe")
    else:
        venv_python = os.path.join(venv_dir, "bin", "python")

    if not os.path.exists(venv_python):
        print(f"Creating virtual environment with {PYTHON_OK}...")
        subprocess.check_call([PYTHON_OK, "-m", "venv", venv_dir])

        print("Installing dependencies...")
        subprocess.check_call([venv_python, "-m", "pip", "install", "--upgrade", "pip"])
        subprocess.check_call([
            venv_python, "-m", "pip", "install",
            "torch",
            "numpy",
            "gymnasium[classic-control]",
            "matplotlib"
        ])

    print("Re-launching inside Python 3.12 virtual environment...\n")
    subprocess.check_call([venv_python] + sys.argv)
    sys.exit(0)


if __name__ == "__main__":
    ensure_venv()

# =============================================================================
# IMPORTS
# =============================================================================

try:
    import gymnasium as gym
except ImportError:
    import gym

import json
import random
from collections import deque, namedtuple

import numpy as np
import torch
import torch.nn as nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =============================================================================
# DUELING NETWORK (Optimized: 64 hidden units)
# =============================================================================

class DuelingQNetwork(nn.Module):
    def __init__(self, state_dim=4, action_dim=2, hidden=64):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.value = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, 1),
        )
        self.advantage = nn.Sequential(
            nn.Linear(hidden, hidden // 2),
            nn.ReLU(),
            nn.Linear(hidden // 2, action_dim),
        )

    def forward(self, x):
        x = self.feature(x)
        v = self.value(x)
        a = self.advantage(x)
        return v + (a - a.mean(dim=1, keepdim=True))

# =============================================================================
# REPLAY BUFFER
# =============================================================================

Transition = namedtuple("Transition", ["state", "action", "reward", "next_state", "done"])

class ReplayBuffer:
    def __init__(self, capacity=20_000):
        self.buffer = deque(maxlen=capacity)

    def push(self, *args):
        self.buffer.append(Transition(*args))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        return Transition(*zip(*batch))

    def __len__(self):
        return len(self.buffer)

# =============================================================================
# REQUIRED AGENT CLASS
# =============================================================================

class MyAgent:
    """
    Evaluation-safe CartPole agent.
    """
    def __init__(self):
        self.net = DuelingQNetwork().to(DEVICE)
        self.net.eval()

        model_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            "results", "dueling_double_dqn.pt"
        )
        if os.path.exists(model_path):
            self.net.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True))

    def select_action(self, state):
        s = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        with torch.no_grad():
            q = self.net(s)
        return int(torch.argmax(q, dim=1).item())

# =============================================================================
# TRAINING (OPTIMIZED HYPERPARAMETERS)
# =============================================================================

def train(episodes=500):
    env = gym.make("CartPole-v1")
    results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
    os.makedirs(results_dir, exist_ok=True)

    policy_net = DuelingQNetwork().to(DEVICE)
    target_net = DuelingQNetwork().to(DEVICE)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    # OPTIMIZED HYPERPARAMETERS
    optimizer = torch.optim.Adam(policy_net.parameters(), lr=5e-4)  # Lower LR for stability
    buffer = ReplayBuffer(capacity=20_000)

    gamma = 0.99
    batch_size = 128          # Larger batch for stable gradients
    start_learning = 500      # Start learning earlier
    tau = 0.005               # Soft update rate
    train_freq = 4            # Train every N steps (reduce correlation)

    eps = 1.0
    eps_min = 0.01            # Lower minimum for better exploitation
    eps_decay = 0.995         # Slower decay for thorough exploration

    max_grad_norm = 1.0       # Tighter gradient clipping

    steps = 0
    scores = []

    print("=" * 60)
    print("TRAINING STARTED - Double Dueling DQN (Optimized)")
    print("=" * 60)
    print(f"LR: 5e-4 | Batch: {batch_size} | Tau: {tau} | Eps decay: {eps_decay}")
    print("=" * 60)

    for ep in range(1, episodes + 1):
        state, _ = env.reset(seed=ep)
        done = False
        total = 0.0

        while not done:
            steps += 1

            # Epsilon-greedy action selection
            if random.random() < eps:
                action = env.action_space.sample()
            else:
                s = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)
                with torch.no_grad():
                    q = policy_net(s)
                action = int(torch.argmax(q, dim=1).item())

            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            # Lighter terminal penalty
            modified_reward = reward if not terminated else -1.0
            buffer.push(state, action, modified_reward, next_state, done)

            state = next_state
            total += reward  # Track actual reward

            # Train every N steps (reduces sample correlation)
            if len(buffer) >= batch_size and steps >= start_learning and steps % train_freq == 0:
                batch = buffer.sample(batch_size)

                states = torch.tensor(np.array(batch.state), dtype=torch.float32, device=DEVICE)
                actions = torch.tensor(batch.action, dtype=torch.int64, device=DEVICE).unsqueeze(1)
                rewards = torch.tensor(batch.reward, dtype=torch.float32, device=DEVICE).unsqueeze(1)
                next_states = torch.tensor(np.array(batch.next_state), dtype=torch.float32, device=DEVICE)
                dones = torch.tensor(batch.done, dtype=torch.float32, device=DEVICE).unsqueeze(1)

                # Current Q values
                q_sa = policy_net(states).gather(1, actions)

                # Double DQN: use policy net to select actions, target net to evaluate
                with torch.no_grad():
                    next_actions = policy_net(next_states).argmax(dim=1, keepdim=True)
                    q_next = target_net(next_states).gather(1, next_actions)
                    target = rewards + (1 - dones) * gamma * q_next

                # Huber loss (smooth L1)
                loss = torch.nn.functional.smooth_l1_loss(q_sa, target)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_grad_norm)
                optimizer.step()

                # Soft update target network (Polyak averaging)
                with torch.no_grad():
                    for target_param, policy_param in zip(target_net.parameters(), policy_net.parameters()):
                        target_param.data.mul_(1 - tau)
                        target_param.data.add_(tau * policy_param.data)

        scores.append(total)
        eps = max(eps_min, eps * eps_decay)

        if ep % 10 == 0:
            avg20 = np.mean(scores[-20:])
            avg50 = np.mean(scores[-50:]) if len(scores) >= 50 else avg20
            print(f"Episode {ep:4d} | Score: {total:5.0f} | Avg20: {avg20:6.1f} | Avg50: {avg50:6.1f} | Eps: {eps:.3f}")

        # Check if solved (avg >= 475 over 100 episodes)
        if len(scores) >= 100 and np.mean(scores[-100:]) >= 475:
            print(f"\n{'='*60}")
            print(f"SOLVED at episode {ep}! Avg(100): {np.mean(scores[-100:]):.1f}")
            print(f"{'='*60}")
            break

    # Save model
    model_path = os.path.join(results_dir, "dueling_double_dqn.pt")
    torch.save(policy_net.state_dict(), model_path)
    print(f"\nModel saved to {model_path}")

    env.close()

    agent = MyAgent()
    agent.net.load_state_dict(policy_net.state_dict())
    return scores, agent

# =============================================================================
# DEMO
# =============================================================================

def demo(agent, episodes=5):
    env = gym.make("CartPole-v1", render_mode="human")
    for ep in range(episodes):
        state, _ = env.reset()
        done = False
        total = 0
        while not done:
            action = agent.select_action(state)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            total += reward
        print(f"Demo Episode {ep+1}: Score = {total}")
    env.close()

# =============================================================================
# MAIN
# =============================================================================

if __name__ == "__main__":
    scores, agent = train()

    results_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results")
    with open(os.path.join(results_dir, "my_results.json"), "w") as f:
        json.dump({
            "algorithm": "Double Dueling DQN (Optimized)",
            "final_avg": float(np.mean(scores[-100:])) if len(scores) >= 100 else float(np.mean(scores)),
            "best_score": float(max(scores)),
            "episodes": len(scores),
            "hyperparameters": {
                "learning_rate": 5e-4,
                "batch_size": 128,
                "gamma": 0.99,
                "tau": 0.005,
                "epsilon_decay": 0.995,
                "buffer_size": 20000,
                "train_freq": 4,
                "hidden_size": 64
            }
        }, f, indent=2)

    print(f"\nFinal Avg(last 100): {np.mean(scores[-100:]) if len(scores) >= 100 else np.mean(scores):.1f}")
    print(f"Best Score: {max(scores):.0f}")

    # Uncomment to watch the trained agent:
    # demo(agent)
