Update notebook results and README stats
This commit is contained in:
@@ -43,40 +43,58 @@
|
||||
"- `<ANGLE_40>`: At 40 degrees\n",
|
||||
"- `<GRADE_V6>`: At V6 difficulty\n",
|
||||
"\n",
|
||||
"This is analogous to how ChatGPT uses a system prompt to condition its responses."
|
||||
"This is analogous to how ChatGPT uses a system prompt to condition its responses.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b6590822",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:14.101804Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:14.101439Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:16.162395Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:16.161684Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"import sys\n",
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import ast\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"import re\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Iterable\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.utils.data import DataLoader, Dataset\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if ROOT.name == \"notebooks\":\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"sys.path.insert(0, str(ROOT / \"src\"))\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.config import load_board_configs\n",
|
||||
"from climbingboardgpt.datasets import RouteGPTDataset\n",
|
||||
"from climbingboardgpt.generation import generate_one\n",
|
||||
"from climbingboardgpt.models import JointRouteGPT"
|
||||
" ROOT = ROOT.parent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f09fdf54",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:16.166017Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:16.165651Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:21.900885Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:21.899985Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOKENIZED = ROOT / \"data\" / \"processed\" / \"tokenized\"\n",
|
||||
@@ -92,6 +110,59 @@
|
||||
"print(f\"Total routes: {len(df_routes):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4fcba532",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Causal dataset helper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "40021fc1",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:21.904008Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:21.903750Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:21.910270Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:21.909595Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Pad route-token sequences and create shifted input/target pairs for causal modeling.\n",
|
||||
"class RouteGPTDataset(Dataset):\n",
|
||||
" \"\"\"Dataset for causal next-token route generation.\n",
|
||||
"\n",
|
||||
" The full sequence is padded once, then split into ``input_ids`` and\n",
|
||||
" ``target_ids`` shifted by one position for teacher-forced language-model\n",
|
||||
" training.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, df, max_len: int, pad_id: int):\n",
|
||||
" \"\"\"Store GPT token ID sequences from a tokenized route DataFrame.\"\"\"\n",
|
||||
" self.ids = df[\"gpt_ids\"].tolist()\n",
|
||||
" self.max_len = int(max_len)\n",
|
||||
" self.pad_id = int(pad_id)\n",
|
||||
"\n",
|
||||
" def __len__(self) -> int:\n",
|
||||
" \"\"\"Return the number of route examples.\"\"\"\n",
|
||||
" return len(self.ids)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx: int):\n",
|
||||
" \"\"\"Return one padded causal-language-model training example.\"\"\"\n",
|
||||
" ids = list(self.ids[idx])[: self.max_len]\n",
|
||||
" if len(ids) < self.max_len:\n",
|
||||
" ids += [self.pad_id] * (self.max_len - len(ids))\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"input_ids\": torch.tensor(ids[:-1], dtype=torch.long),\n",
|
||||
" \"target_ids\": torch.tensor(ids[1:], dtype=torch.long),\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fe4b0faf",
|
||||
@@ -114,14 +185,22 @@
|
||||
"\n",
|
||||
"For the grade predictor (notebook 02), we excluded the grade because the model needed to predict it. But for the generator, we **include** the grade (`<GRADE_V6>`) in the training data so the model learns the relationship between grade and hold selection.\n",
|
||||
"\n",
|
||||
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade."
|
||||
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ad61dbd",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:21.913286Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:21.912788Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:25.369590Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:25.368643Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def encode(tokens):\n",
|
||||
@@ -153,6 +232,132 @@
|
||||
"print(f\"Validation samples: {len(val_ds):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "552b9f69",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### GPT model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4085a314",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:25.373057Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:25.372812Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:25.388342Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:25.387476Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GPT-style causal transformer for route generation.\n",
|
||||
"class JointRouteGPT(nn.Module):\n",
|
||||
" \"\"\"Tiny GPT-style causal transformer for board-conditioned route generation.\n",
|
||||
"\n",
|
||||
" PyTorch's ``TransformerEncoder`` is used with a causal mask, which makes it\n",
|
||||
" behave like a decoder-only language model for short route sequences.\n",
|
||||
"\n",
|
||||
" Why use ``TransformerEncoder`` rather than ``TransformerDecoder``?\n",
|
||||
" -------------------------------------------------------------------\n",
|
||||
" PyTorch's ``TransformerDecoderLayer`` expects two inputs: a decoder\n",
|
||||
" sequence and a separate encoder memory for cross-attention. For\n",
|
||||
" unconditional or prompt-conditioned generation there is no encoder,\n",
|
||||
" so ``TransformerDecoderLayer`` would always ignore the second input\n",
|
||||
" or require a dummy placeholder. Using ``TransformerEncoder`` with a\n",
|
||||
" causal mask avoids this mismatch, keeps the module list uniform,\n",
|
||||
" and produces identical behaviour for short autoregressive generation.\n",
|
||||
"\n",
|
||||
" The trade-off is that ``TransformerEncoder`` does not natively prevent\n",
|
||||
" attention to future positions — the causal mask must be constructed\n",
|
||||
" manually (see ``forward``). For the sequence lengths seen here\n",
|
||||
" (at most ~400 tokens) the overhead of the upper-triangular mask is\n",
|
||||
" negligible, and ``enable_nested_tensor=False`` is set to avoid SDPA\n",
|
||||
" optimisations that do not support masked encoders.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" vocab_size: int,\n",
|
||||
" block_size: int,\n",
|
||||
" n_embd: int = 128,\n",
|
||||
" n_head: int = 4,\n",
|
||||
" n_layer: int = 4,\n",
|
||||
" dropout: float = 0.10,\n",
|
||||
" pad_id: int = 0,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Create the token/position embeddings, causal blocks, and LM head.\"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.vocab_size = vocab_size\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.pad_id = pad_id\n",
|
||||
"\n",
|
||||
" self.token_emb = nn.Embedding(vocab_size, n_embd, padding_idx=pad_id)\n",
|
||||
" self.pos_emb = nn.Embedding(block_size, n_embd)\n",
|
||||
" self.drop = nn.Dropout(dropout)\n",
|
||||
"\n",
|
||||
" layer = nn.TransformerEncoderLayer(\n",
|
||||
" d_model=n_embd,\n",
|
||||
" nhead=n_head,\n",
|
||||
" dim_feedforward=4 * n_embd,\n",
|
||||
" dropout=dropout,\n",
|
||||
" activation=\"gelu\",\n",
|
||||
" batch_first=True,\n",
|
||||
" norm_first=True,\n",
|
||||
" )\n",
|
||||
" self.blocks = nn.TransformerEncoder(\n",
|
||||
" layer,\n",
|
||||
" num_layers=n_layer,\n",
|
||||
" enable_nested_tensor=False,\n",
|
||||
" )\n",
|
||||
" self.ln_f = nn.LayerNorm(n_embd)\n",
|
||||
" self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)\n",
|
||||
" self.lm_head.weight = self.token_emb.weight\n",
|
||||
"\n",
|
||||
" def forward(\n",
|
||||
" self,\n",
|
||||
" idx: torch.Tensor,\n",
|
||||
" targets: torch.Tensor | None = None,\n",
|
||||
" ) -> tuple[torch.Tensor, torch.Tensor | None]:\n",
|
||||
" \"\"\"Return next-token logits and, when targets are supplied, CE loss.\"\"\"\n",
|
||||
" _, seq_len = idx.shape\n",
|
||||
" if seq_len > self.block_size:\n",
|
||||
" idx = idx[:, -self.block_size :]\n",
|
||||
" seq_len = idx.shape[1]\n",
|
||||
"\n",
|
||||
" positions = torch.arange(seq_len, device=idx.device).unsqueeze(0)\n",
|
||||
" x = self.drop(self.token_emb(idx) + self.pos_emb(positions))\n",
|
||||
"\n",
|
||||
" causal_mask = torch.triu(\n",
|
||||
" torch.ones(seq_len, seq_len, device=idx.device, dtype=torch.bool),\n",
|
||||
" diagonal=1,\n",
|
||||
" )\n",
|
||||
" # Padding masks suppress attention to right-padded context tokens while\n",
|
||||
" # the causal mask suppresses attention to future positions.\n",
|
||||
" key_padding_mask = idx.eq(self.pad_id)\n",
|
||||
"\n",
|
||||
" h = self.blocks(\n",
|
||||
" x,\n",
|
||||
" mask=causal_mask,\n",
|
||||
" src_key_padding_mask=key_padding_mask,\n",
|
||||
" )\n",
|
||||
" h = self.ln_f(h)\n",
|
||||
" logits = self.lm_head(h)\n",
|
||||
"\n",
|
||||
" loss = None\n",
|
||||
" if targets is not None:\n",
|
||||
" loss = F.cross_entropy(\n",
|
||||
" logits.reshape(-1, logits.size(-1)),\n",
|
||||
" targets.reshape(-1),\n",
|
||||
" ignore_index=self.pad_id,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return logits, loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66d98641",
|
||||
@@ -177,21 +382,29 @@
|
||||
"- `n_layer=4`: Number of transformer layers (GPT-2 small uses 12)\n",
|
||||
"- `dropout=0.10`: Dropout probability\n",
|
||||
"\n",
|
||||
"This is intentionally small — we're training on ~40K short sequences, not billions of long documents.\n",
|
||||
"This is intentionally small — we're training on a few hundred thousand short sequences, not billions of long documents.\n",
|
||||
"\n",
|
||||
"### Weight tying\n",
|
||||
"\n",
|
||||
"The output projection layer shares weights with the token embedding layer (`self.lm_head.weight = self.token_emb.weight`). This is a common technique that:\n",
|
||||
"- Reduces parameter count\n",
|
||||
"- Acts as a regularizer\n",
|
||||
"- Is used in GPT-2 and many other language models"
|
||||
"- Is used in GPT-2 and many other language models\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3eec6f35",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:25.391551Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:25.391044Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:27.304182Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:27.303257Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
@@ -216,7 +429,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f999cf05",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:27.307681Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:27.307034Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:27.314453Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:27.313491Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_epoch():\n",
|
||||
@@ -275,14 +495,22 @@
|
||||
"- A model that picks uniformly from a 1000-token vocab has perplexity = 1000\n",
|
||||
"- Good language models on English text achieve perplexity ~15-20\n",
|
||||
"\n",
|
||||
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns."
|
||||
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "70b38b02",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:27.317432Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:27.317008Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.199890Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.198963Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history = []\n",
|
||||
@@ -335,6 +563,361 @@
|
||||
"print(f\"Best validation perplexity: {math.exp(min(best_val_loss, 20)):.1f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "20096c62",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Board configuration helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d26d6d4",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.203421Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.203126Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.217533Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.216653Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find the project root and load board configuration JSON files.\n",
|
||||
"def find_project_root(start: str | Path | None = None) -> Path:\n",
|
||||
" \"\"\"Walk upward until the repository root markers are found.\n",
|
||||
"\n",
|
||||
" The project root is identified by both ``pyproject.toml`` and ``configs``.\n",
|
||||
" If neither marker pair is found, the resolved starting directory is returned\n",
|
||||
" so callers still have a deterministic base path.\n",
|
||||
" \"\"\"\n",
|
||||
" current = Path(start).resolve() if start is not None else Path.cwd().resolve()\n",
|
||||
" for candidate in [current, *current.parents]:\n",
|
||||
" if (candidate / \"pyproject.toml\").exists() and (candidate / \"configs\").exists():\n",
|
||||
" return candidate\n",
|
||||
" return current\n",
|
||||
"\n",
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class BoardConfig:\n",
|
||||
" \"\"\"Configuration for a single climbing board.\n",
|
||||
" \n",
|
||||
" This dataclass stores all board-specific settings needed for\n",
|
||||
" data loading, tokenization, and model training.\n",
|
||||
" \n",
|
||||
" Attributes:\n",
|
||||
" board_key: Short identifier (e.g., \"tb2\", \"kilter\")\n",
|
||||
" display_name: Human-readable name (e.g., \"Tension Board 2 Mirror\")\n",
|
||||
" token_prefix: Namespace for hold tokens (e.g., \"TB2\", \"KILTER\")\n",
|
||||
" db_path: Path to the SQLite database\n",
|
||||
" layout_id: Which layout in the database to use\n",
|
||||
" max_angle: Filter out routes steeper than this (None = no filter)\n",
|
||||
" min_fa_date: Filter out routes first ascended before this date\n",
|
||||
" placement_y_max: Filter out placements above this Y coordinate\n",
|
||||
" include_mirror_placement_id: Whether to include mirror info (TB2 only)\n",
|
||||
" role_definitions: Maps semantic role names to numeric IDs\n",
|
||||
" boardlib_database_command: Command to download the database\n",
|
||||
" boardlib_images_command: Command to download board images\n",
|
||||
" notes: Additional notes about the configuration\n",
|
||||
" \"\"\"\n",
|
||||
" board_key: str\n",
|
||||
" display_name: str\n",
|
||||
" token_prefix: str\n",
|
||||
" db_path: Path\n",
|
||||
" layout_id: int\n",
|
||||
" max_angle: float | None\n",
|
||||
" min_fa_date: str | None\n",
|
||||
" placement_y_max: float | None\n",
|
||||
" include_mirror_placement_id: bool\n",
|
||||
" role_definitions: dict[str, int]\n",
|
||||
" boardlib_database_command: str | None = None\n",
|
||||
" boardlib_images_command: str | None = None\n",
|
||||
" notes: tuple[str, ...] = ()\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def role_id_to_name(self) -> dict[int, str]:\n",
|
||||
" \"\"\"Reverse mapping from numeric role IDs to semantic role names.\n",
|
||||
" \n",
|
||||
" Example: {5: 'start', 6: 'middle', 7: 'finish', 8: 'foot'} for TB2\n",
|
||||
" \"\"\"\n",
|
||||
" return {int(role_id): name for name, role_id in self.role_definitions.items()}\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def board_token(self) -> str:\n",
|
||||
" \"\"\"The special token representing this board.\n",
|
||||
" \n",
|
||||
" Example: \"<BOARD_TB2>\" or \"<BOARD_KILTER>\"\n",
|
||||
" \"\"\"\n",
|
||||
" return f\"<BOARD_{self.token_prefix}>\"\n",
|
||||
"\n",
|
||||
" def resolve_db_path(self, project_root: Path | None = None) -> Path:\n",
|
||||
" \"\"\"Resolve the database path relative to the project root.\n",
|
||||
" \n",
|
||||
" If db_path is absolute, return it as-is.\n",
|
||||
" Otherwise, resolve it relative to the project root.\n",
|
||||
" \"\"\"\n",
|
||||
" project_root = project_root or find_project_root()\n",
|
||||
" return self.db_path if self.db_path.is_absolute() else project_root / self.db_path\n",
|
||||
"\n",
|
||||
"def load_board_config(board_key: str, config_dir: str | Path | None = None) -> BoardConfig:\n",
|
||||
" \"\"\"Load a single board configuration from a JSON file.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" board_key: Board identifier (e.g., \"tb2\", \"kilter\")\n",
|
||||
" config_dir: Directory containing config JSON files\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" BoardConfig dataclass with all board settings\n",
|
||||
" \n",
|
||||
" Raises:\n",
|
||||
" FileNotFoundError: If the config file doesn't exist\n",
|
||||
" \"\"\"\n",
|
||||
" project_root = find_project_root()\n",
|
||||
" config_dir = Path(config_dir) if config_dir is not None else project_root / \"configs\"\n",
|
||||
" path = config_dir / f\"{board_key}.json\"\n",
|
||||
" if not path.exists():\n",
|
||||
" available = sorted(p.stem for p in config_dir.glob(\"*.json\"))\n",
|
||||
" raise FileNotFoundError(\n",
|
||||
" f\"Unknown board config '{board_key}'. Available: {available}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" payload = json.loads(path.read_text(encoding=\"utf-8\"))\n",
|
||||
" return BoardConfig(\n",
|
||||
" board_key=str(payload[\"board_key\"]),\n",
|
||||
" display_name=str(payload[\"display_name\"]),\n",
|
||||
" token_prefix=str(payload[\"token_prefix\"]),\n",
|
||||
" db_path=Path(payload[\"db_path\"]),\n",
|
||||
" layout_id=int(payload[\"layout_id\"]),\n",
|
||||
" max_angle=None if payload.get(\"max_angle\") is None else float(payload[\"max_angle\"]),\n",
|
||||
" min_fa_date=payload.get(\"min_fa_date\"),\n",
|
||||
" placement_y_max=None if payload.get(\"placement_y_max\") is None else float(payload[\"placement_y_max\"]),\n",
|
||||
" include_mirror_placement_id=bool(payload.get(\"include_mirror_placement_id\", False)),\n",
|
||||
" role_definitions={str(k): int(v) for k, v in payload[\"role_definitions\"].items()},\n",
|
||||
" boardlib_database_command=payload.get(\"boardlib_database_command\"),\n",
|
||||
" boardlib_images_command=payload.get(\"boardlib_images_command\"),\n",
|
||||
" notes=tuple(payload.get(\"notes\", [])),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def load_board_configs(board_keys: list[str] | tuple[str, ...]) -> list[BoardConfig]:\n",
|
||||
" \"\"\"Load multiple board configurations.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" board_keys: List of board identifiers\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" List of BoardConfig dataclasses\n",
|
||||
" \"\"\"\n",
|
||||
" return [load_board_config(board_key) for board_key in board_keys]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "94d352c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generation helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "abdabe8e",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.220501Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.220197Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.241455Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.240589Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse generated hold tokens back into structured hold records.\n",
|
||||
"HOLD_TOKEN_PATTERN = re.compile(r\"^<([A-Z0-9_]+)_p(\\d+)_(start|middle|finish|foot|unknown)>$\")\n",
|
||||
"\n",
|
||||
"def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
||||
" \"\"\"Extract hold records from model tokens using the shared hold-token grammar.\"\"\"\n",
|
||||
" rows: list[dict[str, object]] = []\n",
|
||||
" for token in tokens:\n",
|
||||
" match = HOLD_TOKEN_PATTERN.match(str(token))\n",
|
||||
" if match is None:\n",
|
||||
" continue\n",
|
||||
" board_prefix = match.group(1)\n",
|
||||
" rows.append(\n",
|
||||
" {\n",
|
||||
" \"token\": str(token),\n",
|
||||
" \"board_token_prefix\": board_prefix,\n",
|
||||
" \"board_prefix\": board_prefix,\n",
|
||||
" \"placement_id\": int(match.group(2)),\n",
|
||||
" \"role\": match.group(3),\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" return rows\n",
|
||||
"\n",
|
||||
"# Sample routes from the trained GPT model and convert them back to frames strings.\n",
|
||||
"def top_k_filter(logits: torch.Tensor, k: int | None) -> torch.Tensor:\n",
|
||||
" \"\"\"Mask logits outside the top ``k`` choices for each batch row.\"\"\"\n",
|
||||
" if k is None or k <= 0 or k >= logits.size(-1):\n",
|
||||
" return logits\n",
|
||||
" values, _ = torch.topk(logits, k)\n",
|
||||
" cutoff = values[:, [-1]]\n",
|
||||
" return torch.where(logits < cutoff, torch.full_like(logits, -float(\"inf\")), logits)\n",
|
||||
"\n",
|
||||
"@torch.no_grad()\n",
|
||||
"def sample_ids(\n",
|
||||
" model,\n",
|
||||
" prompt_ids: list[int],\n",
|
||||
" device: torch.device,\n",
|
||||
" max_new_tokens: int = 40,\n",
|
||||
" temperature: float = 0.9,\n",
|
||||
" top_k: int | None = 50,\n",
|
||||
" eos_id: int | None = None,\n",
|
||||
" forbidden_ids: Iterable[int] | None = None,\n",
|
||||
") -> list[int]:\n",
|
||||
" \"\"\"Autoregressively sample token IDs from a trained route generator.\n",
|
||||
"\n",
|
||||
" The returned list includes the prompt IDs and all sampled IDs up to either\n",
|
||||
" ``max_new_tokens`` or the first sampled ``eos_id``.\n",
|
||||
" \"\"\"\n",
|
||||
" model.eval()\n",
|
||||
" sequence = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n",
|
||||
" forbidden_ids = set(forbidden_ids or [])\n",
|
||||
"\n",
|
||||
" for _ in range(max_new_tokens):\n",
|
||||
" idx_cond = sequence[:, -model.block_size :]\n",
|
||||
" logits, _ = model(idx_cond)\n",
|
||||
" logits = logits[:, -1, :] / max(temperature, 1e-6)\n",
|
||||
"\n",
|
||||
" # Special tokens like <PAD> and <CLS> are valid vocabulary entries but\n",
|
||||
" # should never be emitted in the middle of a generated climb.\n",
|
||||
" for token_id in forbidden_ids:\n",
|
||||
" logits[:, int(token_id)] = -float(\"inf\")\n",
|
||||
"\n",
|
||||
" logits = top_k_filter(logits, top_k)\n",
|
||||
" probs = F.softmax(logits, dim=-1)\n",
|
||||
" next_id = torch.multinomial(probs, num_samples=1)\n",
|
||||
" sequence = torch.cat([sequence, next_id], dim=1)\n",
|
||||
"\n",
|
||||
" if eos_id is not None and int(next_id.item()) == int(eos_id):\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" return sequence[0].detach().cpu().tolist()\n",
|
||||
"\n",
|
||||
"def prompt_tokens(board_prefix: str, angle: int, grouped_v: int) -> list[str]:\n",
|
||||
" \"\"\"Build the conditioning prefix used before sampling hold tokens.\"\"\"\n",
|
||||
" return [\n",
|
||||
" \"<BOS>\",\n",
|
||||
" f\"<BOARD_{board_prefix}>\",\n",
|
||||
" f\"<ANGLE_{int(angle)}>\",\n",
|
||||
" f\"<GRADE_V{int(grouped_v)}>\",\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
||||
" \"\"\"Extract hold records from generated tokens.\"\"\"\n",
|
||||
" return tokens_to_hold_records(tokens)\n",
|
||||
"\n",
|
||||
"def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:\n",
|
||||
" \"\"\"Summarize basic structural validity for generated token sequences.\"\"\"\n",
|
||||
" records = hold_records(tokens)\n",
|
||||
" placements = [record[\"placement_id\"] for record in records]\n",
|
||||
" roles = [record[\"role\"] for record in records]\n",
|
||||
" prefixes = [record[\"board_prefix\"] for record in records]\n",
|
||||
"\n",
|
||||
" one_board_only = len(set(prefixes)) <= 1\n",
|
||||
" matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)\n",
|
||||
" no_duplicates = len(placements) == len(set(placements))\n",
|
||||
" has_start = \"start\" in roles\n",
|
||||
" has_finish = \"finish\" in roles\n",
|
||||
" enough_holds = len(records) >= 3\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"n_hold_tokens\": len(records),\n",
|
||||
" \"n_unique_placements\": len(set(placements)),\n",
|
||||
" \"has_duplicate_placements\": not no_duplicates,\n",
|
||||
" \"one_board_only\": one_board_only,\n",
|
||||
" \"matches_requested_board\": matches_requested_board,\n",
|
||||
" \"has_start\": has_start,\n",
|
||||
" \"has_middle\": \"middle\" in roles,\n",
|
||||
" \"has_finish\": has_finish,\n",
|
||||
" \"n_start\": roles.count(\"start\"),\n",
|
||||
" \"n_middle\": roles.count(\"middle\"),\n",
|
||||
" \"n_foot\": roles.count(\"foot\"),\n",
|
||||
" \"n_finish\": roles.count(\"finish\"),\n",
|
||||
" \"basic_valid\": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:\n",
|
||||
" \"\"\"Convert generated hold tokens back into a frames string.\n",
|
||||
"\n",
|
||||
" Duplicate placements and unknown roles are skipped, matching the forgiving\n",
|
||||
" cleanup used by the demo scripts and webapp.\n",
|
||||
" \"\"\"\n",
|
||||
" pieces = []\n",
|
||||
" seen = set()\n",
|
||||
" for record in hold_records(tokens):\n",
|
||||
" if board_prefix is not None and str(record[\"board_prefix\"]) != board_prefix:\n",
|
||||
" continue\n",
|
||||
" placement_id = int(record[\"placement_id\"])\n",
|
||||
" role = str(record[\"role\"])\n",
|
||||
" if placement_id in seen or role not in role_name_to_id:\n",
|
||||
" continue\n",
|
||||
" seen.add(placement_id)\n",
|
||||
" pieces.append(f\"p{placement_id}r{int(role_name_to_id[role])}\")\n",
|
||||
" return \"\".join(pieces)\n",
|
||||
"\n",
|
||||
"def generate_one(\n",
|
||||
" model,\n",
|
||||
" stoi: dict[str, int],\n",
|
||||
" itos: dict[int, str],\n",
|
||||
" device: torch.device,\n",
|
||||
" board_prefix: str,\n",
|
||||
" angle: int,\n",
|
||||
" grouped_v: int,\n",
|
||||
" role_name_to_id: dict[str, int],\n",
|
||||
" temperature: float = 0.9,\n",
|
||||
" top_k: int | None = 50,\n",
|
||||
" max_new_tokens: int = 40,\n",
|
||||
") -> dict[str, object]:\n",
|
||||
" \"\"\"Generate one route and return tokens, frames, request metadata, validity.\"\"\"\n",
|
||||
" unk_id = stoi[\"<UNK>\"]\n",
|
||||
" eos_id = stoi[\"<EOS>\"]\n",
|
||||
" forbidden_ids = [\n",
|
||||
" stoi[\"<PAD>\"],\n",
|
||||
" stoi[\"<UNK>\"],\n",
|
||||
" stoi[\"<BOS>\"],\n",
|
||||
" stoi[\"<CLS>\"],\n",
|
||||
" stoi[\"<MASK>\"],\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" prompt = prompt_tokens(board_prefix, angle, grouped_v)\n",
|
||||
" prompt_ids = [stoi.get(token, unk_id) for token in prompt]\n",
|
||||
" token_ids = sample_ids(\n",
|
||||
" model=model,\n",
|
||||
" prompt_ids=prompt_ids,\n",
|
||||
" device=device,\n",
|
||||
" max_new_tokens=max_new_tokens,\n",
|
||||
" temperature=temperature,\n",
|
||||
" top_k=top_k,\n",
|
||||
" eos_id=eos_id,\n",
|
||||
" forbidden_ids=forbidden_ids,\n",
|
||||
" )\n",
|
||||
" tokens = [itos.get(int(idx), \"<UNK>\") for idx in token_ids]\n",
|
||||
" validity = validity_summary(tokens, requested_board_prefix=board_prefix)\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"requested_board_prefix\": board_prefix,\n",
|
||||
" \"requested_angle\": int(angle),\n",
|
||||
" \"requested_grouped_v\": int(grouped_v),\n",
|
||||
" \"temperature\": float(temperature),\n",
|
||||
" \"top_k\": None if top_k is None else int(top_k),\n",
|
||||
" \"tokens\": tokens,\n",
|
||||
" \"sequence\": \" \".join(tokens),\n",
|
||||
" \"frames\": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),\n",
|
||||
" **validity,\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "69926180",
|
||||
@@ -356,14 +939,22 @@
|
||||
"- **Temperature** (default 0.9): Controls randomness. Lower = more deterministic, higher = more random\n",
|
||||
"- **Top-k** (default 50): Only consider the k most likely tokens. This prevents the model from generating very unlikely tokens.\n",
|
||||
"\n",
|
||||
"These are the same techniques used in language models like GPT-3 to control output diversity."
|
||||
"These are the same techniques used in language models like GPT-3 to control output diversity.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "029eb911",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.244254Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.244037Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.680983Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.679992Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate sample routes for both boards\n",
|
||||
@@ -400,14 +991,22 @@
|
||||
"source": [
|
||||
"## Generate More Routes for Evaluation\n",
|
||||
"\n",
|
||||
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards."
|
||||
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "generate_bulk",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.684476Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.683935Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:53.779260Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:53.778391Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate routes across multiple angles and grades for evaluation\n",
|
||||
@@ -454,14 +1053,22 @@
|
||||
"source": [
|
||||
"## Save Model and Generated Routes\n",
|
||||
"\n",
|
||||
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation)."
|
||||
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation).\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "save_outputs",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:53.782874Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:53.782303Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:53.831685Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:53.830785Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
@@ -509,8 +1116,16 @@
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"version": "3.11"
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user