#!/usr/bin/env python3
"""
================================================================================
W17D2: SST-2 Sentiment Classification - INSTRUCTOR DEMO
================================================================================

This is a TEACHING script - it's intentionally verbose with lots of print
statements and explanations. Run this while explaining each section to students.

WHAT WE'RE DOING:
    Fine-tuning a pretrained transformer (DistilBERT) on sentiment classification.
    The model already "knows" English from pretraining - we're teaching it
    specifically to detect positive vs negative sentiment.

WHAT STUDENTS WILL LEARN:
    1. How the Hugging Face ecosystem works (datasets, tokenizers, models, Trainer)
    2. What happens during fine-tuning (updating weights on task-specific data)
    3. How to configure TrainingArguments (the settings they'll experiment with)
    4. How to evaluate and interpret results

HARDWARE: Optimized for M4 Mac Pro with MPS acceleration (also works on CUDA/CPU)
DATASET:  SST-2 (Stanford Sentiment Treebank) - movie review sentiment
MODEL:    DistilBERT (66M params) - fast baseline, students can try larger models

EXPECTED RESULTS:
    - DistilBERT: ~89-91% accuracy in ~5 minutes
    - BERT-base:  ~92-93% accuracy in ~12 minutes
    - RoBERTa:    ~94-95% accuracy in ~12 minutes

Run: python demo_training.py
================================================================================
"""

import os
import sys
import time
from datetime import datetime

# =============================================================================
# STEP 0: DEPENDENCY CHECK
# =============================================================================
# Make sure all required packages are installed

REQUIRED_PACKAGES = {
    "torch": "pip install torch",
    "datasets": "pip install datasets",
    "transformers": "pip install transformers",
    "sklearn": "pip install scikit-learn",
    "numpy": "pip install numpy",
}

missing = []
for package, install_cmd in REQUIRED_PACKAGES.items():
    try:
        __import__(package)
    except ImportError:
        missing.append(f"  {package}: {install_cmd}")

if missing:
    print("=" * 70)
    print("MISSING DEPENDENCIES")
    print("=" * 70)
    print("\nPlease install the following packages:\n")
    print("\n".join(missing))
    print("\nOr install all at once:")
    print("  pip install torch datasets transformers scikit-learn numpy")
    print("\nFor Apple Silicon (M1/M2/M3/M4), use:")
    print("  pip install torch datasets transformers scikit-learn numpy accelerate")
    print("=" * 70)
    sys.exit(1)

# =============================================================================
# STEP 0b: ENVIRONMENT SETUP
# =============================================================================
# These environment variables prevent common warnings that confuse students

os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Prevents fork warning
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"  # Cleaner output

print("=" * 70)
print("W17D2: FINE-TUNING TRANSFORMERS - INSTRUCTOR DEMO")
print("=" * 70)
print(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()

# =============================================================================
# STEP 1: IMPORT LIBRARIES
# =============================================================================
print("[STEP 1] IMPORTING LIBRARIES")
print("-" * 70)
print("""
TEACHING POINT: The Hugging Face Ecosystem
------------------------------------------
Hugging Face provides four key libraries we'll use:

1. 'datasets'     - Load and process datasets (like SST-2)
2. 'transformers' - Models, tokenizers, and the Trainer API
3. 'tokenizers'   - Fast tokenization (used internally)
4. 'evaluate'     - Metrics computation (we'll use sklearn instead for clarity)

The beauty is these all work together seamlessly!
""")

print("Importing torch...")
import torch
print(f"   PyTorch version: {torch.__version__}")

print("Importing datasets...")
from datasets import load_dataset

print("Importing transformers...")
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)

print("Importing numpy and sklearn...")
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

print("\n   All imports successful!")
print()

# =============================================================================
# STEP 2: DEVICE DETECTION
# =============================================================================
print("[STEP 2] DETECTING HARDWARE")
print("-" * 70)
print("""
TEACHING POINT: GPU Acceleration
--------------------------------
Deep learning is MUCH faster on GPUs. We check for:

1. MPS  - Apple Silicon (M1/M2/M3/M4 Macs) - what we're using today
2. CUDA - NVIDIA GPUs (data centers, gaming PCs)
3. CPU  - Fallback (works but slow)

The Trainer will automatically use the best available device.
""")

if torch.backends.mps.is_available():
    DEVICE = "mps"
    print(f"   DETECTED: MPS (Apple Silicon)")
    print(f"   This Mac has unified memory - GPU and CPU share RAM!")
elif torch.cuda.is_available():
    DEVICE = "cuda"
    print(f"   DETECTED: CUDA ({torch.cuda.get_device_name(0)})")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    DEVICE = "cpu"
    print(f"   DETECTED: CPU only (training will be slower)")

print(f"\n   Using device: {DEVICE}")
print()

# =============================================================================
# STEP 3: CONFIGURATION
# =============================================================================
print("[STEP 3] CONFIGURATION")
print("-" * 70)
print("""
TEACHING POINT: Hyperparameters
-------------------------------
These are the "knobs" students will tune in the competition!

The most impactful ones:
- learning_rate:  How big are our update steps? (try 1e-5 to 5e-5)
- batch_size:     How many examples per update? (try 8, 16, 32)
- num_epochs:     How many passes through data? (try 2-5)
- warmup_ratio:   What fraction to warm up LR? (try 0.0 to 0.2)

This config comes from training-args.html - students will generate their own!
""")

# Model choice - students can experiment with these
MODEL_NAME = "distilbert-base-uncased"

print(f"MODEL: {MODEL_NAME}")
print("""
   Why DistilBERT?
   - It's a "distilled" version of BERT (40% smaller, 60% faster)
   - Still achieves 97% of BERT's performance
   - Perfect for demos and experimentation

   Students can try: bert-base-uncased, roberta-base, deberta-v3-small
""")

# Training configuration - THIS IS WHAT training-args.html GENERATES
TRAINING_CONFIG = {
    "output_dir": "./sst2_results",
    "eval_strategy": "epoch",           # Evaluate after each epoch
    "save_strategy": "epoch",           # Save checkpoint each epoch
    "learning_rate": 2e-5,              # 0.00002 - small for fine-tuning!
    "per_device_train_batch_size": 16,  # Examples per forward pass
    "per_device_eval_batch_size": 32,   # Can be larger (no gradients)
    "num_train_epochs": 3,              # Full passes through training data
    "weight_decay": 0.01,               # L2 regularization
    "warmup_ratio": 0.1,                # 10% of training for LR warmup
    "load_best_model_at_end": True,     # Keep the best checkpoint
    "metric_for_best_model": "accuracy",
    "logging_steps": 100,               # Log every 100 steps
    "fp16": False,                      # MPS doesn't support fp16
    "report_to": "none",                # No wandb/tensorboard for demo
    "seed": 42,                         # Reproducibility!
}

print("TRAINING ARGUMENTS:")
print("-" * 40)
for key, value in TRAINING_CONFIG.items():
    print(f"   {key}: {value}")
print()

# =============================================================================
# STEP 4: LOAD DATASET
# =============================================================================
print("[STEP 4] LOADING DATASET")
print("-" * 70)
print("""
TEACHING POINT: SST-2 Dataset
-----------------------------
SST-2 (Stanford Sentiment Treebank) is a classic NLP benchmark:

- Task: Binary sentiment classification (positive/negative)
- Source: Movie reviews from Rotten Tomatoes
- Part of GLUE benchmark (General Language Understanding Evaluation)

Why it's good for learning:
- Simple task everyone understands
- Fast to train (short sentences)
- Clear evaluation metric (accuracy)
""")

print("Loading SST-2 from Hugging Face Hub...")
dataset = load_dataset("glue", "sst2")

print(f"""
DATASET STRUCTURE:
   Training examples:   {len(dataset['train']):,}
   Validation examples: {len(dataset['validation']):,}

   Labels: 0 = negative, 1 = positive

   Note: SST-2 test set labels are hidden (it's a benchmark!)
         We use validation set for our evaluation.
""")

print("SAMPLE DATA:")
print("-" * 40)
for i in range(5):
    example = dataset['train'][i]
    sentiment = "POSITIVE" if example['label'] == 1 else "NEGATIVE"
    sentence = example['sentence']
    # Truncate long sentences for display
    if len(sentence) > 70:
        sentence = sentence[:67] + "..."
    print(f"   [{sentiment:8}] {sentence}")
print()

# =============================================================================
# STEP 5: LOAD TOKENIZER
# =============================================================================
print("[STEP 5] LOADING TOKENIZER")
print("-" * 70)
print("""
TEACHING POINT: What is Tokenization?
-------------------------------------
Models don't understand text - they need numbers!

Tokenization converts: "I loved this movie" -> [101, 1045, 2293, 2023, 3185, 102]

Modern tokenizers use SUBWORD tokenization:
- "unhappiness" -> ["un", "##happiness"] -> [4895, 12356]
- This handles rare words by breaking them into known pieces

Each model has its OWN tokenizer - always load the matching one!
""")

print(f"Loading tokenizer for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"""
TOKENIZER INFO:
   Vocabulary size: {tokenizer.vocab_size:,} tokens
   Model max length: {tokenizer.model_max_length:,} tokens
   Padding token: '{tokenizer.pad_token}' (id: {tokenizer.pad_token_id})
   Special tokens: [CLS]={tokenizer.cls_token_id}, [SEP]={tokenizer.sep_token_id}
""")

# Demonstrate tokenization
demo_sentence = "I absolutely loved this movie!"
tokens = tokenizer(demo_sentence)
decoded_tokens = tokenizer.convert_ids_to_tokens(tokens['input_ids'])

print("TOKENIZATION DEMO:")
print("-" * 40)
print(f"   Input:     \"{demo_sentence}\"")
print(f"   Token IDs: {tokens['input_ids']}")
print(f"   Tokens:    {decoded_tokens}")
print(f"   Attention: {tokens['attention_mask']}")
print("""
   Note: [CLS] is added at start, [SEP] at end
         Attention mask: 1 = real token, 0 = padding
""")

# =============================================================================
# STEP 6: LOAD MODEL
# =============================================================================
print("[STEP 6] LOADING MODEL")
print("-" * 70)
print("""
TEACHING POINT: Transfer Learning
---------------------------------
We're NOT training from scratch! The model already learned:
- English grammar and vocabulary
- Word meanings and relationships
- General language understanding

From where? Pretraining on massive text (Wikipedia, books, web)

We're FINE-TUNING: Adding a classification head and updating weights
specifically for sentiment analysis. Much faster than training from scratch!
""")

print(f"Loading {MODEL_NAME} for sequence classification...")
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
    id2label={0: "negative", 1: "positive"},
    label2id={"negative": 0, "positive": 1},
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"""
MODEL ARCHITECTURE:
   Base model: {MODEL_NAME}
   Task: Sequence Classification (2 labels)

PARAMETERS:
   Total parameters:     {total_params:,}
   Trainable parameters: {trainable_params:,}

   That's {total_params/1e6:.1f} million parameters we're fine-tuning!

   Compare to GPT-3: 175 BILLION parameters
   Compare to GPT-4: ~1.8 TRILLION parameters (estimated)
""")

# Show model structure
print("MODEL STRUCTURE (simplified):")
print("-" * 40)
print("""   DistilBertForSequenceClassification(
     (distilbert): DistilBertModel(
       (embeddings): Word + Position embeddings
       (transformer): 6 transformer layers
     )
     (pre_classifier): Linear(768 -> 768)  <- NEW for fine-tuning
     (classifier): Linear(768 -> 2)        <- NEW for fine-tuning
     (dropout): Dropout(p=0.2)
   )
""")
print()

# =============================================================================
# STEP 7: TOKENIZE DATASET
# =============================================================================
print("[STEP 7] TOKENIZING DATASET")
print("-" * 70)
print("""
TEACHING POINT: Batch Processing
--------------------------------
We need to tokenize ALL examples before training.
The .map() function applies our tokenization to every example efficiently.

We also remove columns the model doesn't need (sentence text, index)
and keep only: input_ids, attention_mask, label
""")

def tokenize_function(examples):
    """Tokenize a batch of examples."""
    return tokenizer(
        examples["sentence"],
        truncation=True,     # Cut off at max_length
        max_length=128,      # SST-2 sentences are short
        # Note: We DON'T pad here - DataCollator does dynamic padding
    )

print("Tokenizing training set...")
tokenized_train = dataset["train"].map(
    tokenize_function,
    batched=True,
    remove_columns=["sentence", "idx"],
    desc="Tokenizing train",
)

print("Tokenizing validation set...")
tokenized_val = dataset["validation"].map(
    tokenize_function,
    batched=True,
    remove_columns=["sentence", "idx"],
    desc="Tokenizing validation",
)

print(f"""
TOKENIZED DATASET:
   Train columns: {tokenized_train.column_names}
   Train size:    {len(tokenized_train):,} examples
   Val columns:   {tokenized_val.column_names}
   Val size:      {len(tokenized_val):,} examples
""")

# Show a tokenized example
example = tokenized_train[0]
print("TOKENIZED EXAMPLE:")
print("-" * 40)
print(f"   input_ids length: {len(example['input_ids'])}")
print(f"   attention_mask:   {example['attention_mask'][:10]}... (1s for real tokens)")
print(f"   label:            {example['label']} ({'positive' if example['label']==1 else 'negative'})")
print()

# =============================================================================
# STEP 8: SETUP TRAINING
# =============================================================================
print("[STEP 8] SETTING UP TRAINER")
print("-" * 70)
print("""
TEACHING POINT: The Trainer API
-------------------------------
Hugging Face Trainer handles ALL the training boilerplate:
- Training loop
- Gradient computation and updates
- Evaluation
- Checkpointing
- Logging
- Mixed precision
- Multi-GPU (if available)

We just provide: model, args, datasets, metrics
""")

# Data collator for dynamic padding
print("Creating data collator...")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
print("""   DataCollatorWithPadding:
   - Pads each batch to the longest sequence IN THAT BATCH
   - More efficient than padding everything to max_length
""")

# Metrics function
def compute_metrics(eval_pred):
    """Compute accuracy and F1 score."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions, average='binary')

    return {"accuracy": acc, "f1": f1}

print("Metrics function defined (accuracy + F1)")

# Create training arguments
print("\nCreating TrainingArguments...")
training_args = TrainingArguments(**TRAINING_CONFIG)

# Create trainer
print("Initializing Trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Calculate training steps
steps_per_epoch = len(tokenized_train) // TRAINING_CONFIG["per_device_train_batch_size"]
total_steps = steps_per_epoch * TRAINING_CONFIG["num_train_epochs"]

print(f"""
TRAINING PLAN:
   Training examples:    {len(tokenized_train):,}
   Batch size:           {TRAINING_CONFIG['per_device_train_batch_size']}
   Steps per epoch:      {steps_per_epoch:,}
   Number of epochs:     {TRAINING_CONFIG['num_train_epochs']}
   Total training steps: {total_steps:,}

   Learning rate:        {TRAINING_CONFIG['learning_rate']}
   Warmup steps:         {int(total_steps * TRAINING_CONFIG['warmup_ratio']):,} ({TRAINING_CONFIG['warmup_ratio']*100:.0f}% of total)
""")
print()

# =============================================================================
# STEP 9: TRAIN!
# =============================================================================
print("[STEP 9] TRAINING")
print("=" * 70)
print("""
TEACHING POINT: What Happens During Training?
--------------------------------------------
Each step:
1. Load a batch of examples
2. Forward pass: compute predictions
3. Compute loss (cross-entropy for classification)
4. Backward pass: compute gradients
5. Update weights using optimizer (AdamW)
6. Repeat!

Watch the loss decrease and accuracy increase over epochs.
""")
print("=" * 70)
print()

input("Press ENTER to start training (or Ctrl+C to cancel)...")
print()

start_time = time.time()
print(f"Training started at {datetime.now().strftime('%H:%M:%S')}")
print("-" * 70)

# Actually train
train_result = trainer.train()

training_time = time.time() - start_time
minutes = int(training_time // 60)
seconds = int(training_time % 60)

print("-" * 70)
print(f"Training completed in {minutes}m {seconds}s")
print()

# =============================================================================
# STEP 10: EVALUATE
# =============================================================================
print("[STEP 10] FINAL EVALUATION")
print("-" * 70)
print("""
TEACHING POINT: Evaluation
--------------------------
We evaluate on the VALIDATION set (data the model never saw during training).
This tells us how well the model generalizes to new examples.

Metrics:
- Accuracy: % of correct predictions
- F1 Score: Harmonic mean of precision and recall
""")

print("Running final evaluation...")
eval_results = trainer.evaluate()

print(f"""
RESULTS:
   Accuracy: {eval_results['eval_accuracy']:.4f} ({eval_results['eval_accuracy']*100:.2f}%)
   F1 Score: {eval_results['eval_f1']:.4f}
   Eval Loss: {eval_results['eval_loss']:.4f}
""")

# =============================================================================
# STEP 11: SAVE MODEL
# =============================================================================
print("[STEP 11] SAVING MODEL")
print("-" * 70)

save_path = "./sst2_final_model"
print(f"Saving model to {save_path}...")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

print(f"""
SAVED FILES:
   {save_path}/
   ├── config.json          (model configuration)
   ├── model.safetensors    (model weights)
   ├── tokenizer.json       (tokenizer)
   ├── tokenizer_config.json
   ├── vocab.txt            (vocabulary)
   └── special_tokens_map.json

You can reload this model with:
   model = AutoModelForSequenceClassification.from_pretrained("{save_path}")
""")

# =============================================================================
# STEP 12: INFERENCE DEMO
# =============================================================================
print("[STEP 12] INFERENCE DEMO")
print("-" * 70)
print("""
TEACHING POINT: Using Your Model
--------------------------------
The 'pipeline' API makes inference easy:
- Load model
- Pass text
- Get predictions

This is what you'd use in production!
""")

from transformers import pipeline

print("Loading model as pipeline...")
classifier = pipeline(
    "sentiment-analysis",
    model=save_path,
    device=DEVICE if DEVICE != "mps" else -1,  # Pipeline handles MPS differently
)

test_sentences = [
    "This movie was absolutely fantastic! Best film I've seen all year.",
    "What a complete waste of time. Terrible acting and boring plot.",
    "It was okay, nothing special but watchable.",
    "The visuals were stunning but the story made no sense.",
    "I cried, I laughed, I fell in love with every character.",
    "Predictable and cliched, but my kids enjoyed it.",
]

print("\nTEST PREDICTIONS:")
print("-" * 70)
for sentence in test_sentences:
    result = classifier(sentence)[0]
    confidence = result['score']
    sentiment = result['label'].upper()

    # Visual indicator
    if sentiment == "POSITIVE":
        bar = "+" * int(confidence * 20)
        indicator = f"[{bar:20}]"
    else:
        bar = "-" * int(confidence * 20)
        indicator = f"[{bar:20}]"

    print(f"{indicator} {sentiment:8} ({confidence:5.1%})")
    print(f"   \"{sentence[:60]}{'...' if len(sentence) > 60 else ''}\"")
    print()

# =============================================================================
# FINAL SUMMARY
# =============================================================================
print("=" * 70)
print("TRAINING COMPLETE - SUMMARY")
print("=" * 70)
print(f"""
MODEL:          {MODEL_NAME}
TRAINING TIME:  {minutes}m {seconds}s
FINAL ACCURACY: {eval_results['eval_accuracy']*100:.2f}%
FINAL F1:       {eval_results['eval_f1']:.4f}

BASELINE TO BEAT: {eval_results['eval_accuracy']*100:.2f}%

NEXT STEPS FOR STUDENTS:
------------------------
1. Open strategy-selector.html   -> Decide: Full fine-tune or LoRA?
2. Open training-args.html       -> Generate optimized TrainingArguments
3. Open lora-calculator.html     -> (If using LoRA) Configure PEFT
4. Paste config into student_starter.ipynb
5. Train and beat this baseline!
6. Use evidence-builder.html to document your results

TIPS:
- Try learning_rate: 1e-5, 2e-5, 3e-5, 5e-5
- Try batch_size: 8, 16, 32
- Try num_epochs: 2, 3, 4, 5
- Try roberta-base for potentially higher accuracy
""")
print("=" * 70)
