Next version. Models + scripts updated. 2
This commit is contained in:
518
notebooks/03_joint_route_generator.ipynb
Normal file
518
notebooks/03_joint_route_generator.ipynb
Normal file
@@ -0,0 +1,518 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "27197e7d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 03 — Joint nanoGPT-style Route Generation\n",
|
||||
"\n",
|
||||
"## From Understanding to Generation\n",
|
||||
"\n",
|
||||
"Notebook 02 used a **transformer encoder** (BERT-style) to *understand* routes and predict their grade. This notebook uses a **transformer decoder** (GPT-style) to *generate* new routes.\n",
|
||||
"\n",
|
||||
"### The key difference: Encoder vs Decoder\n",
|
||||
"\n",
|
||||
"| Aspect | BERT-style (Encoder) | GPT-style (Decoder) |\n",
|
||||
"|---|---|---|\n",
|
||||
"| Attention | Bidirectional (sees all tokens) | Causal (only sees past tokens) |\n",
|
||||
"| Training | Masked language modeling | Next-token prediction |\n",
|
||||
"| Use case | Classification, regression | Text generation |\n",
|
||||
"| Output | Single prediction per sequence | One prediction per position |\n",
|
||||
"\n",
|
||||
"### How GPT-style generation works\n",
|
||||
"\n",
|
||||
"The model is trained to predict the **next token** given all previous tokens:\n",
|
||||
"\n",
|
||||
"```text\n",
|
||||
"Input: <BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>\n",
|
||||
"Target: <BOARD_TB2> <ANGLE_40> <GRADE_V6> <TB2_p344_start>\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"At generation time, we:\n",
|
||||
"1. Start with a prompt like `<BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>`\n",
|
||||
"2. Ask the model to predict the next token\n",
|
||||
"3. Sample from the predicted probability distribution\n",
|
||||
"4. Append the sampled token to the sequence\n",
|
||||
"5. Repeat until we generate `<EOS>` or hit a max length\n",
|
||||
"\n",
|
||||
"### Conditioning on board, angle, and grade\n",
|
||||
"\n",
|
||||
"The prompt tokens tell the model *what kind of route to generate*:\n",
|
||||
"- `<BOARD_TB2>`: Generate a route for the Tension Board 2\n",
|
||||
"- `<ANGLE_40>`: At 40 degrees\n",
|
||||
"- `<GRADE_V6>`: At V6 difficulty\n",
|
||||
"\n",
|
||||
"This is analogous to how ChatGPT uses a system prompt to condition its responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b6590822",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"import sys\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if ROOT.name == \"notebooks\":\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"sys.path.insert(0, str(ROOT / \"src\"))\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.config import load_board_configs\n",
|
||||
"from climbingboardgpt.datasets import RouteGPTDataset\n",
|
||||
"from climbingboardgpt.generation import generate_one\n",
|
||||
"from climbingboardgpt.models import JointRouteGPT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f09fdf54",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOKENIZED = ROOT / \"data\" / \"processed\" / \"tokenized\"\n",
|
||||
"df_routes = pd.read_csv(TOKENIZED / \"route_sequences.csv\")\n",
|
||||
"vocab = json.loads((TOKENIZED / \"token_vocab.json\").read_text(encoding=\"utf-8\"))\n",
|
||||
"stoi = {str(k): int(v) for k, v in vocab[\"stoi\"].items()}\n",
|
||||
"itos = {int(k): str(v) for k, v in vocab[\"itos\"].items()}\n",
|
||||
"\n",
|
||||
"pad_id = stoi[\"<PAD>\"]\n",
|
||||
"unk_id = stoi[\"<UNK>\"]\n",
|
||||
"\n",
|
||||
"print(f\"Vocabulary size: {len(stoi):,}\")\n",
|
||||
"print(f\"Total routes: {len(df_routes):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fe4b0faf",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sequence encoding for causal language modeling\n",
|
||||
"\n",
|
||||
"### The autoregressive setup\n",
|
||||
"\n",
|
||||
"For GPT-style training, each route becomes a sequence where the model learns to predict each token given all previous tokens:\n",
|
||||
"\n",
|
||||
"```text\n",
|
||||
"Input: <BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6> <TB2_p344_start> <TB2_p369_middle>\n",
|
||||
"Target: <BOARD_TB2> <ANGLE_40> <GRADE_V6> <TB2_p344_start> <TB2_p369_middle> <TB2_p603_finish>\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"The input is shifted right by one position compared to the target. This is the standard causal language modeling setup.\n",
|
||||
"\n",
|
||||
"### Why include the grade in the training sequence?\n",
|
||||
"\n",
|
||||
"For the grade predictor (notebook 02), we excluded the grade because the model needed to predict it. But for the generator, we **include** the grade (`<GRADE_V6>`) in the training data so the model learns the relationship between grade and hold selection.\n",
|
||||
"\n",
|
||||
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ad61dbd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def encode(tokens):\n",
|
||||
" \"\"\"Convert token strings to integer IDs.\"\"\"\n",
|
||||
" return [stoi.get(token, unk_id) for token in tokens]\n",
|
||||
"\n",
|
||||
"# Use the \"with grade\" version for GPT training\n",
|
||||
"# The model needs to see the grade to learn grade-hold relationships\n",
|
||||
"df_routes[\"gpt_tokens\"] = df_routes[\"sequence_with_grade\"].fillna(\"\").str.split()\n",
|
||||
"df_routes[\"gpt_ids\"] = df_routes[\"gpt_tokens\"].apply(encode)\n",
|
||||
"df_routes[\"seq_len\"] = df_routes[\"gpt_ids\"].apply(len)\n",
|
||||
"max_len = int(df_routes[\"seq_len\"].max())\n",
|
||||
"block_size = max_len - 1 # Input length (one less than full sequence)\n",
|
||||
"\n",
|
||||
"# Create train/val splits\n",
|
||||
"train_df = df_routes[df_routes[\"split\"] == \"train\"].reset_index(drop=True)\n",
|
||||
"val_df = df_routes[df_routes[\"split\"] == \"val\"].reset_index(drop=True)\n",
|
||||
"\n",
|
||||
"# Create datasets and data loaders\n",
|
||||
"# RouteGPTDataset handles the input/target shift for causal modeling\n",
|
||||
"train_ds = RouteGPTDataset(train_df, max_len=max_len, pad_id=pad_id)\n",
|
||||
"val_ds = RouteGPTDataset(val_df, max_len=max_len, pad_id=pad_id)\n",
|
||||
"train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)\n",
|
||||
"val_loader = DataLoader(val_ds, batch_size=128, shuffle=False)\n",
|
||||
"\n",
|
||||
"print(f\"Max sequence length: {max_len}\")\n",
|
||||
"print(f\"Block size (input length): {block_size}\")\n",
|
||||
"print(f\"Training samples: {len(train_ds):,}\")\n",
|
||||
"print(f\"Validation samples: {len(val_ds):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66d98641",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## The GPT Model Architecture\n",
|
||||
"\n",
|
||||
"### JointRouteGPT\n",
|
||||
"\n",
|
||||
"This is a **causal transformer decoder** — the same architecture used in GPT-2, GPT-3, etc., but much smaller:\n",
|
||||
"\n",
|
||||
"1. **Token embeddings**: Convert integer token IDs to dense vectors\n",
|
||||
"2. **Positional embeddings**: Learned position vectors (not sinusoidal)\n",
|
||||
"3. **Causal self-attention**: Each position can only attend to previous positions (via a causal mask)\n",
|
||||
"4. **Transformer layers**: Multiple layers of attention + feedforward\n",
|
||||
"5. **Language modeling head**: Projects hidden states to vocabulary logits\n",
|
||||
"\n",
|
||||
"### Key hyperparameters\n",
|
||||
"\n",
|
||||
"- `n_embd=128`: Embedding dimension (GPT-2 small uses 768)\n",
|
||||
"- `n_head=4`: Number of attention heads\n",
|
||||
"- `n_layer=4`: Number of transformer layers (GPT-2 small uses 12)\n",
|
||||
"- `dropout=0.10`: Dropout probability\n",
|
||||
"\n",
|
||||
"This is intentionally small — we're training on ~40K short sequences, not billions of long documents.\n",
|
||||
"\n",
|
||||
"### Weight tying\n",
|
||||
"\n",
|
||||
"The output projection layer shares weights with the token embedding layer (`self.lm_head.weight = self.token_emb.weight`). This is a common technique that:\n",
|
||||
"- Reduces parameter count\n",
|
||||
"- Acts as a regularizer\n",
|
||||
"- Is used in GPT-2 and many other language models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3eec6f35",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"model = JointRouteGPT(\n",
|
||||
" vocab_size=len(stoi),\n",
|
||||
" block_size=block_size,\n",
|
||||
" n_embd=128,\n",
|
||||
" n_head=4,\n",
|
||||
" n_layer=4,\n",
|
||||
" dropout=0.10,\n",
|
||||
" pad_id=pad_id,\n",
|
||||
").to(device)\n",
|
||||
"\n",
|
||||
"optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)\n",
|
||||
"\n",
|
||||
"print(f\"Device: {device}\")\n",
|
||||
"print(f\"Total parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f999cf05",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_epoch():\n",
|
||||
" \"\"\"Train for one epoch.\"\"\"\n",
|
||||
" model.train()\n",
|
||||
" losses = []\n",
|
||||
" n = 0\n",
|
||||
" for batch in train_loader:\n",
|
||||
" x = batch[\"input_ids\"].to(device)\n",
|
||||
" y = batch[\"target_ids\"].to(device)\n",
|
||||
" \n",
|
||||
" optimizer.zero_grad(set_to_none=True)\n",
|
||||
" _, loss = model(x, y)\n",
|
||||
" loss.backward()\n",
|
||||
" \n",
|
||||
" # Gradient clipping prevents exploding gradients\n",
|
||||
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
||||
" \n",
|
||||
" optimizer.step()\n",
|
||||
" losses.append(loss.item() * x.size(0))\n",
|
||||
" n += x.size(0)\n",
|
||||
" return sum(losses) / max(1, n)\n",
|
||||
"\n",
|
||||
"@torch.no_grad()\n",
|
||||
"def eval_loss(loader):\n",
|
||||
" \"\"\"Evaluate loss on a data loader.\"\"\"\n",
|
||||
" model.eval()\n",
|
||||
" losses = []\n",
|
||||
" n = 0\n",
|
||||
" for batch in loader:\n",
|
||||
" x = batch[\"input_ids\"].to(device)\n",
|
||||
" y = batch[\"target_ids\"].to(device)\n",
|
||||
" _, loss = model(x, y)\n",
|
||||
" losses.append(loss.item() * x.size(0))\n",
|
||||
" n += x.size(0)\n",
|
||||
" return sum(losses) / max(1, n)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51fb8b6e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training\n",
|
||||
"\n",
|
||||
"### What we're optimizing\n",
|
||||
"\n",
|
||||
"The model minimizes **cross-entropy loss** — the standard loss function for language modeling. At each position, the model outputs a probability distribution over the entire vocabulary, and the loss measures how surprised it is by the actual next token.\n",
|
||||
"\n",
|
||||
"### Perplexity\n",
|
||||
"\n",
|
||||
"We also track **perplexity**, which is `exp(loss)`. Perplexity answers the question: \"On average, how many tokens was the model choosing between at each step?\" Lower perplexity = better model.\n",
|
||||
"\n",
|
||||
"For reference:\n",
|
||||
"- A model that always predicts the right token has perplexity = 1\n",
|
||||
"- A model that picks uniformly from a 1000-token vocab has perplexity = 1000\n",
|
||||
"- Good language models on English text achieve perplexity ~15-20\n",
|
||||
"\n",
|
||||
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "70b38b02",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history = []\n",
|
||||
"best_val_loss = float(\"inf\")\n",
|
||||
"best_state = None\n",
|
||||
"patience = 10\n",
|
||||
"stagnant = 0\n",
|
||||
"\n",
|
||||
"print(\"Starting GPT training...\\n\")\n",
|
||||
"\n",
|
||||
"for epoch in range(1, 21):\n",
|
||||
" train_loss = train_epoch()\n",
|
||||
" val_loss = eval_loss(val_loader)\n",
|
||||
" \n",
|
||||
" # Track perplexity (exponentiated loss)\n",
|
||||
" train_ppl = math.exp(min(train_loss, 20))\n",
|
||||
" val_ppl = math.exp(min(val_loss, 20))\n",
|
||||
" \n",
|
||||
" history.append({\n",
|
||||
" \"epoch\": epoch,\n",
|
||||
" \"train_loss\": train_loss,\n",
|
||||
" \"val_loss\": val_loss,\n",
|
||||
" \"train_perplexity\": train_ppl,\n",
|
||||
" \"val_perplexity\": val_ppl,\n",
|
||||
" })\n",
|
||||
" \n",
|
||||
" # Early stopping\n",
|
||||
" if val_loss < best_val_loss:\n",
|
||||
" best_val_loss = val_loss\n",
|
||||
" best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
|
||||
" stagnant = 0\n",
|
||||
" else:\n",
|
||||
" stagnant += 1\n",
|
||||
" \n",
|
||||
" if epoch == 1 or epoch % 5 == 0:\n",
|
||||
" print(f\"Epoch {epoch:3d} | \"\n",
|
||||
" f\"Train Loss: {train_loss:.4f} | \"\n",
|
||||
" f\"Val Loss: {val_loss:.4f} | \"\n",
|
||||
" f\"Val PPL: {val_ppl:.1f}\")\n",
|
||||
" \n",
|
||||
" if stagnant >= patience:\n",
|
||||
" print(f\"\\nEarly stopping at epoch {epoch}\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
"# Load best model\n",
|
||||
"if best_state is not None:\n",
|
||||
" model.load_state_dict(best_state)\n",
|
||||
"\n",
|
||||
"print(f\"\\nBest validation loss: {best_val_loss:.4f}\")\n",
|
||||
"print(f\"Best validation perplexity: {math.exp(min(best_val_loss, 20)):.1f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "69926180",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generating Routes\n",
|
||||
"\n",
|
||||
"### The generation process\n",
|
||||
"\n",
|
||||
"To generate a route, we:\n",
|
||||
"\n",
|
||||
"1. **Create a prompt**: `<BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>`\n",
|
||||
"2. **Feed it to the model**: Get a probability distribution over the vocabulary for the next token\n",
|
||||
"3. **Sample a token**: Use temperature and top-k filtering to control randomness\n",
|
||||
"4. **Append and repeat**: Add the sampled token to the sequence and repeat until `<EOS>` or max length\n",
|
||||
"\n",
|
||||
"### Temperature and top-k sampling\n",
|
||||
"\n",
|
||||
"- **Temperature** (default 0.9): Controls randomness. Lower = more deterministic, higher = more random\n",
|
||||
"- **Top-k** (default 50): Only consider the k most likely tokens. This prevents the model from generating very unlikely tokens.\n",
|
||||
"\n",
|
||||
"These are the same techniques used in language models like GPT-3 to control output diversity."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "029eb911",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate sample routes for both boards\n",
|
||||
"configs = load_board_configs([\"tb2\", \"kilter\"])\n",
|
||||
"configs_by_key = {config.board_key: config for config in configs}\n",
|
||||
"\n",
|
||||
"samples = []\n",
|
||||
"for board_key, config in configs_by_key.items():\n",
|
||||
" for grouped_v in [3, 5, 7]: # V3, V5, V7\n",
|
||||
" sample = generate_one(\n",
|
||||
" model=model,\n",
|
||||
" stoi=stoi,\n",
|
||||
" itos=itos,\n",
|
||||
" device=device,\n",
|
||||
" board_prefix=config.token_prefix,\n",
|
||||
" angle=40,\n",
|
||||
" grouped_v=grouped_v,\n",
|
||||
" role_name_to_id=config.role_definitions,\n",
|
||||
" temperature=0.9,\n",
|
||||
" top_k=50,\n",
|
||||
" max_new_tokens=40,\n",
|
||||
" )\n",
|
||||
" samples.append({\"board_key\": board_key, **sample})\n",
|
||||
"\n",
|
||||
"samples_df = pd.DataFrame(samples)\n",
|
||||
"print(\"Generated route samples:\")\n",
|
||||
"print(samples_df[[\"board_key\", \"requested_grouped_v\", \"basic_valid\", \"sequence\", \"frames\"]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "generate_more",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Generate More Routes for Evaluation\n",
|
||||
"\n",
|
||||
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "generate_bulk",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate routes across multiple angles and grades for evaluation\n",
|
||||
"all_samples = []\n",
|
||||
"\n",
|
||||
"for board_key, config in configs_by_key.items():\n",
|
||||
" # Get common angles and grades for this board\n",
|
||||
" board_df = df_routes[df_routes[\"board_key\"] == board_key]\n",
|
||||
" common_angles = sorted(board_df[\"angle\"].astype(int).value_counts().head(5).index.tolist())\n",
|
||||
" common_grades = sorted(board_df[\"grouped_v\"].astype(int).value_counts().head(8).index.tolist())\n",
|
||||
" \n",
|
||||
" print(f\"\\nGenerating for {config.display_name}:\")\n",
|
||||
" print(f\" Angles: {common_angles}\")\n",
|
||||
" print(f\" Grades: V{min(common_grades)}-V{max(common_grades)}\")\n",
|
||||
" \n",
|
||||
" for angle in common_angles:\n",
|
||||
" for grade in common_grades:\n",
|
||||
" for i in range(5): # 5 samples per condition\n",
|
||||
" sample = generate_one(\n",
|
||||
" model=model,\n",
|
||||
" stoi=stoi,\n",
|
||||
" itos=itos,\n",
|
||||
" device=device,\n",
|
||||
" board_prefix=config.token_prefix,\n",
|
||||
" angle=int(angle),\n",
|
||||
" grouped_v=int(grade),\n",
|
||||
" role_name_to_id=config.role_definitions,\n",
|
||||
" temperature=0.9,\n",
|
||||
" top_k=50,\n",
|
||||
" max_new_tokens=40,\n",
|
||||
" )\n",
|
||||
" all_samples.append({\"board_key\": board_key, **sample})\n",
|
||||
"\n",
|
||||
"all_samples_df = pd.DataFrame(all_samples)\n",
|
||||
"print(f\"\\nTotal generated routes: {len(all_samples_df):,}\")\n",
|
||||
"print(\"\\nBasic validity by board:\")\n",
|
||||
"print(all_samples_df.groupby(\"board_key\")[\"basic_valid\"].mean())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "save_artifacts",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Save Model and Generated Routes\n",
|
||||
"\n",
|
||||
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "save_outputs",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# Save model checkpoint\n",
|
||||
"MODEL_DIR = ROOT / \"models\"\n",
|
||||
"MODEL_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"checkpoint = {\n",
|
||||
" \"model_state_dict\": model.state_dict(),\n",
|
||||
" \"config\": {\n",
|
||||
" \"vocab_size\": len(stoi),\n",
|
||||
" \"block_size\": block_size,\n",
|
||||
" \"n_embd\": 128,\n",
|
||||
" \"n_head\": 4,\n",
|
||||
" \"n_layer\": 4,\n",
|
||||
" \"dropout\": 0.10,\n",
|
||||
" \"pad_id\": pad_id,\n",
|
||||
" },\n",
|
||||
" \"stoi\": stoi,\n",
|
||||
" \"itos\": {str(k): v for k, v in itos.items()},\n",
|
||||
" \"best_val_loss\": best_val_loss,\n",
|
||||
"}\n",
|
||||
"model_path = MODEL_DIR / \"joint_route_gpt_generator.pth\"\n",
|
||||
"torch.save(checkpoint, model_path)\n",
|
||||
"print(f\"Saved model checkpoint to: {model_path}\")\n",
|
||||
"\n",
|
||||
"# Save training history\n",
|
||||
"GEN_DIR = ROOT / \"data\" / \"processed\" / \"generation\"\n",
|
||||
"GEN_DIR.mkdir(parents=True, exist_ok=True)\n",
|
||||
"\n",
|
||||
"pd.DataFrame(history).to_csv(GEN_DIR / \"training_history.csv\", index=False)\n",
|
||||
"print(f\"Saved training history to: {GEN_DIR / 'training_history.csv'}\")\n",
|
||||
"\n",
|
||||
"# Save generated routes (this is what notebook 04 needs)\n",
|
||||
"all_samples_df.to_csv(GEN_DIR / \"generated_routes.csv\", index=False)\n",
|
||||
"print(f\"Saved {len(all_samples_df)} generated routes to: {GEN_DIR / 'generated_routes.csv'}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user