1134 lines
45 KiB
Plaintext
1134 lines
45 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "27197e7d",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 03 — Joint nanoGPT-style Route Generation\n",
|
|
"\n",
|
|
"## From Understanding to Generation\n",
|
|
"\n",
|
|
"Notebook 02 used a **transformer encoder** (BERT-style) to *understand* routes and predict their grade. This notebook uses a **transformer decoder** (GPT-style) to *generate* new routes.\n",
|
|
"\n",
|
|
"### The key difference: Encoder vs Decoder\n",
|
|
"\n",
|
|
"| Aspect | BERT-style (Encoder) | GPT-style (Decoder) |\n",
|
|
"|---|---|---|\n",
|
|
"| Attention | Bidirectional (sees all tokens) | Causal (only sees past tokens) |\n",
|
|
"| Training | Masked language modeling | Next-token prediction |\n",
|
|
"| Use case | Classification, regression | Text generation |\n",
|
|
"| Output | Single prediction per sequence | One prediction per position |\n",
|
|
"\n",
|
|
"### How GPT-style generation works\n",
|
|
"\n",
|
|
"The model is trained to predict the **next token** given all previous tokens:\n",
|
|
"\n",
|
|
"```text\n",
|
|
"Input: <BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>\n",
|
|
"Target: <BOARD_TB2> <ANGLE_40> <GRADE_V6> <TB2_p344_start>\n",
|
|
"```\n",
|
|
"\n",
|
|
"At generation time, we:\n",
|
|
"1. Start with a prompt like `<BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>`\n",
|
|
"2. Ask the model to predict the next token\n",
|
|
"3. Sample from the predicted probability distribution\n",
|
|
"4. Append the sampled token to the sequence\n",
|
|
"5. Repeat until we generate `<EOS>` or hit a max length\n",
|
|
"\n",
|
|
"### Conditioning on board, angle, and grade\n",
|
|
"\n",
|
|
"The prompt tokens tell the model *what kind of route to generate*:\n",
|
|
"- `<BOARD_TB2>`: Generate a route for the Tension Board 2\n",
|
|
"- `<ANGLE_40>`: At 40 degrees\n",
|
|
"- `<GRADE_V6>`: At V6 difficulty\n",
|
|
"\n",
|
|
"This is analogous to how ChatGPT uses a system prompt to condition its responses.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b6590822",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:14.101804Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:14.101439Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:16.162395Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:16.161684Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from __future__ import annotations\n",
|
|
"\n",
|
|
"import ast\n",
|
|
"import json\n",
|
|
"import math\n",
|
|
"import re\n",
|
|
"from dataclasses import dataclass\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Iterable\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from torch.utils.data import DataLoader, Dataset\n",
|
|
"\n",
|
|
"ROOT = Path.cwd().resolve()\n",
|
|
"if ROOT.name == \"notebooks\":\n",
|
|
" ROOT = ROOT.parent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f09fdf54",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:16.166017Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:16.165651Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:21.900885Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:21.899985Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"TOKENIZED = ROOT / \"data\" / \"processed\" / \"tokenized\"\n",
|
|
"df_routes = pd.read_csv(TOKENIZED / \"route_sequences.csv\")\n",
|
|
"vocab = json.loads((TOKENIZED / \"token_vocab.json\").read_text(encoding=\"utf-8\"))\n",
|
|
"stoi = {str(k): int(v) for k, v in vocab[\"stoi\"].items()}\n",
|
|
"itos = {int(k): str(v) for k, v in vocab[\"itos\"].items()}\n",
|
|
"\n",
|
|
"pad_id = stoi[\"<PAD>\"]\n",
|
|
"unk_id = stoi[\"<UNK>\"]\n",
|
|
"\n",
|
|
"print(f\"Vocabulary size: {len(stoi):,}\")\n",
|
|
"print(f\"Total routes: {len(df_routes):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4fcba532",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Causal dataset helper"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "40021fc1",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:21.904008Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:21.903750Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:21.910270Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:21.909595Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Pad route-token sequences and create shifted input/target pairs for causal modeling.\n",
|
|
"class RouteGPTDataset(Dataset):\n",
|
|
" \"\"\"Dataset for causal next-token route generation.\n",
|
|
"\n",
|
|
" The full sequence is padded once, then split into ``input_ids`` and\n",
|
|
" ``target_ids`` shifted by one position for teacher-forced language-model\n",
|
|
" training.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(self, df, max_len: int, pad_id: int):\n",
|
|
" \"\"\"Store GPT token ID sequences from a tokenized route DataFrame.\"\"\"\n",
|
|
" self.ids = df[\"gpt_ids\"].tolist()\n",
|
|
" self.max_len = int(max_len)\n",
|
|
" self.pad_id = int(pad_id)\n",
|
|
"\n",
|
|
" def __len__(self) -> int:\n",
|
|
" \"\"\"Return the number of route examples.\"\"\"\n",
|
|
" return len(self.ids)\n",
|
|
"\n",
|
|
" def __getitem__(self, idx: int):\n",
|
|
" \"\"\"Return one padded causal-language-model training example.\"\"\"\n",
|
|
" ids = list(self.ids[idx])[: self.max_len]\n",
|
|
" if len(ids) < self.max_len:\n",
|
|
" ids += [self.pad_id] * (self.max_len - len(ids))\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"input_ids\": torch.tensor(ids[:-1], dtype=torch.long),\n",
|
|
" \"target_ids\": torch.tensor(ids[1:], dtype=torch.long),\n",
|
|
" }"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fe4b0faf",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Sequence encoding for causal language modeling\n",
|
|
"\n",
|
|
"### The autoregressive setup\n",
|
|
"\n",
|
|
"For GPT-style training, each route becomes a sequence where the model learns to predict each token given all previous tokens:\n",
|
|
"\n",
|
|
"```text\n",
|
|
"Input: <BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6> <TB2_p344_start> <TB2_p369_middle>\n",
|
|
"Target: <BOARD_TB2> <ANGLE_40> <GRADE_V6> <TB2_p344_start> <TB2_p369_middle> <TB2_p603_finish>\n",
|
|
"```\n",
|
|
"\n",
|
|
"The input is shifted right by one position compared to the target. This is the standard causal language modeling setup.\n",
|
|
"\n",
|
|
"### Why include the grade in the training sequence?\n",
|
|
"\n",
|
|
"For the grade predictor (notebook 02), we excluded the grade because the model needed to predict it. But for the generator, we **include** the grade (`<GRADE_V6>`) in the training data so the model learns the relationship between grade and hold selection.\n",
|
|
"\n",
|
|
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7ad61dbd",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:21.913286Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:21.912788Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:25.369590Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:25.368643Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def encode(tokens):\n",
|
|
" \"\"\"Convert token strings to integer IDs.\"\"\"\n",
|
|
" return [stoi.get(token, unk_id) for token in tokens]\n",
|
|
"\n",
|
|
"# Use the \"with grade\" version for GPT training\n",
|
|
"# The model needs to see the grade to learn grade-hold relationships\n",
|
|
"df_routes[\"gpt_tokens\"] = df_routes[\"sequence_with_grade\"].fillna(\"\").str.split()\n",
|
|
"df_routes[\"gpt_ids\"] = df_routes[\"gpt_tokens\"].apply(encode)\n",
|
|
"df_routes[\"seq_len\"] = df_routes[\"gpt_ids\"].apply(len)\n",
|
|
"max_len = int(df_routes[\"seq_len\"].max())\n",
|
|
"block_size = max_len - 1 # Input length (one less than full sequence)\n",
|
|
"\n",
|
|
"# Create train/val splits\n",
|
|
"train_df = df_routes[df_routes[\"split\"] == \"train\"].reset_index(drop=True)\n",
|
|
"val_df = df_routes[df_routes[\"split\"] == \"val\"].reset_index(drop=True)\n",
|
|
"\n",
|
|
"# Create datasets and data loaders\n",
|
|
"# RouteGPTDataset handles the input/target shift for causal modeling\n",
|
|
"train_ds = RouteGPTDataset(train_df, max_len=max_len, pad_id=pad_id)\n",
|
|
"val_ds = RouteGPTDataset(val_df, max_len=max_len, pad_id=pad_id)\n",
|
|
"train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)\n",
|
|
"val_loader = DataLoader(val_ds, batch_size=128, shuffle=False)\n",
|
|
"\n",
|
|
"print(f\"Max sequence length: {max_len}\")\n",
|
|
"print(f\"Block size (input length): {block_size}\")\n",
|
|
"print(f\"Training samples: {len(train_ds):,}\")\n",
|
|
"print(f\"Validation samples: {len(val_ds):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "552b9f69",
|
|
"metadata": {},
|
|
"source": [
|
|
"### GPT model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4085a314",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:25.373057Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:25.372812Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:25.388342Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:25.387476Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# GPT-style causal transformer for route generation.\n",
|
|
"class JointRouteGPT(nn.Module):\n",
|
|
" \"\"\"Tiny GPT-style causal transformer for board-conditioned route generation.\n",
|
|
"\n",
|
|
" PyTorch's ``TransformerEncoder`` is used with a causal mask, which makes it\n",
|
|
" behave like a decoder-only language model for short route sequences.\n",
|
|
"\n",
|
|
" Why use ``TransformerEncoder`` rather than ``TransformerDecoder``?\n",
|
|
" -------------------------------------------------------------------\n",
|
|
" PyTorch's ``TransformerDecoderLayer`` expects two inputs: a decoder\n",
|
|
" sequence and a separate encoder memory for cross-attention. For\n",
|
|
" unconditional or prompt-conditioned generation there is no encoder,\n",
|
|
" so ``TransformerDecoderLayer`` would always ignore the second input\n",
|
|
" or require a dummy placeholder. Using ``TransformerEncoder`` with a\n",
|
|
" causal mask avoids this mismatch, keeps the module list uniform,\n",
|
|
" and produces identical behaviour for short autoregressive generation.\n",
|
|
"\n",
|
|
" The trade-off is that ``TransformerEncoder`` does not natively prevent\n",
|
|
" attention to future positions — the causal mask must be constructed\n",
|
|
" manually (see ``forward``). For the sequence lengths seen here\n",
|
|
" (at most ~400 tokens) the overhead of the upper-triangular mask is\n",
|
|
" negligible, and ``enable_nested_tensor=False`` is set to avoid SDPA\n",
|
|
" optimisations that do not support masked encoders.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" vocab_size: int,\n",
|
|
" block_size: int,\n",
|
|
" n_embd: int = 128,\n",
|
|
" n_head: int = 4,\n",
|
|
" n_layer: int = 4,\n",
|
|
" dropout: float = 0.10,\n",
|
|
" pad_id: int = 0,\n",
|
|
" ):\n",
|
|
" \"\"\"Create the token/position embeddings, causal blocks, and LM head.\"\"\"\n",
|
|
" super().__init__()\n",
|
|
" self.vocab_size = vocab_size\n",
|
|
" self.block_size = block_size\n",
|
|
" self.pad_id = pad_id\n",
|
|
"\n",
|
|
" self.token_emb = nn.Embedding(vocab_size, n_embd, padding_idx=pad_id)\n",
|
|
" self.pos_emb = nn.Embedding(block_size, n_embd)\n",
|
|
" self.drop = nn.Dropout(dropout)\n",
|
|
"\n",
|
|
" layer = nn.TransformerEncoderLayer(\n",
|
|
" d_model=n_embd,\n",
|
|
" nhead=n_head,\n",
|
|
" dim_feedforward=4 * n_embd,\n",
|
|
" dropout=dropout,\n",
|
|
" activation=\"gelu\",\n",
|
|
" batch_first=True,\n",
|
|
" norm_first=True,\n",
|
|
" )\n",
|
|
" self.blocks = nn.TransformerEncoder(\n",
|
|
" layer,\n",
|
|
" num_layers=n_layer,\n",
|
|
" enable_nested_tensor=False,\n",
|
|
" )\n",
|
|
" self.ln_f = nn.LayerNorm(n_embd)\n",
|
|
" self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)\n",
|
|
" self.lm_head.weight = self.token_emb.weight\n",
|
|
"\n",
|
|
" def forward(\n",
|
|
" self,\n",
|
|
" idx: torch.Tensor,\n",
|
|
" targets: torch.Tensor | None = None,\n",
|
|
" ) -> tuple[torch.Tensor, torch.Tensor | None]:\n",
|
|
" \"\"\"Return next-token logits and, when targets are supplied, CE loss.\"\"\"\n",
|
|
" _, seq_len = idx.shape\n",
|
|
" if seq_len > self.block_size:\n",
|
|
" idx = idx[:, -self.block_size :]\n",
|
|
" seq_len = idx.shape[1]\n",
|
|
"\n",
|
|
" positions = torch.arange(seq_len, device=idx.device).unsqueeze(0)\n",
|
|
" x = self.drop(self.token_emb(idx) + self.pos_emb(positions))\n",
|
|
"\n",
|
|
" causal_mask = torch.triu(\n",
|
|
" torch.ones(seq_len, seq_len, device=idx.device, dtype=torch.bool),\n",
|
|
" diagonal=1,\n",
|
|
" )\n",
|
|
" # Padding masks suppress attention to right-padded context tokens while\n",
|
|
" # the causal mask suppresses attention to future positions.\n",
|
|
" key_padding_mask = idx.eq(self.pad_id)\n",
|
|
"\n",
|
|
" h = self.blocks(\n",
|
|
" x,\n",
|
|
" mask=causal_mask,\n",
|
|
" src_key_padding_mask=key_padding_mask,\n",
|
|
" )\n",
|
|
" h = self.ln_f(h)\n",
|
|
" logits = self.lm_head(h)\n",
|
|
"\n",
|
|
" loss = None\n",
|
|
" if targets is not None:\n",
|
|
" loss = F.cross_entropy(\n",
|
|
" logits.reshape(-1, logits.size(-1)),\n",
|
|
" targets.reshape(-1),\n",
|
|
" ignore_index=self.pad_id,\n",
|
|
" )\n",
|
|
"\n",
|
|
" return logits, loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "66d98641",
|
|
"metadata": {},
|
|
"source": [
|
|
"## The GPT Model Architecture\n",
|
|
"\n",
|
|
"### JointRouteGPT\n",
|
|
"\n",
|
|
"This is a **causal transformer decoder** — the same architecture used in GPT-2, GPT-3, etc., but much smaller:\n",
|
|
"\n",
|
|
"1. **Token embeddings**: Convert integer token IDs to dense vectors\n",
|
|
"2. **Positional embeddings**: Learned position vectors (not sinusoidal)\n",
|
|
"3. **Causal self-attention**: Each position can only attend to previous positions (via a causal mask)\n",
|
|
"4. **Transformer layers**: Multiple layers of attention + feedforward\n",
|
|
"5. **Language modeling head**: Projects hidden states to vocabulary logits\n",
|
|
"\n",
|
|
"### Key hyperparameters\n",
|
|
"\n",
|
|
"- `n_embd=128`: Embedding dimension (GPT-2 small uses 768)\n",
|
|
"- `n_head=4`: Number of attention heads\n",
|
|
"- `n_layer=4`: Number of transformer layers (GPT-2 small uses 12)\n",
|
|
"- `dropout=0.10`: Dropout probability\n",
|
|
"\n",
|
|
"This is intentionally small — we're training on a few hundred thousand short sequences, not billions of long documents.\n",
|
|
"\n",
|
|
"### Weight tying\n",
|
|
"\n",
|
|
"The output projection layer shares weights with the token embedding layer (`self.lm_head.weight = self.token_emb.weight`). This is a common technique that:\n",
|
|
"- Reduces parameter count\n",
|
|
"- Acts as a regularizer\n",
|
|
"- Is used in GPT-2 and many other language models\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3eec6f35",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:25.391551Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:25.391044Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:27.304182Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:27.303257Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"\n",
|
|
"model = JointRouteGPT(\n",
|
|
" vocab_size=len(stoi),\n",
|
|
" block_size=block_size,\n",
|
|
" n_embd=128,\n",
|
|
" n_head=4,\n",
|
|
" n_layer=4,\n",
|
|
" dropout=0.10,\n",
|
|
" pad_id=pad_id,\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)\n",
|
|
"\n",
|
|
"print(f\"Device: {device}\")\n",
|
|
"print(f\"Total parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f999cf05",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:27.307681Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:27.307034Z",
|
|
"iopub.status.idle": "2026-06-07T19:12:27.314453Z",
|
|
"shell.execute_reply": "2026-06-07T19:12:27.313491Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train_epoch():\n",
|
|
" \"\"\"Train for one epoch.\"\"\"\n",
|
|
" model.train()\n",
|
|
" losses = []\n",
|
|
" n = 0\n",
|
|
" for batch in train_loader:\n",
|
|
" x = batch[\"input_ids\"].to(device)\n",
|
|
" y = batch[\"target_ids\"].to(device)\n",
|
|
" \n",
|
|
" optimizer.zero_grad(set_to_none=True)\n",
|
|
" _, loss = model(x, y)\n",
|
|
" loss.backward()\n",
|
|
" \n",
|
|
" # Gradient clipping prevents exploding gradients\n",
|
|
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
|
|
" \n",
|
|
" optimizer.step()\n",
|
|
" losses.append(loss.item() * x.size(0))\n",
|
|
" n += x.size(0)\n",
|
|
" return sum(losses) / max(1, n)\n",
|
|
"\n",
|
|
"@torch.no_grad()\n",
|
|
"def eval_loss(loader):\n",
|
|
" \"\"\"Evaluate loss on a data loader.\"\"\"\n",
|
|
" model.eval()\n",
|
|
" losses = []\n",
|
|
" n = 0\n",
|
|
" for batch in loader:\n",
|
|
" x = batch[\"input_ids\"].to(device)\n",
|
|
" y = batch[\"target_ids\"].to(device)\n",
|
|
" _, loss = model(x, y)\n",
|
|
" losses.append(loss.item() * x.size(0))\n",
|
|
" n += x.size(0)\n",
|
|
" return sum(losses) / max(1, n)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "51fb8b6e",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Training\n",
|
|
"\n",
|
|
"### What we're optimizing\n",
|
|
"\n",
|
|
"The model minimizes **cross-entropy loss** — the standard loss function for language modeling. At each position, the model outputs a probability distribution over the entire vocabulary, and the loss measures how surprised it is by the actual next token.\n",
|
|
"\n",
|
|
"### Perplexity\n",
|
|
"\n",
|
|
"We also track **perplexity**, which is `exp(loss)`. Perplexity answers the question: \"On average, how many tokens was the model choosing between at each step?\" Lower perplexity = better model.\n",
|
|
"\n",
|
|
"For reference:\n",
|
|
"- A model that always predicts the right token has perplexity = 1\n",
|
|
"- A model that picks uniformly from a 1000-token vocab has perplexity = 1000\n",
|
|
"- Good language models on English text achieve perplexity ~15-20\n",
|
|
"\n",
|
|
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "70b38b02",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T19:12:27.317432Z",
|
|
"iopub.status.busy": "2026-06-07T19:12:27.317008Z",
|
|
"iopub.status.idle": "2026-06-07T23:43:38.199890Z",
|
|
"shell.execute_reply": "2026-06-07T23:43:38.198963Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"history = []\n",
|
|
"best_val_loss = float(\"inf\")\n",
|
|
"best_state = None\n",
|
|
"patience = 10\n",
|
|
"stagnant = 0\n",
|
|
"\n",
|
|
"print(\"Starting GPT training...\\n\")\n",
|
|
"\n",
|
|
"for epoch in range(1, 21):\n",
|
|
" train_loss = train_epoch()\n",
|
|
" val_loss = eval_loss(val_loader)\n",
|
|
" \n",
|
|
" # Track perplexity (exponentiated loss)\n",
|
|
" train_ppl = math.exp(min(train_loss, 20))\n",
|
|
" val_ppl = math.exp(min(val_loss, 20))\n",
|
|
" \n",
|
|
" history.append({\n",
|
|
" \"epoch\": epoch,\n",
|
|
" \"train_loss\": train_loss,\n",
|
|
" \"val_loss\": val_loss,\n",
|
|
" \"train_perplexity\": train_ppl,\n",
|
|
" \"val_perplexity\": val_ppl,\n",
|
|
" })\n",
|
|
" \n",
|
|
" # Early stopping\n",
|
|
" if val_loss < best_val_loss:\n",
|
|
" best_val_loss = val_loss\n",
|
|
" best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}\n",
|
|
" stagnant = 0\n",
|
|
" else:\n",
|
|
" stagnant += 1\n",
|
|
" \n",
|
|
" if epoch == 1 or epoch % 5 == 0:\n",
|
|
" print(f\"Epoch {epoch:3d} | \"\n",
|
|
" f\"Train Loss: {train_loss:.4f} | \"\n",
|
|
" f\"Val Loss: {val_loss:.4f} | \"\n",
|
|
" f\"Val PPL: {val_ppl:.1f}\")\n",
|
|
" \n",
|
|
" if stagnant >= patience:\n",
|
|
" print(f\"\\nEarly stopping at epoch {epoch}\")\n",
|
|
" break\n",
|
|
"\n",
|
|
"# Load best model\n",
|
|
"if best_state is not None:\n",
|
|
" model.load_state_dict(best_state)\n",
|
|
"\n",
|
|
"print(f\"\\nBest validation loss: {best_val_loss:.4f}\")\n",
|
|
"print(f\"Best validation perplexity: {math.exp(min(best_val_loss, 20)):.1f}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "20096c62",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Board configuration helpers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8d26d6d4",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T23:43:38.203421Z",
|
|
"iopub.status.busy": "2026-06-07T23:43:38.203126Z",
|
|
"iopub.status.idle": "2026-06-07T23:43:38.217533Z",
|
|
"shell.execute_reply": "2026-06-07T23:43:38.216653Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Find the project root and load board configuration JSON files.\n",
|
|
"def find_project_root(start: str | Path | None = None) -> Path:\n",
|
|
" \"\"\"Walk upward until the repository root markers are found.\n",
|
|
"\n",
|
|
" The project root is identified by both ``pyproject.toml`` and ``configs``.\n",
|
|
" If neither marker pair is found, the resolved starting directory is returned\n",
|
|
" so callers still have a deterministic base path.\n",
|
|
" \"\"\"\n",
|
|
" current = Path(start).resolve() if start is not None else Path.cwd().resolve()\n",
|
|
" for candidate in [current, *current.parents]:\n",
|
|
" if (candidate / \"pyproject.toml\").exists() and (candidate / \"configs\").exists():\n",
|
|
" return candidate\n",
|
|
" return current\n",
|
|
"\n",
|
|
"@dataclass(frozen=True)\n",
|
|
"class BoardConfig:\n",
|
|
" \"\"\"Configuration for a single climbing board.\n",
|
|
" \n",
|
|
" This dataclass stores all board-specific settings needed for\n",
|
|
" data loading, tokenization, and model training.\n",
|
|
" \n",
|
|
" Attributes:\n",
|
|
" board_key: Short identifier (e.g., \"tb2\", \"kilter\")\n",
|
|
" display_name: Human-readable name (e.g., \"Tension Board 2 Mirror\")\n",
|
|
" token_prefix: Namespace for hold tokens (e.g., \"TB2\", \"KILTER\")\n",
|
|
" db_path: Path to the SQLite database\n",
|
|
" layout_id: Which layout in the database to use\n",
|
|
" max_angle: Filter out routes steeper than this (None = no filter)\n",
|
|
" min_fa_date: Filter out routes first ascended before this date\n",
|
|
" placement_y_max: Filter out placements above this Y coordinate\n",
|
|
" include_mirror_placement_id: Whether to include mirror info (TB2 only)\n",
|
|
" role_definitions: Maps semantic role names to numeric IDs\n",
|
|
" boardlib_database_command: Command to download the database\n",
|
|
" boardlib_images_command: Command to download board images\n",
|
|
" notes: Additional notes about the configuration\n",
|
|
" \"\"\"\n",
|
|
" board_key: str\n",
|
|
" display_name: str\n",
|
|
" token_prefix: str\n",
|
|
" db_path: Path\n",
|
|
" layout_id: int\n",
|
|
" max_angle: float | None\n",
|
|
" min_fa_date: str | None\n",
|
|
" placement_y_max: float | None\n",
|
|
" include_mirror_placement_id: bool\n",
|
|
" role_definitions: dict[str, int]\n",
|
|
" boardlib_database_command: str | None = None\n",
|
|
" boardlib_images_command: str | None = None\n",
|
|
" notes: tuple[str, ...] = ()\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def role_id_to_name(self) -> dict[int, str]:\n",
|
|
" \"\"\"Reverse mapping from numeric role IDs to semantic role names.\n",
|
|
" \n",
|
|
" Example: {5: 'start', 6: 'middle', 7: 'finish', 8: 'foot'} for TB2\n",
|
|
" \"\"\"\n",
|
|
" return {int(role_id): name for name, role_id in self.role_definitions.items()}\n",
|
|
"\n",
|
|
" @property\n",
|
|
" def board_token(self) -> str:\n",
|
|
" \"\"\"The special token representing this board.\n",
|
|
" \n",
|
|
" Example: \"<BOARD_TB2>\" or \"<BOARD_KILTER>\"\n",
|
|
" \"\"\"\n",
|
|
" return f\"<BOARD_{self.token_prefix}>\"\n",
|
|
"\n",
|
|
" def resolve_db_path(self, project_root: Path | None = None) -> Path:\n",
|
|
" \"\"\"Resolve the database path relative to the project root.\n",
|
|
" \n",
|
|
" If db_path is absolute, return it as-is.\n",
|
|
" Otherwise, resolve it relative to the project root.\n",
|
|
" \"\"\"\n",
|
|
" project_root = project_root or find_project_root()\n",
|
|
" return self.db_path if self.db_path.is_absolute() else project_root / self.db_path\n",
|
|
"\n",
|
|
"def load_board_config(board_key: str, config_dir: str | Path | None = None) -> BoardConfig:\n",
|
|
" \"\"\"Load a single board configuration from a JSON file.\n",
|
|
" \n",
|
|
" Args:\n",
|
|
" board_key: Board identifier (e.g., \"tb2\", \"kilter\")\n",
|
|
" config_dir: Directory containing config JSON files\n",
|
|
" \n",
|
|
" Returns:\n",
|
|
" BoardConfig dataclass with all board settings\n",
|
|
" \n",
|
|
" Raises:\n",
|
|
" FileNotFoundError: If the config file doesn't exist\n",
|
|
" \"\"\"\n",
|
|
" project_root = find_project_root()\n",
|
|
" config_dir = Path(config_dir) if config_dir is not None else project_root / \"configs\"\n",
|
|
" path = config_dir / f\"{board_key}.json\"\n",
|
|
" if not path.exists():\n",
|
|
" available = sorted(p.stem for p in config_dir.glob(\"*.json\"))\n",
|
|
" raise FileNotFoundError(\n",
|
|
" f\"Unknown board config '{board_key}'. Available: {available}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" payload = json.loads(path.read_text(encoding=\"utf-8\"))\n",
|
|
" return BoardConfig(\n",
|
|
" board_key=str(payload[\"board_key\"]),\n",
|
|
" display_name=str(payload[\"display_name\"]),\n",
|
|
" token_prefix=str(payload[\"token_prefix\"]),\n",
|
|
" db_path=Path(payload[\"db_path\"]),\n",
|
|
" layout_id=int(payload[\"layout_id\"]),\n",
|
|
" max_angle=None if payload.get(\"max_angle\") is None else float(payload[\"max_angle\"]),\n",
|
|
" min_fa_date=payload.get(\"min_fa_date\"),\n",
|
|
" placement_y_max=None if payload.get(\"placement_y_max\") is None else float(payload[\"placement_y_max\"]),\n",
|
|
" include_mirror_placement_id=bool(payload.get(\"include_mirror_placement_id\", False)),\n",
|
|
" role_definitions={str(k): int(v) for k, v in payload[\"role_definitions\"].items()},\n",
|
|
" boardlib_database_command=payload.get(\"boardlib_database_command\"),\n",
|
|
" boardlib_images_command=payload.get(\"boardlib_images_command\"),\n",
|
|
" notes=tuple(payload.get(\"notes\", [])),\n",
|
|
" )\n",
|
|
"\n",
|
|
"def load_board_configs(board_keys: list[str] | tuple[str, ...]) -> list[BoardConfig]:\n",
|
|
" \"\"\"Load multiple board configurations.\n",
|
|
" \n",
|
|
" Args:\n",
|
|
" board_keys: List of board identifiers\n",
|
|
" \n",
|
|
" Returns:\n",
|
|
" List of BoardConfig dataclasses\n",
|
|
" \"\"\"\n",
|
|
" return [load_board_config(board_key) for board_key in board_keys]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "94d352c6",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Generation helpers"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "abdabe8e",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T23:43:38.220501Z",
|
|
"iopub.status.busy": "2026-06-07T23:43:38.220197Z",
|
|
"iopub.status.idle": "2026-06-07T23:43:38.241455Z",
|
|
"shell.execute_reply": "2026-06-07T23:43:38.240589Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Parse generated hold tokens back into structured hold records.\n",
|
|
"HOLD_TOKEN_PATTERN = re.compile(r\"^<([A-Z0-9_]+)_p(\\d+)_(start|middle|finish|foot|unknown)>$\")\n",
|
|
"\n",
|
|
"def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
|
" \"\"\"Extract hold records from model tokens using the shared hold-token grammar.\"\"\"\n",
|
|
" rows: list[dict[str, object]] = []\n",
|
|
" for token in tokens:\n",
|
|
" match = HOLD_TOKEN_PATTERN.match(str(token))\n",
|
|
" if match is None:\n",
|
|
" continue\n",
|
|
" board_prefix = match.group(1)\n",
|
|
" rows.append(\n",
|
|
" {\n",
|
|
" \"token\": str(token),\n",
|
|
" \"board_token_prefix\": board_prefix,\n",
|
|
" \"board_prefix\": board_prefix,\n",
|
|
" \"placement_id\": int(match.group(2)),\n",
|
|
" \"role\": match.group(3),\n",
|
|
" }\n",
|
|
" )\n",
|
|
" return rows\n",
|
|
"\n",
|
|
"# Sample routes from the trained GPT model and convert them back to frames strings.\n",
|
|
"def top_k_filter(logits: torch.Tensor, k: int | None) -> torch.Tensor:\n",
|
|
" \"\"\"Mask logits outside the top ``k`` choices for each batch row.\"\"\"\n",
|
|
" if k is None or k <= 0 or k >= logits.size(-1):\n",
|
|
" return logits\n",
|
|
" values, _ = torch.topk(logits, k)\n",
|
|
" cutoff = values[:, [-1]]\n",
|
|
" return torch.where(logits < cutoff, torch.full_like(logits, -float(\"inf\")), logits)\n",
|
|
"\n",
|
|
"@torch.no_grad()\n",
|
|
"def sample_ids(\n",
|
|
" model,\n",
|
|
" prompt_ids: list[int],\n",
|
|
" device: torch.device,\n",
|
|
" max_new_tokens: int = 40,\n",
|
|
" temperature: float = 0.9,\n",
|
|
" top_k: int | None = 50,\n",
|
|
" eos_id: int | None = None,\n",
|
|
" forbidden_ids: Iterable[int] | None = None,\n",
|
|
") -> list[int]:\n",
|
|
" \"\"\"Autoregressively sample token IDs from a trained route generator.\n",
|
|
"\n",
|
|
" The returned list includes the prompt IDs and all sampled IDs up to either\n",
|
|
" ``max_new_tokens`` or the first sampled ``eos_id``.\n",
|
|
" \"\"\"\n",
|
|
" model.eval()\n",
|
|
" sequence = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n",
|
|
" forbidden_ids = set(forbidden_ids or [])\n",
|
|
"\n",
|
|
" for _ in range(max_new_tokens):\n",
|
|
" idx_cond = sequence[:, -model.block_size :]\n",
|
|
" logits, _ = model(idx_cond)\n",
|
|
" logits = logits[:, -1, :] / max(temperature, 1e-6)\n",
|
|
"\n",
|
|
" # Special tokens like <PAD> and <CLS> are valid vocabulary entries but\n",
|
|
" # should never be emitted in the middle of a generated climb.\n",
|
|
" for token_id in forbidden_ids:\n",
|
|
" logits[:, int(token_id)] = -float(\"inf\")\n",
|
|
"\n",
|
|
" logits = top_k_filter(logits, top_k)\n",
|
|
" probs = F.softmax(logits, dim=-1)\n",
|
|
" next_id = torch.multinomial(probs, num_samples=1)\n",
|
|
" sequence = torch.cat([sequence, next_id], dim=1)\n",
|
|
"\n",
|
|
" if eos_id is not None and int(next_id.item()) == int(eos_id):\n",
|
|
" break\n",
|
|
"\n",
|
|
" return sequence[0].detach().cpu().tolist()\n",
|
|
"\n",
|
|
"def prompt_tokens(board_prefix: str, angle: int, grouped_v: int) -> list[str]:\n",
|
|
" \"\"\"Build the conditioning prefix used before sampling hold tokens.\"\"\"\n",
|
|
" return [\n",
|
|
" \"<BOS>\",\n",
|
|
" f\"<BOARD_{board_prefix}>\",\n",
|
|
" f\"<ANGLE_{int(angle)}>\",\n",
|
|
" f\"<GRADE_V{int(grouped_v)}>\",\n",
|
|
" ]\n",
|
|
"\n",
|
|
"def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
|
" \"\"\"Extract hold records from generated tokens.\"\"\"\n",
|
|
" return tokens_to_hold_records(tokens)\n",
|
|
"\n",
|
|
"def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:\n",
|
|
" \"\"\"Summarize basic structural validity for generated token sequences.\"\"\"\n",
|
|
" records = hold_records(tokens)\n",
|
|
" placements = [record[\"placement_id\"] for record in records]\n",
|
|
" roles = [record[\"role\"] for record in records]\n",
|
|
" prefixes = [record[\"board_prefix\"] for record in records]\n",
|
|
"\n",
|
|
" one_board_only = len(set(prefixes)) <= 1\n",
|
|
" matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)\n",
|
|
" no_duplicates = len(placements) == len(set(placements))\n",
|
|
" has_start = \"start\" in roles\n",
|
|
" has_finish = \"finish\" in roles\n",
|
|
" enough_holds = len(records) >= 3\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"n_hold_tokens\": len(records),\n",
|
|
" \"n_unique_placements\": len(set(placements)),\n",
|
|
" \"has_duplicate_placements\": not no_duplicates,\n",
|
|
" \"one_board_only\": one_board_only,\n",
|
|
" \"matches_requested_board\": matches_requested_board,\n",
|
|
" \"has_start\": has_start,\n",
|
|
" \"has_middle\": \"middle\" in roles,\n",
|
|
" \"has_finish\": has_finish,\n",
|
|
" \"n_start\": roles.count(\"start\"),\n",
|
|
" \"n_middle\": roles.count(\"middle\"),\n",
|
|
" \"n_foot\": roles.count(\"foot\"),\n",
|
|
" \"n_finish\": roles.count(\"finish\"),\n",
|
|
" \"basic_valid\": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),\n",
|
|
" }\n",
|
|
"\n",
|
|
"def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:\n",
|
|
" \"\"\"Convert generated hold tokens back into a frames string.\n",
|
|
"\n",
|
|
" Duplicate placements and unknown roles are skipped, matching the forgiving\n",
|
|
" cleanup used by the demo scripts and webapp.\n",
|
|
" \"\"\"\n",
|
|
" pieces = []\n",
|
|
" seen = set()\n",
|
|
" for record in hold_records(tokens):\n",
|
|
" if board_prefix is not None and str(record[\"board_prefix\"]) != board_prefix:\n",
|
|
" continue\n",
|
|
" placement_id = int(record[\"placement_id\"])\n",
|
|
" role = str(record[\"role\"])\n",
|
|
" if placement_id in seen or role not in role_name_to_id:\n",
|
|
" continue\n",
|
|
" seen.add(placement_id)\n",
|
|
" pieces.append(f\"p{placement_id}r{int(role_name_to_id[role])}\")\n",
|
|
" return \"\".join(pieces)\n",
|
|
"\n",
|
|
"def generate_one(\n",
|
|
" model,\n",
|
|
" stoi: dict[str, int],\n",
|
|
" itos: dict[int, str],\n",
|
|
" device: torch.device,\n",
|
|
" board_prefix: str,\n",
|
|
" angle: int,\n",
|
|
" grouped_v: int,\n",
|
|
" role_name_to_id: dict[str, int],\n",
|
|
" temperature: float = 0.9,\n",
|
|
" top_k: int | None = 50,\n",
|
|
" max_new_tokens: int = 40,\n",
|
|
") -> dict[str, object]:\n",
|
|
" \"\"\"Generate one route and return tokens, frames, request metadata, validity.\"\"\"\n",
|
|
" unk_id = stoi[\"<UNK>\"]\n",
|
|
" eos_id = stoi[\"<EOS>\"]\n",
|
|
" forbidden_ids = [\n",
|
|
" stoi[\"<PAD>\"],\n",
|
|
" stoi[\"<UNK>\"],\n",
|
|
" stoi[\"<BOS>\"],\n",
|
|
" stoi[\"<CLS>\"],\n",
|
|
" stoi[\"<MASK>\"],\n",
|
|
" ]\n",
|
|
"\n",
|
|
" prompt = prompt_tokens(board_prefix, angle, grouped_v)\n",
|
|
" prompt_ids = [stoi.get(token, unk_id) for token in prompt]\n",
|
|
" token_ids = sample_ids(\n",
|
|
" model=model,\n",
|
|
" prompt_ids=prompt_ids,\n",
|
|
" device=device,\n",
|
|
" max_new_tokens=max_new_tokens,\n",
|
|
" temperature=temperature,\n",
|
|
" top_k=top_k,\n",
|
|
" eos_id=eos_id,\n",
|
|
" forbidden_ids=forbidden_ids,\n",
|
|
" )\n",
|
|
" tokens = [itos.get(int(idx), \"<UNK>\") for idx in token_ids]\n",
|
|
" validity = validity_summary(tokens, requested_board_prefix=board_prefix)\n",
|
|
"\n",
|
|
" return {\n",
|
|
" \"requested_board_prefix\": board_prefix,\n",
|
|
" \"requested_angle\": int(angle),\n",
|
|
" \"requested_grouped_v\": int(grouped_v),\n",
|
|
" \"temperature\": float(temperature),\n",
|
|
" \"top_k\": None if top_k is None else int(top_k),\n",
|
|
" \"tokens\": tokens,\n",
|
|
" \"sequence\": \" \".join(tokens),\n",
|
|
" \"frames\": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),\n",
|
|
" **validity,\n",
|
|
" }"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "69926180",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generating Routes\n",
|
|
"\n",
|
|
"### The generation process\n",
|
|
"\n",
|
|
"To generate a route, we:\n",
|
|
"\n",
|
|
"1. **Create a prompt**: `<BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>`\n",
|
|
"2. **Feed it to the model**: Get a probability distribution over the vocabulary for the next token\n",
|
|
"3. **Sample a token**: Use temperature and top-k filtering to control randomness\n",
|
|
"4. **Append and repeat**: Add the sampled token to the sequence and repeat until `<EOS>` or max length\n",
|
|
"\n",
|
|
"### Temperature and top-k sampling\n",
|
|
"\n",
|
|
"- **Temperature** (default 0.9): Controls randomness. Lower = more deterministic, higher = more random\n",
|
|
"- **Top-k** (default 50): Only consider the k most likely tokens. This prevents the model from generating very unlikely tokens.\n",
|
|
"\n",
|
|
"These are the same techniques used in language models like GPT-3 to control output diversity.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "029eb911",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T23:43:38.244254Z",
|
|
"iopub.status.busy": "2026-06-07T23:43:38.244037Z",
|
|
"iopub.status.idle": "2026-06-07T23:43:38.680983Z",
|
|
"shell.execute_reply": "2026-06-07T23:43:38.679992Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate sample routes for both boards\n",
|
|
"configs = load_board_configs([\"tb2\", \"kilter\"])\n",
|
|
"configs_by_key = {config.board_key: config for config in configs}\n",
|
|
"\n",
|
|
"samples = []\n",
|
|
"for board_key, config in configs_by_key.items():\n",
|
|
" for grouped_v in [3, 5, 7]: # V3, V5, V7\n",
|
|
" sample = generate_one(\n",
|
|
" model=model,\n",
|
|
" stoi=stoi,\n",
|
|
" itos=itos,\n",
|
|
" device=device,\n",
|
|
" board_prefix=config.token_prefix,\n",
|
|
" angle=40,\n",
|
|
" grouped_v=grouped_v,\n",
|
|
" role_name_to_id=config.role_definitions,\n",
|
|
" temperature=0.9,\n",
|
|
" top_k=50,\n",
|
|
" max_new_tokens=40,\n",
|
|
" )\n",
|
|
" samples.append({\"board_key\": board_key, **sample})\n",
|
|
"\n",
|
|
"samples_df = pd.DataFrame(samples)\n",
|
|
"print(\"Generated route samples:\")\n",
|
|
"print(samples_df[[\"board_key\", \"requested_grouped_v\", \"basic_valid\", \"sequence\", \"frames\"]])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "generate_more",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Generate More Routes for Evaluation\n",
|
|
"\n",
|
|
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "generate_bulk",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T23:43:38.684476Z",
|
|
"iopub.status.busy": "2026-06-07T23:43:38.683935Z",
|
|
"iopub.status.idle": "2026-06-07T23:43:53.779260Z",
|
|
"shell.execute_reply": "2026-06-07T23:43:53.778391Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate routes across multiple angles and grades for evaluation\n",
|
|
"all_samples = []\n",
|
|
"\n",
|
|
"for board_key, config in configs_by_key.items():\n",
|
|
" # Get common angles and grades for this board\n",
|
|
" board_df = df_routes[df_routes[\"board_key\"] == board_key]\n",
|
|
" common_angles = sorted(board_df[\"angle\"].astype(int).value_counts().head(5).index.tolist())\n",
|
|
" common_grades = sorted(board_df[\"grouped_v\"].astype(int).value_counts().head(8).index.tolist())\n",
|
|
" \n",
|
|
" print(f\"\\nGenerating for {config.display_name}:\")\n",
|
|
" print(f\" Angles: {common_angles}\")\n",
|
|
" print(f\" Grades: V{min(common_grades)}-V{max(common_grades)}\")\n",
|
|
" \n",
|
|
" for angle in common_angles:\n",
|
|
" for grade in common_grades:\n",
|
|
" for i in range(5): # 5 samples per condition\n",
|
|
" sample = generate_one(\n",
|
|
" model=model,\n",
|
|
" stoi=stoi,\n",
|
|
" itos=itos,\n",
|
|
" device=device,\n",
|
|
" board_prefix=config.token_prefix,\n",
|
|
" angle=int(angle),\n",
|
|
" grouped_v=int(grade),\n",
|
|
" role_name_to_id=config.role_definitions,\n",
|
|
" temperature=0.9,\n",
|
|
" top_k=50,\n",
|
|
" max_new_tokens=40,\n",
|
|
" )\n",
|
|
" all_samples.append({\"board_key\": board_key, **sample})\n",
|
|
"\n",
|
|
"all_samples_df = pd.DataFrame(all_samples)\n",
|
|
"print(f\"\\nTotal generated routes: {len(all_samples_df):,}\")\n",
|
|
"print(\"\\nBasic validity by board:\")\n",
|
|
"print(all_samples_df.groupby(\"board_key\")[\"basic_valid\"].mean())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "save_artifacts",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Save Model and Generated Routes\n",
|
|
"\n",
|
|
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation).\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "save_outputs",
|
|
"metadata": {
|
|
"execution": {
|
|
"iopub.execute_input": "2026-06-07T23:43:53.782874Z",
|
|
"iopub.status.busy": "2026-06-07T23:43:53.782303Z",
|
|
"iopub.status.idle": "2026-06-07T23:43:53.831685Z",
|
|
"shell.execute_reply": "2026-06-07T23:43:53.830785Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"\n",
|
|
"# Save model checkpoint\n",
|
|
"MODEL_DIR = ROOT / \"models\"\n",
|
|
"MODEL_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
"\n",
|
|
"checkpoint = {\n",
|
|
" \"model_state_dict\": model.state_dict(),\n",
|
|
" \"config\": {\n",
|
|
" \"vocab_size\": len(stoi),\n",
|
|
" \"block_size\": block_size,\n",
|
|
" \"n_embd\": 128,\n",
|
|
" \"n_head\": 4,\n",
|
|
" \"n_layer\": 4,\n",
|
|
" \"dropout\": 0.10,\n",
|
|
" \"pad_id\": pad_id,\n",
|
|
" },\n",
|
|
" \"stoi\": stoi,\n",
|
|
" \"itos\": {str(k): v for k, v in itos.items()},\n",
|
|
" \"best_val_loss\": best_val_loss,\n",
|
|
"}\n",
|
|
"model_path = MODEL_DIR / \"joint_route_gpt_generator.pth\"\n",
|
|
"torch.save(checkpoint, model_path)\n",
|
|
"print(f\"Saved model checkpoint to: {model_path}\")\n",
|
|
"\n",
|
|
"# Save training history\n",
|
|
"GEN_DIR = ROOT / \"data\" / \"processed\" / \"generation\"\n",
|
|
"GEN_DIR.mkdir(parents=True, exist_ok=True)\n",
|
|
"\n",
|
|
"pd.DataFrame(history).to_csv(GEN_DIR / \"training_history.csv\", index=False)\n",
|
|
"print(f\"Saved training history to: {GEN_DIR / 'training_history.csv'}\")\n",
|
|
"\n",
|
|
"# Save generated routes (this is what notebook 04 needs)\n",
|
|
"all_samples_df.to_csv(GEN_DIR / \"generated_routes.csv\", index=False)\n",
|
|
"print(f\"Saved {len(all_samples_df)} generated routes to: {GEN_DIR / 'generated_routes.csv'}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.12"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|