#!/usr/bin/env python3
"""
W19D1: Vanilla CartPole - Understanding RL from Scratch
========================================================

Welcome! This script will teach you reinforcement learning step by step.
No black boxes - you'll understand every line of code.

We'll progress through 3 agents:
  1. Random Agent     - Does nothing smart (baseline)
  2. Hand-coded Agent - Uses simple rules (surprisingly good!)
  3. Q-Learning Agent - Actually LEARNS from experience

Run: python vanilla_starter.py

Learning Goals:
  - Understand the CartPole environment
  - See why random actions fail
  - Discover that simple rules can work
  - Watch an agent learn in real-time
  - Build intuition for hyperparameters
"""

# =============================================================================
# SECTION 0: SETUP (runs automatically)
# =============================================================================

import os
import sys
import subprocess
import shutil

VENV_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".venv_vanilla")
REQUIREMENTS = ["gymnasium", "numpy"]

def is_in_venv():
    return hasattr(sys, 'real_prefix') or (
        hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix
    ) or os.environ.get("VANILLA_VENV_ACTIVE") == "1"

def setup_venv():
    print("=" * 60)
    print("Setting up virtual environment...")
    print("=" * 60)

    if not os.path.exists(VENV_DIR):
        print(f"Creating venv at {VENV_DIR}...")
        subprocess.run([sys.executable, "-m", "venv", VENV_DIR], check=True)

    if sys.platform == "win32":
        pip_path = os.path.join(VENV_DIR, "Scripts", "pip")
        python_path = os.path.join(VENV_DIR, "Scripts", "python")
    else:
        pip_path = os.path.join(VENV_DIR, "bin", "pip")
        python_path = os.path.join(VENV_DIR, "bin", "python")

    print("Installing gymnasium...")
    subprocess.run([pip_path, "install", "--quiet", "--upgrade", "pip"], check=True)
    subprocess.run([pip_path, "install", "--quiet"] + REQUIREMENTS, check=True)
    print("Ready!\n")

    return python_path

def cleanup_venv():
    if os.path.exists(VENV_DIR):
        print("\nCleaning up virtual environment...")
        shutil.rmtree(VENV_DIR)

def run_in_venv():
    python_path = setup_venv()
    env = os.environ.copy()
    env["VANILLA_VENV_ACTIVE"] = "1"
    args = [python_path, __file__] + sys.argv[1:]
    result = subprocess.run(args, env=env)
    if "--keep-venv" not in sys.argv:
        cleanup_venv()
    sys.exit(result.returncode)

if not is_in_venv():
    run_in_venv()

# =============================================================================
# SECTION 1: IMPORTS
# =============================================================================

import time
import gymnasium as gym
import numpy as np
from collections import defaultdict

# =============================================================================
# SECTION 2: CONFIGURATION
# =============================================================================

STUDENT_NAME = "Your Name"  # <-- Change this!

# How many episodes to run for each agent
RANDOM_EPISODES = 5
HANDCODED_EPISODES = 3
QLEARNING_EPISODES = 5000  # More episodes to see real learning

# Q-Learning hyperparameters (try changing these!)
LEARNING_RATE = 0.4      # How fast to learn (0.01 - 1.0)
DISCOUNT_FACTOR = 0.999   # How much to value future rewards (0.9 - 0.999)
EPSILON_START = 1.0      # Starting exploration rate
EPSILON_END = 0.01       # Minimum exploration rate
EPSILON_DECAY = 0.995     # Faster decay = exploit sooner

# =============================================================================
# SECTION 3: HELPER FUNCTIONS
# =============================================================================

def print_header(title):
    """Print a fancy header."""
    print("\n" + "=" * 60)
    print(f"  {title}")
    print("=" * 60)

def print_subheader(title):
    """Print a subheader."""
    print(f"\n--- {title} ---")

def print_state(state, step):
    """Print the current state in a readable format."""
    cart_pos, cart_vel, pole_angle, pole_vel = state

    # Convert angle to degrees for readability
    angle_deg = np.degrees(pole_angle)

    # Visual indicator of pole position
    if abs(angle_deg) < 1:
        pole_visual = "    |    "  # Centered
    elif angle_deg > 0:
        tilt = min(int(angle_deg / 2), 4)
        pole_visual = "    " + " " * tilt + "/" + " " * (4 - tilt)
    else:
        tilt = min(int(-angle_deg / 2), 4)
        pole_visual = " " * (4 - tilt) + "\\" + " " * tilt + "    "

    print(f"  Step {step:3d}: [{pole_visual}]  "
          f"angle={angle_deg:+5.1f}°  "
          f"cart_pos={cart_pos:+5.2f}")

def explain_cartpole():
    """Explain the CartPole environment."""
    print_header("UNDERSTANDING CARTPOLE")

    print("""
    THE GOAL: Balance a pole on a moving cart!

                    |
                    |  <-- Pole (don't let it fall!)
                    |
               [ cart ]  <-- You control this
            ================= <-- Track

    OBSERVATIONS (what the agent sees):
      1. Cart Position  : Where is the cart? (-2.4 to +2.4)
      2. Cart Velocity  : How fast is it moving?
      3. Pole Angle     : How tilted is the pole? (-12° to +12°)
      4. Pole Velocity  : How fast is it falling?

    ACTIONS (what the agent can do):
      0 = Push cart LEFT  <--
      1 = Push cart RIGHT -->

    REWARDS:
      +1 for every step the pole stays up!

    GAME OVER when:
      - Pole tilts more than 12 degrees
      - Cart moves off the track (|position| > 2.4)
      - You survive 500 steps (you win!)

    Press ENTER to continue...""")
    input()


# =============================================================================
# SECTION 4: AGENT 1 - RANDOM AGENT
# =============================================================================

def run_random_agent():
    """
    RANDOM AGENT
    ============
    Strategy: Pick a random action every time.
    Expected score: ~20-25 (very bad!)

    This is our BASELINE - any learning algorithm should beat this.
    """
    print_header("AGENT 1: RANDOM AGENT")

    print("""
    Strategy: Pick LEFT or RIGHT randomly each step.

    Why do this? To establish a BASELINE.
    If your fancy algorithm can't beat random, something is wrong!

    Let's see how bad random is...
    """)
    time.sleep(1)

    env = gym.make("CartPole-v1")
    scores = []

    for episode in range(RANDOM_EPISODES):
        print_subheader(f"Episode {episode + 1}/{RANDOM_EPISODES}")

        state, _ = env.reset()
        total_reward = 0

        for step in range(500):
            # RANDOM: Just pick 0 or 1 randomly
            action = np.random.randint(0, 2)
            action_name = "LEFT <--" if action == 0 else "RIGHT -->"

            # Take the action
            next_state, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward

            # Show first few steps
            if step < 5 or (terminated and step < 20):
                print_state(state, step)
                print(f"           Action: {action_name}")
            elif step == 5:
                print("           ...")

            state = next_state

            if terminated or truncated:
                break

        scores.append(total_reward)
        status = "FELL!" if total_reward < 500 else "SURVIVED!"
        print(f"\n  Result: {status} Score: {total_reward:.0f}")

    env.close()

    avg_score = np.mean(scores)
    print(f"\n  RANDOM AGENT AVERAGE: {avg_score:.1f}")
    print(f"  (This is our baseline to beat!)")

    return avg_score


# =============================================================================
# SECTION 5: AGENT 2 - HAND-CODED RULES
# =============================================================================

def run_handcoded_agent():
    """
    HAND-CODED AGENT
    ================
    Strategy: Simple rule - if pole tilting right, push right.
    Expected score: ~500 (perfect!)

    This shows that CartPole is actually "easy" if you think about it.
    The challenge is: can a machine LEARN this rule?
    """
    print_header("AGENT 2: HAND-CODED RULES")

    print("""
    Strategy: Use physics intuition!

    Think about it:
      - If the pole is falling RIGHT... push RIGHT to catch it!
      - If the pole is falling LEFT... push LEFT to catch it!

    The code is just:

        if pole_angle + pole_velocity > 0:
            action = RIGHT
        else:
            action = LEFT

    That's it! Let's see if it works...
    """)
    time.sleep(1)

    env = gym.make("CartPole-v1")
    scores = []

    for episode in range(HANDCODED_EPISODES):
        print_subheader(f"Episode {episode + 1}/{HANDCODED_EPISODES}")

        state, _ = env.reset()
        total_reward = 0

        for step in range(500):
            cart_pos, cart_vel, pole_angle, pole_vel = state

            # HAND-CODED RULE: If pole falling right, push right
            # We add velocity to predict where it's GOING to be
            if pole_angle + 0.1 * pole_vel > 0:
                action = 1  # RIGHT
            else:
                action = 0  # LEFT

            action_name = "LEFT <--" if action == 0 else "RIGHT -->"

            next_state, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward

            # Show some steps
            if step < 3 or step == 250 or step == 499:
                print_state(state, step)
                print(f"           Action: {action_name}")
            elif step == 3:
                print("           ... (running) ...")

            state = next_state

            if terminated or truncated:
                break

        scores.append(total_reward)
        status = "FELL!" if total_reward < 500 else "PERFECT!"
        print(f"\n  Result: {status} Score: {total_reward:.0f}")

    env.close()

    avg_score = np.mean(scores)
    print(f"\n  HAND-CODED AGENT AVERAGE: {avg_score:.1f}")

    if avg_score >= 450:
        print("""
  WOW! The hand-coded agent is nearly perfect!

  So why do we need machine learning?

  1. We had to THINK to create these rules
  2. CartPole is simple - real problems are harder
  3. Can a machine discover these rules on its own?

  Let's find out with Q-Learning...
        """)

    return avg_score


# =============================================================================
# SECTION 6: AGENT 3 - Q-LEARNING
# =============================================================================

def discretize_state(state, bins):
    """
    Convert continuous state to discrete bins.

    Q-Learning needs discrete states (like squares on a chess board).
    CartPole has continuous states (any decimal number).

    Solution: Divide the range into "bins" (buckets).
    Example: angle from -12° to +12° divided into 10 bins.
    """
    cart_pos, cart_vel, pole_angle, pole_vel = state

    # Define the bins for each observation
    cart_pos_bin = np.digitize(cart_pos, bins["cart_pos"])
    cart_vel_bin = np.digitize(cart_vel, bins["cart_vel"])
    pole_angle_bin = np.digitize(pole_angle, bins["pole_angle"])
    pole_vel_bin = np.digitize(pole_vel, bins["pole_vel"])

    return (cart_pos_bin, cart_vel_bin, pole_angle_bin, pole_vel_bin)


def run_qlearning_agent():
    """
    Q-LEARNING AGENT
    ================
    Strategy: Learn a Q-table that maps states to action values.
    Expected score: Starts bad (~20), improves to ~100-300+

    This is REAL learning! The agent starts knowing nothing
    and gradually discovers what works.
    """
    print_header("AGENT 3: Q-LEARNING (Watch it Learn!)")

    print("""
    Strategy: Learn from experience!

    Q-Learning builds a "cheat sheet" called the Q-table:

        State           | Action LEFT | Action RIGHT
        ----------------+-------------+-------------
        Pole tilted R   |    -5.2     |    +12.8    <-- RIGHT is better!
        Pole tilted L   |   +11.4     |    -3.1     <-- LEFT is better!
        ...             |    ...      |    ...

    The algorithm:
        1. Try an action (explore randomly at first)
        2. See what reward you get
        3. Update the Q-table: "this state-action was good/bad"
        4. Repeat!

    Key hyperparameters:
        - Learning Rate ({lr}): How fast to update beliefs
        - Discount Factor ({df}): How much to value future rewards
        - Epsilon ({eps}): How often to explore vs exploit

    Watch the scores improve over time!
    """.format(lr=LEARNING_RATE, df=DISCOUNT_FACTOR, eps=EPSILON_START))
    time.sleep(2)

    # Create bins for discretizing continuous states
    # More bins = finer resolution = better learning (but slower)
    bins = {
        "cart_pos": np.linspace(-2.4, 2.4, 12),
        "cart_vel": np.linspace(-3, 3, 12),
        "pole_angle": np.linspace(-0.21, 0.21, 24),  # Fine resolution for angle!
        "pole_vel": np.linspace(-3, 3, 12),
    }

    # Initialize Q-table with zeros
    # This is our "cheat sheet" - starts empty, fills with experience
    q_table = defaultdict(lambda: np.zeros(2))

    # Track statistics
    scores = []
    epsilon = EPSILON_START

    env = gym.make("CartPole-v1")

    print_subheader("Training Progress")
    print(f"  {'Episode':>8} | {'Score':>6} | {'Avg(10)':>8} | {'Epsilon':>8} | Status")
    print("  " + "-" * 55)

    for episode in range(QLEARNING_EPISODES):
        state, _ = env.reset()
        discrete_state = discretize_state(state, bins)
        total_reward = 0

        for step in range(500):
            # EXPLORATION vs EXPLOITATION
            # Early: explore randomly to discover good actions
            # Later: exploit what we've learned
            if np.random.random() < epsilon:
                action = np.random.randint(0, 2)  # Explore
            else:
                action = np.argmax(q_table[discrete_state])  # Exploit

            # Take action, observe result
            next_state, reward, terminated, truncated, _ = env.step(action)
            next_discrete_state = discretize_state(next_state, bins)
            total_reward += reward

            # Q-LEARNING UPDATE
            # This is where the magic happens!
            # Q(s,a) = Q(s,a) + lr * (reward + gamma * max(Q(s')) - Q(s,a))

            old_value = q_table[discrete_state][action]

            if terminated:
                # Terminal state: no future rewards
                td_target = reward
            else:
                # Non-terminal: reward + discounted future value
                td_target = reward + DISCOUNT_FACTOR * np.max(q_table[next_discrete_state])

            # Update Q-value toward the target
            new_value = old_value + LEARNING_RATE * (td_target - old_value)
            q_table[discrete_state][action] = new_value

            discrete_state = next_discrete_state

            if terminated or truncated:
                break

        # Decay exploration rate
        epsilon = max(EPSILON_END, epsilon * EPSILON_DECAY)
        scores.append(total_reward)

        # Print progress
        avg_10 = np.mean(scores[-10:]) if len(scores) >= 10 else np.mean(scores)

        if total_reward >= 450:
            status = "EXCELLENT!"
        elif total_reward >= 200:
            status = "Great!"
        elif total_reward >= 100:
            status = "Learning!"
        elif total_reward >= 50:
            status = "Improving"
        else:
            status = "Exploring"

        # Print at key moments
        show_progress = (
            (episode + 1) % 25 == 0 or  # Every 25 episodes
            episode < 5 or               # First few
            total_reward >= 200 or       # Good scores
            (total_reward >= 100 and max(scores[:-1] if len(scores) > 1 else [0]) < 100)  # First 100+
        )

        if show_progress:
            print(f"  {episode + 1:>8} | {total_reward:>6.0f} | {avg_10:>8.1f} | {epsilon:>8.3f} | {status}")

    env.close()

    # Final statistics
    print_subheader("Q-Learning Results")

    avg_first_10 = np.mean(scores[:10])
    avg_last_10 = np.mean(scores[-10:])
    best_score = max(scores)

    print(f"""
    First 10 episodes average: {avg_first_10:.1f}
    Last 10 episodes average:  {avg_last_10:.1f}
    Best score:                {best_score:.0f}

    States discovered: {len(q_table)} unique states
    """)

    improvement = avg_last_10 - avg_first_10
    if improvement > 50:
        print(f"    The agent IMPROVED by {improvement:.0f} points!")
        print("    It learned something! But can it do better?")
    else:
        print("    Hmm, not much improvement. Try adjusting hyperparameters!")
        print("    Hint: Try more episodes or different learning rate.")

    return avg_last_10


# =============================================================================
# SECTION 7: COMPARISON & NEXT STEPS
# =============================================================================

def show_comparison(random_score, handcoded_score, qlearning_score):
    """Show final comparison of all agents."""
    print_header("FINAL COMPARISON")

    print("""
    AGENT PERFORMANCE SUMMARY
    =========================""")

    bar_width = 40
    max_score = 500

    def make_bar(score, max_score, width):
        filled = int((score / max_score) * width)
        return "[" + "#" * filled + " " * (width - filled) + "]"

    print(f"""
    Random:     {make_bar(random_score, max_score, bar_width)} {random_score:>6.1f}
    Hand-coded: {make_bar(handcoded_score, max_score, bar_width)} {handcoded_score:>6.1f}
    Q-Learning: {make_bar(qlearning_score, max_score, bar_width)} {qlearning_score:>6.1f}

    Target:     {make_bar(500, max_score, bar_width)} {500:>6.1f}
    """)

    print("""
    KEY INSIGHTS
    ============

    1. RANDOM is terrible (as expected)
       - This is our baseline - anything should beat this

    2. HAND-CODED is perfect (but we had to think!)
       - Shows the problem IS solvable
       - But requires human insight

    3. Q-LEARNING actually learns!
       - Started from nothing
       - Discovered patterns through trial and error
       - Might not be perfect, but it LEARNED
    """)


def show_next_steps():
    """Show what to explore next."""
    print_header("NEXT STEPS")

    print("""
    EXPERIMENTS TO TRY
    ==================

    1. TUNE Q-LEARNING HYPERPARAMETERS (edit the CONFIG section):

       LEARNING_RATE    = 0.2   # Try: 0.05, 0.1, 0.2, 0.5
       DISCOUNT_FACTOR  = 0.99  # Try: 0.9, 0.95, 0.99, 0.999
       EPSILON_DECAY    = 0.99  # Try: 0.98, 0.99, 0.995 (slower = more exploration)
       QLEARNING_EPISODES = 500 # Try: 200, 500, 1000, 2000

    2. QUESTIONS TO EXPLORE:

       - What happens with LEARNING_RATE = 1.0? (too fast?)
       - What happens with DISCOUNT_FACTOR = 0.5? (short-sighted?)
       - What if we NEVER explore (EPSILON_START = 0)?
       - Can Q-Learning reach 500? How many episodes?

    3. ADVANCED: Try PPO (baseline_starter.py)

       PPO uses neural networks instead of Q-tables.
       It should learn faster and achieve higher scores!


    WHAT YOU LEARNED TODAY
    ======================

    - The CartPole environment (state, actions, rewards)
    - Why baselines matter (random agent)
    - That simple rules can work (hand-coded)
    - How Q-Learning discovers patterns (real ML!)
    - The explore vs exploit tradeoff (epsilon)
    - Why hyperparameters matter

    Great job! You now understand the fundamentals of RL!
    """)


# =============================================================================
# SECTION 8: MAIN
# =============================================================================

def main():
    print("\n" + "=" * 60)
    print("  W19D1: VANILLA CARTPOLE - Learning RL from Scratch")
    print("  Student: " + STUDENT_NAME)
    print("=" * 60)

    # Explain CartPole first
    explain_cartpole()

    # Run all three agents
    random_score = run_random_agent()

    print("\n  Press ENTER to continue to Hand-coded Agent...")
    input()

    handcoded_score = run_handcoded_agent()

    print("\n  Press ENTER to continue to Q-Learning...")
    input()

    qlearning_score = run_qlearning_agent()

    # Show comparison
    show_comparison(random_score, handcoded_score, qlearning_score)

    # Show next steps
    show_next_steps()

    print("\n" + "=" * 60)
    print("  Done! Now try changing the hyperparameters and re-run!")
    print("=" * 60 + "\n")


if __name__ == "__main__":
    main()
