#!/usr/bin/env python3
"""
W17D2: SST-2 Fine-Tuning - Student Template
============================================

INSTRUCTIONS:
1. Use strategy-selector.html   → Decide: Full fine-tune or LoRA?
2. Use training-args.html       → Generate your TrainingArguments (paste in Section 2)
3. Use lora-calculator.html     → If using LoRA, generate config (paste in Section 3)
4. Run this script              → Train and get your accuracy score
5. Use evidence-builder.html    → Document your results

Run: python student_training.py
"""

# =============================================================================
# SECTION 0: SETUP (don't modify)
# =============================================================================

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import time
import torch
import numpy as np
from datetime import datetime
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
)
from sklearn.metrics import accuracy_score, f1_score

# Detect device
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# =============================================================================
# SECTION 1: YOUR INFO
# =============================================================================

STUDENT_NAME = "Your Name"  # <-- Change this!

# =============================================================================
# SECTION 2: YOUR MODEL & TRAINING CONFIG
# Paste your config from training-args.html here
# =============================================================================

# Choose your model (options: distilbert-base-uncased, bert-base-uncased, roberta-base)
MODEL_NAME = "distilbert-base-uncased"

# ┌─────────────────────────────────────────────────────────────────────────────┐
# │  PASTE YOUR TrainingArguments FROM training-args.html BELOW                 │
# └─────────────────────────────────────────────────────────────────────────────┘

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,                  # ← Tune this! Try: 1e-5, 2e-5, 3e-5, 5e-5
    per_device_train_batch_size=16,      # ← Tune this! Try: 8, 16, 32
    per_device_eval_batch_size=32,
    num_train_epochs=3,                  # ← Tune this! Try: 2, 3, 4, 5
    weight_decay=0.01,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_steps=100,
    fp16=False,                          # Keep False for MPS
    report_to="none",
)

# =============================================================================
# SECTION 3: LoRA CONFIG (Optional - only if using LoRA)
# Set USE_LORA = True and paste config from lora-calculator.html
# =============================================================================

USE_LORA = False  # ← Set to True if using LoRA

# ┌─────────────────────────────────────────────────────────────────────────────┐
# │  PASTE YOUR LoraConfig FROM lora-calculator.html BELOW (if using LoRA)      │
# └─────────────────────────────────────────────────────────────────────────────┘

LORA_CONFIG = {
    "r": 16,
    "lora_alpha": 32,
    "target_modules": ["q_lin", "v_lin"],  # For DistilBERT
    # "target_modules": ["query", "value"],  # For BERT
    # "target_modules": ["q_proj", "v_proj"],  # For RoBERTa
    "lora_dropout": 0.05,
    "bias": "none",
}

# =============================================================================
# SECTION 4: TRAINING CODE (don't modify unless you know what you're doing)
# =============================================================================

def compute_metrics(eval_pred):
    """Compute accuracy and F1."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "f1": f1_score(labels, predictions, average='binary'),
    }

def main():
    print("=" * 60)
    print(f"SST-2 Training - {STUDENT_NAME}")
    print("=" * 60)
    print(f"Model: {MODEL_NAME}")
    print(f"LoRA: {'Yes' if USE_LORA else 'No'}")
    print(f"Learning Rate: {training_args.learning_rate}")
    print(f"Batch Size: {training_args.per_device_train_batch_size}")
    print(f"Epochs: {training_args.num_train_epochs}")
    print("=" * 60)

    # Load dataset
    print("\n[1/5] Loading SST-2 dataset...")
    dataset = load_dataset("glue", "sst2")
    print(f"      Train: {len(dataset['train']):,} | Val: {len(dataset['validation']):,}")

    # Load tokenizer and model
    print(f"\n[2/5] Loading {MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=2,
        id2label={0: "negative", 1: "positive"},
        label2id={"negative": 0, "positive": 1},
    )

    # Apply LoRA if enabled
    if USE_LORA:
        print("\n[2b] Applying LoRA...")
        from peft import LoraConfig, get_peft_model, TaskType

        lora_config = LoraConfig(
            r=LORA_CONFIG["r"],
            lora_alpha=LORA_CONFIG["lora_alpha"],
            target_modules=LORA_CONFIG["target_modules"],
            lora_dropout=LORA_CONFIG["lora_dropout"],
            bias=LORA_CONFIG["bias"],
            task_type=TaskType.SEQ_CLS,
        )
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
    else:
        total = sum(p.numel() for p in model.parameters())
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"      Parameters: {trainable:,} trainable / {total:,} total")

    # Tokenize
    print("\n[3/5] Tokenizing...")
    def tokenize(examples):
        return tokenizer(examples["sentence"], truncation=True, max_length=128)

    tokenized_train = dataset["train"].map(tokenize, batched=True, remove_columns=["sentence", "idx"])
    tokenized_val = dataset["validation"].map(tokenize, batched=True, remove_columns=["sentence", "idx"])

    # Setup trainer
    print("\n[4/5] Setting up trainer...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
        compute_metrics=compute_metrics,
    )

    # Train
    print("\n[5/5] Training...")
    print("-" * 60)
    start_time = time.time()
    trainer.train()
    training_time = time.time() - start_time

    # Evaluate
    print("\n" + "-" * 60)
    print("Evaluating...")
    results = trainer.evaluate()

    # Results
    minutes = int(training_time // 60)
    seconds = int(training_time % 60)

    print("\n" + "=" * 60)
    print("RESULTS")
    print("=" * 60)
    print(f"Student:       {STUDENT_NAME}")
    print(f"Model:         {MODEL_NAME}")
    print(f"Strategy:      {'LoRA' if USE_LORA else 'Full Fine-Tune'}")
    print(f"Training Time: {minutes}m {seconds}s")
    print("-" * 60)
    print(f"ACCURACY:      {results['eval_accuracy']*100:.2f}%")
    print(f"F1 SCORE:      {results['eval_f1']:.4f}")
    print("=" * 60)

    # Leaderboard submission
    strategy = f"LoRA r={LORA_CONFIG['r']}" if USE_LORA else "Full"
    submission = f"| {STUDENT_NAME} | {results['eval_accuracy']*100:.2f}% | {MODEL_NAME} | {strategy} | {minutes}m {seconds}s |"

    print("\nLEADERBOARD SUBMISSION (copy this):")
    print("-" * 60)
    print(submission)
    print("-" * 60)

    # Save model
    print("\nSaving model to ./my_model...")
    trainer.save_model("./my_model")
    tokenizer.save_pretrained("./my_model")
    print("Done!")

    return results

# =============================================================================
# SECTION 5: RUN
# =============================================================================

if __name__ == "__main__":
    results = main()

    print("\n" + "=" * 60)
    print("NEXT STEPS")
    print("=" * 60)
    print("1. Copy your leaderboard submission above")
    print("2. Open evidence-builder.html")
    print("3. Fill in your training details")
    print("4. Generate and save your METRICS.md")
    print("=" * 60)
