Update notebook results and README stats

This commit is contained in:
2026-06-08 13:20:19 -04:00
parent f87d116c03
commit 874de6c0fb
5 changed files with 2679 additions and 154 deletions
+643 -28
View File
@@ -43,40 +43,58 @@
"- `<ANGLE_40>`: At 40 degrees\n",
"- `<GRADE_V6>`: At V6 difficulty\n",
"\n",
"This is analogous to how ChatGPT uses a system prompt to condition its responses."
"This is analogous to how ChatGPT uses a system prompt to condition its responses.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6590822",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:14.101804Z",
"iopub.status.busy": "2026-06-07T19:12:14.101439Z",
"iopub.status.idle": "2026-06-07T19:12:16.162395Z",
"shell.execute_reply": "2026-06-07T19:12:16.161684Z"
}
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"import sys\n",
"from __future__ import annotations\n",
"\n",
"import ast\n",
"import json\n",
"import math\n",
"import re\n",
"from dataclasses import dataclass\n",
"from pathlib import Path\n",
"from typing import Iterable\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from torch.utils.data import DataLoader\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader, Dataset\n",
"\n",
"ROOT = Path.cwd().resolve()\n",
"if ROOT.name == \"notebooks\":\n",
" ROOT = ROOT.parent\n",
"sys.path.insert(0, str(ROOT / \"src\"))\n",
"\n",
"from climbingboardgpt.config import load_board_configs\n",
"from climbingboardgpt.datasets import RouteGPTDataset\n",
"from climbingboardgpt.generation import generate_one\n",
"from climbingboardgpt.models import JointRouteGPT"
" ROOT = ROOT.parent"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f09fdf54",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:16.166017Z",
"iopub.status.busy": "2026-06-07T19:12:16.165651Z",
"iopub.status.idle": "2026-06-07T19:12:21.900885Z",
"shell.execute_reply": "2026-06-07T19:12:21.899985Z"
}
},
"outputs": [],
"source": [
"TOKENIZED = ROOT / \"data\" / \"processed\" / \"tokenized\"\n",
@@ -92,6 +110,59 @@
"print(f\"Total routes: {len(df_routes):,}\")"
]
},
{
"cell_type": "markdown",
"id": "4fcba532",
"metadata": {},
"source": [
"### Causal dataset helper"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40021fc1",
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:21.904008Z",
"iopub.status.busy": "2026-06-07T19:12:21.903750Z",
"iopub.status.idle": "2026-06-07T19:12:21.910270Z",
"shell.execute_reply": "2026-06-07T19:12:21.909595Z"
}
},
"outputs": [],
"source": [
"# Pad route-token sequences and create shifted input/target pairs for causal modeling.\n",
"class RouteGPTDataset(Dataset):\n",
" \"\"\"Dataset for causal next-token route generation.\n",
"\n",
" The full sequence is padded once, then split into ``input_ids`` and\n",
" ``target_ids`` shifted by one position for teacher-forced language-model\n",
" training.\n",
" \"\"\"\n",
"\n",
" def __init__(self, df, max_len: int, pad_id: int):\n",
" \"\"\"Store GPT token ID sequences from a tokenized route DataFrame.\"\"\"\n",
" self.ids = df[\"gpt_ids\"].tolist()\n",
" self.max_len = int(max_len)\n",
" self.pad_id = int(pad_id)\n",
"\n",
" def __len__(self) -> int:\n",
" \"\"\"Return the number of route examples.\"\"\"\n",
" return len(self.ids)\n",
"\n",
" def __getitem__(self, idx: int):\n",
" \"\"\"Return one padded causal-language-model training example.\"\"\"\n",
" ids = list(self.ids[idx])[: self.max_len]\n",
" if len(ids) < self.max_len:\n",
" ids += [self.pad_id] * (self.max_len - len(ids))\n",
"\n",
" return {\n",
" \"input_ids\": torch.tensor(ids[:-1], dtype=torch.long),\n",
" \"target_ids\": torch.tensor(ids[1:], dtype=torch.long),\n",
" }"
]
},
{
"cell_type": "markdown",
"id": "fe4b0faf",
@@ -114,14 +185,22 @@
"\n",
"For the grade predictor (notebook 02), we excluded the grade because the model needed to predict it. But for the generator, we **include** the grade (`<GRADE_V6>`) in the training data so the model learns the relationship between grade and hold selection.\n",
"\n",
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade."
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ad61dbd",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:21.913286Z",
"iopub.status.busy": "2026-06-07T19:12:21.912788Z",
"iopub.status.idle": "2026-06-07T19:12:25.369590Z",
"shell.execute_reply": "2026-06-07T19:12:25.368643Z"
}
},
"outputs": [],
"source": [
"def encode(tokens):\n",
@@ -153,6 +232,132 @@
"print(f\"Validation samples: {len(val_ds):,}\")"
]
},
{
"cell_type": "markdown",
"id": "552b9f69",
"metadata": {},
"source": [
"### GPT model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4085a314",
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:25.373057Z",
"iopub.status.busy": "2026-06-07T19:12:25.372812Z",
"iopub.status.idle": "2026-06-07T19:12:25.388342Z",
"shell.execute_reply": "2026-06-07T19:12:25.387476Z"
}
},
"outputs": [],
"source": [
"# GPT-style causal transformer for route generation.\n",
"class JointRouteGPT(nn.Module):\n",
" \"\"\"Tiny GPT-style causal transformer for board-conditioned route generation.\n",
"\n",
" PyTorch's ``TransformerEncoder`` is used with a causal mask, which makes it\n",
" behave like a decoder-only language model for short route sequences.\n",
"\n",
" Why use ``TransformerEncoder`` rather than ``TransformerDecoder``?\n",
" -------------------------------------------------------------------\n",
" PyTorch's ``TransformerDecoderLayer`` expects two inputs: a decoder\n",
" sequence and a separate encoder memory for cross-attention. For\n",
" unconditional or prompt-conditioned generation there is no encoder,\n",
" so ``TransformerDecoderLayer`` would always ignore the second input\n",
" or require a dummy placeholder. Using ``TransformerEncoder`` with a\n",
" causal mask avoids this mismatch, keeps the module list uniform,\n",
" and produces identical behaviour for short autoregressive generation.\n",
"\n",
" The trade-off is that ``TransformerEncoder`` does not natively prevent\n",
" attention to future positions — the causal mask must be constructed\n",
" manually (see ``forward``). For the sequence lengths seen here\n",
" (at most ~400 tokens) the overhead of the upper-triangular mask is\n",
" negligible, and ``enable_nested_tensor=False`` is set to avoid SDPA\n",
" optimisations that do not support masked encoders.\n",
" \"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" vocab_size: int,\n",
" block_size: int,\n",
" n_embd: int = 128,\n",
" n_head: int = 4,\n",
" n_layer: int = 4,\n",
" dropout: float = 0.10,\n",
" pad_id: int = 0,\n",
" ):\n",
" \"\"\"Create the token/position embeddings, causal blocks, and LM head.\"\"\"\n",
" super().__init__()\n",
" self.vocab_size = vocab_size\n",
" self.block_size = block_size\n",
" self.pad_id = pad_id\n",
"\n",
" self.token_emb = nn.Embedding(vocab_size, n_embd, padding_idx=pad_id)\n",
" self.pos_emb = nn.Embedding(block_size, n_embd)\n",
" self.drop = nn.Dropout(dropout)\n",
"\n",
" layer = nn.TransformerEncoderLayer(\n",
" d_model=n_embd,\n",
" nhead=n_head,\n",
" dim_feedforward=4 * n_embd,\n",
" dropout=dropout,\n",
" activation=\"gelu\",\n",
" batch_first=True,\n",
" norm_first=True,\n",
" )\n",
" self.blocks = nn.TransformerEncoder(\n",
" layer,\n",
" num_layers=n_layer,\n",
" enable_nested_tensor=False,\n",
" )\n",
" self.ln_f = nn.LayerNorm(n_embd)\n",
" self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)\n",
" self.lm_head.weight = self.token_emb.weight\n",
"\n",
" def forward(\n",
" self,\n",
" idx: torch.Tensor,\n",
" targets: torch.Tensor | None = None,\n",
" ) -> tuple[torch.Tensor, torch.Tensor | None]:\n",
" \"\"\"Return next-token logits and, when targets are supplied, CE loss.\"\"\"\n",
" _, seq_len = idx.shape\n",
" if seq_len > self.block_size:\n",
" idx = idx[:, -self.block_size :]\n",
" seq_len = idx.shape[1]\n",
"\n",
" positions = torch.arange(seq_len, device=idx.device).unsqueeze(0)\n",
" x = self.drop(self.token_emb(idx) + self.pos_emb(positions))\n",
"\n",
" causal_mask = torch.triu(\n",
" torch.ones(seq_len, seq_len, device=idx.device, dtype=torch.bool),\n",
" diagonal=1,\n",
" )\n",
" # Padding masks suppress attention to right-padded context tokens while\n",
" # the causal mask suppresses attention to future positions.\n",
" key_padding_mask = idx.eq(self.pad_id)\n",
"\n",
" h = self.blocks(\n",
" x,\n",
" mask=causal_mask,\n",
" src_key_padding_mask=key_padding_mask,\n",
" )\n",
" h = self.ln_f(h)\n",
" logits = self.lm_head(h)\n",
"\n",
" loss = None\n",
" if targets is not None:\n",
" loss = F.cross_entropy(\n",
" logits.reshape(-1, logits.size(-1)),\n",
" targets.reshape(-1),\n",
" ignore_index=self.pad_id,\n",
" )\n",
"\n",
" return logits, loss"
]
},
{
"cell_type": "markdown",
"id": "66d98641",
@@ -177,21 +382,29 @@
"- `n_layer=4`: Number of transformer layers (GPT-2 small uses 12)\n",
"- `dropout=0.10`: Dropout probability\n",
"\n",
"This is intentionally small — we're training on ~40K short sequences, not billions of long documents.\n",
"This is intentionally small — we're training on a few hundred thousand short sequences, not billions of long documents.\n",
"\n",
"### Weight tying\n",
"\n",
"The output projection layer shares weights with the token embedding layer (`self.lm_head.weight = self.token_emb.weight`). This is a common technique that:\n",
"- Reduces parameter count\n",
"- Acts as a regularizer\n",
"- Is used in GPT-2 and many other language models"
"- Is used in GPT-2 and many other language models\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3eec6f35",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:25.391551Z",
"iopub.status.busy": "2026-06-07T19:12:25.391044Z",
"iopub.status.idle": "2026-06-07T19:12:27.304182Z",
"shell.execute_reply": "2026-06-07T19:12:27.303257Z"
}
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
@@ -216,7 +429,14 @@
"cell_type": "code",
"execution_count": null,
"id": "f999cf05",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:27.307681Z",
"iopub.status.busy": "2026-06-07T19:12:27.307034Z",
"iopub.status.idle": "2026-06-07T19:12:27.314453Z",
"shell.execute_reply": "2026-06-07T19:12:27.313491Z"
}
},
"outputs": [],
"source": [
"def train_epoch():\n",
@@ -275,14 +495,22 @@
"- A model that picks uniformly from a 1000-token vocab has perplexity = 1000\n",
"- Good language models on English text achieve perplexity ~15-20\n",
"\n",
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns."
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "70b38b02",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T19:12:27.317432Z",
"iopub.status.busy": "2026-06-07T19:12:27.317008Z",
"iopub.status.idle": "2026-06-07T23:43:38.199890Z",
"shell.execute_reply": "2026-06-07T23:43:38.198963Z"
}
},
"outputs": [],
"source": [
"history = []\n",
@@ -335,6 +563,361 @@
"print(f\"Best validation perplexity: {math.exp(min(best_val_loss, 20)):.1f}\")"
]
},
{
"cell_type": "markdown",
"id": "20096c62",
"metadata": {},
"source": [
"### Board configuration helpers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d26d6d4",
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T23:43:38.203421Z",
"iopub.status.busy": "2026-06-07T23:43:38.203126Z",
"iopub.status.idle": "2026-06-07T23:43:38.217533Z",
"shell.execute_reply": "2026-06-07T23:43:38.216653Z"
}
},
"outputs": [],
"source": [
"# Find the project root and load board configuration JSON files.\n",
"def find_project_root(start: str | Path | None = None) -> Path:\n",
" \"\"\"Walk upward until the repository root markers are found.\n",
"\n",
" The project root is identified by both ``pyproject.toml`` and ``configs``.\n",
" If neither marker pair is found, the resolved starting directory is returned\n",
" so callers still have a deterministic base path.\n",
" \"\"\"\n",
" current = Path(start).resolve() if start is not None else Path.cwd().resolve()\n",
" for candidate in [current, *current.parents]:\n",
" if (candidate / \"pyproject.toml\").exists() and (candidate / \"configs\").exists():\n",
" return candidate\n",
" return current\n",
"\n",
"@dataclass(frozen=True)\n",
"class BoardConfig:\n",
" \"\"\"Configuration for a single climbing board.\n",
" \n",
" This dataclass stores all board-specific settings needed for\n",
" data loading, tokenization, and model training.\n",
" \n",
" Attributes:\n",
" board_key: Short identifier (e.g., \"tb2\", \"kilter\")\n",
" display_name: Human-readable name (e.g., \"Tension Board 2 Mirror\")\n",
" token_prefix: Namespace for hold tokens (e.g., \"TB2\", \"KILTER\")\n",
" db_path: Path to the SQLite database\n",
" layout_id: Which layout in the database to use\n",
" max_angle: Filter out routes steeper than this (None = no filter)\n",
" min_fa_date: Filter out routes first ascended before this date\n",
" placement_y_max: Filter out placements above this Y coordinate\n",
" include_mirror_placement_id: Whether to include mirror info (TB2 only)\n",
" role_definitions: Maps semantic role names to numeric IDs\n",
" boardlib_database_command: Command to download the database\n",
" boardlib_images_command: Command to download board images\n",
" notes: Additional notes about the configuration\n",
" \"\"\"\n",
" board_key: str\n",
" display_name: str\n",
" token_prefix: str\n",
" db_path: Path\n",
" layout_id: int\n",
" max_angle: float | None\n",
" min_fa_date: str | None\n",
" placement_y_max: float | None\n",
" include_mirror_placement_id: bool\n",
" role_definitions: dict[str, int]\n",
" boardlib_database_command: str | None = None\n",
" boardlib_images_command: str | None = None\n",
" notes: tuple[str, ...] = ()\n",
"\n",
" @property\n",
" def role_id_to_name(self) -> dict[int, str]:\n",
" \"\"\"Reverse mapping from numeric role IDs to semantic role names.\n",
" \n",
" Example: {5: 'start', 6: 'middle', 7: 'finish', 8: 'foot'} for TB2\n",
" \"\"\"\n",
" return {int(role_id): name for name, role_id in self.role_definitions.items()}\n",
"\n",
" @property\n",
" def board_token(self) -> str:\n",
" \"\"\"The special token representing this board.\n",
" \n",
" Example: \"<BOARD_TB2>\" or \"<BOARD_KILTER>\"\n",
" \"\"\"\n",
" return f\"<BOARD_{self.token_prefix}>\"\n",
"\n",
" def resolve_db_path(self, project_root: Path | None = None) -> Path:\n",
" \"\"\"Resolve the database path relative to the project root.\n",
" \n",
" If db_path is absolute, return it as-is.\n",
" Otherwise, resolve it relative to the project root.\n",
" \"\"\"\n",
" project_root = project_root or find_project_root()\n",
" return self.db_path if self.db_path.is_absolute() else project_root / self.db_path\n",
"\n",
"def load_board_config(board_key: str, config_dir: str | Path | None = None) -> BoardConfig:\n",
" \"\"\"Load a single board configuration from a JSON file.\n",
" \n",
" Args:\n",
" board_key: Board identifier (e.g., \"tb2\", \"kilter\")\n",
" config_dir: Directory containing config JSON files\n",
" \n",
" Returns:\n",
" BoardConfig dataclass with all board settings\n",
" \n",
" Raises:\n",
" FileNotFoundError: If the config file doesn't exist\n",
" \"\"\"\n",
" project_root = find_project_root()\n",
" config_dir = Path(config_dir) if config_dir is not None else project_root / \"configs\"\n",
" path = config_dir / f\"{board_key}.json\"\n",
" if not path.exists():\n",
" available = sorted(p.stem for p in config_dir.glob(\"*.json\"))\n",
" raise FileNotFoundError(\n",
" f\"Unknown board config '{board_key}'. Available: {available}\"\n",
" )\n",
"\n",
" payload = json.loads(path.read_text(encoding=\"utf-8\"))\n",
" return BoardConfig(\n",
" board_key=str(payload[\"board_key\"]),\n",
" display_name=str(payload[\"display_name\"]),\n",
" token_prefix=str(payload[\"token_prefix\"]),\n",
" db_path=Path(payload[\"db_path\"]),\n",
" layout_id=int(payload[\"layout_id\"]),\n",
" max_angle=None if payload.get(\"max_angle\") is None else float(payload[\"max_angle\"]),\n",
" min_fa_date=payload.get(\"min_fa_date\"),\n",
" placement_y_max=None if payload.get(\"placement_y_max\") is None else float(payload[\"placement_y_max\"]),\n",
" include_mirror_placement_id=bool(payload.get(\"include_mirror_placement_id\", False)),\n",
" role_definitions={str(k): int(v) for k, v in payload[\"role_definitions\"].items()},\n",
" boardlib_database_command=payload.get(\"boardlib_database_command\"),\n",
" boardlib_images_command=payload.get(\"boardlib_images_command\"),\n",
" notes=tuple(payload.get(\"notes\", [])),\n",
" )\n",
"\n",
"def load_board_configs(board_keys: list[str] | tuple[str, ...]) -> list[BoardConfig]:\n",
" \"\"\"Load multiple board configurations.\n",
" \n",
" Args:\n",
" board_keys: List of board identifiers\n",
" \n",
" Returns:\n",
" List of BoardConfig dataclasses\n",
" \"\"\"\n",
" return [load_board_config(board_key) for board_key in board_keys]"
]
},
{
"cell_type": "markdown",
"id": "94d352c6",
"metadata": {},
"source": [
"### Generation helpers"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "abdabe8e",
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T23:43:38.220501Z",
"iopub.status.busy": "2026-06-07T23:43:38.220197Z",
"iopub.status.idle": "2026-06-07T23:43:38.241455Z",
"shell.execute_reply": "2026-06-07T23:43:38.240589Z"
}
},
"outputs": [],
"source": [
"# Parse generated hold tokens back into structured hold records.\n",
"HOLD_TOKEN_PATTERN = re.compile(r\"^<([A-Z0-9_]+)_p(\\d+)_(start|middle|finish|foot|unknown)>$\")\n",
"\n",
"def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
" \"\"\"Extract hold records from model tokens using the shared hold-token grammar.\"\"\"\n",
" rows: list[dict[str, object]] = []\n",
" for token in tokens:\n",
" match = HOLD_TOKEN_PATTERN.match(str(token))\n",
" if match is None:\n",
" continue\n",
" board_prefix = match.group(1)\n",
" rows.append(\n",
" {\n",
" \"token\": str(token),\n",
" \"board_token_prefix\": board_prefix,\n",
" \"board_prefix\": board_prefix,\n",
" \"placement_id\": int(match.group(2)),\n",
" \"role\": match.group(3),\n",
" }\n",
" )\n",
" return rows\n",
"\n",
"# Sample routes from the trained GPT model and convert them back to frames strings.\n",
"def top_k_filter(logits: torch.Tensor, k: int | None) -> torch.Tensor:\n",
" \"\"\"Mask logits outside the top ``k`` choices for each batch row.\"\"\"\n",
" if k is None or k <= 0 or k >= logits.size(-1):\n",
" return logits\n",
" values, _ = torch.topk(logits, k)\n",
" cutoff = values[:, [-1]]\n",
" return torch.where(logits < cutoff, torch.full_like(logits, -float(\"inf\")), logits)\n",
"\n",
"@torch.no_grad()\n",
"def sample_ids(\n",
" model,\n",
" prompt_ids: list[int],\n",
" device: torch.device,\n",
" max_new_tokens: int = 40,\n",
" temperature: float = 0.9,\n",
" top_k: int | None = 50,\n",
" eos_id: int | None = None,\n",
" forbidden_ids: Iterable[int] | None = None,\n",
") -> list[int]:\n",
" \"\"\"Autoregressively sample token IDs from a trained route generator.\n",
"\n",
" The returned list includes the prompt IDs and all sampled IDs up to either\n",
" ``max_new_tokens`` or the first sampled ``eos_id``.\n",
" \"\"\"\n",
" model.eval()\n",
" sequence = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n",
" forbidden_ids = set(forbidden_ids or [])\n",
"\n",
" for _ in range(max_new_tokens):\n",
" idx_cond = sequence[:, -model.block_size :]\n",
" logits, _ = model(idx_cond)\n",
" logits = logits[:, -1, :] / max(temperature, 1e-6)\n",
"\n",
" # Special tokens like <PAD> and <CLS> are valid vocabulary entries but\n",
" # should never be emitted in the middle of a generated climb.\n",
" for token_id in forbidden_ids:\n",
" logits[:, int(token_id)] = -float(\"inf\")\n",
"\n",
" logits = top_k_filter(logits, top_k)\n",
" probs = F.softmax(logits, dim=-1)\n",
" next_id = torch.multinomial(probs, num_samples=1)\n",
" sequence = torch.cat([sequence, next_id], dim=1)\n",
"\n",
" if eos_id is not None and int(next_id.item()) == int(eos_id):\n",
" break\n",
"\n",
" return sequence[0].detach().cpu().tolist()\n",
"\n",
"def prompt_tokens(board_prefix: str, angle: int, grouped_v: int) -> list[str]:\n",
" \"\"\"Build the conditioning prefix used before sampling hold tokens.\"\"\"\n",
" return [\n",
" \"<BOS>\",\n",
" f\"<BOARD_{board_prefix}>\",\n",
" f\"<ANGLE_{int(angle)}>\",\n",
" f\"<GRADE_V{int(grouped_v)}>\",\n",
" ]\n",
"\n",
"def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
" \"\"\"Extract hold records from generated tokens.\"\"\"\n",
" return tokens_to_hold_records(tokens)\n",
"\n",
"def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:\n",
" \"\"\"Summarize basic structural validity for generated token sequences.\"\"\"\n",
" records = hold_records(tokens)\n",
" placements = [record[\"placement_id\"] for record in records]\n",
" roles = [record[\"role\"] for record in records]\n",
" prefixes = [record[\"board_prefix\"] for record in records]\n",
"\n",
" one_board_only = len(set(prefixes)) <= 1\n",
" matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)\n",
" no_duplicates = len(placements) == len(set(placements))\n",
" has_start = \"start\" in roles\n",
" has_finish = \"finish\" in roles\n",
" enough_holds = len(records) >= 3\n",
"\n",
" return {\n",
" \"n_hold_tokens\": len(records),\n",
" \"n_unique_placements\": len(set(placements)),\n",
" \"has_duplicate_placements\": not no_duplicates,\n",
" \"one_board_only\": one_board_only,\n",
" \"matches_requested_board\": matches_requested_board,\n",
" \"has_start\": has_start,\n",
" \"has_middle\": \"middle\" in roles,\n",
" \"has_finish\": has_finish,\n",
" \"n_start\": roles.count(\"start\"),\n",
" \"n_middle\": roles.count(\"middle\"),\n",
" \"n_foot\": roles.count(\"foot\"),\n",
" \"n_finish\": roles.count(\"finish\"),\n",
" \"basic_valid\": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),\n",
" }\n",
"\n",
"def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:\n",
" \"\"\"Convert generated hold tokens back into a frames string.\n",
"\n",
" Duplicate placements and unknown roles are skipped, matching the forgiving\n",
" cleanup used by the demo scripts and webapp.\n",
" \"\"\"\n",
" pieces = []\n",
" seen = set()\n",
" for record in hold_records(tokens):\n",
" if board_prefix is not None and str(record[\"board_prefix\"]) != board_prefix:\n",
" continue\n",
" placement_id = int(record[\"placement_id\"])\n",
" role = str(record[\"role\"])\n",
" if placement_id in seen or role not in role_name_to_id:\n",
" continue\n",
" seen.add(placement_id)\n",
" pieces.append(f\"p{placement_id}r{int(role_name_to_id[role])}\")\n",
" return \"\".join(pieces)\n",
"\n",
"def generate_one(\n",
" model,\n",
" stoi: dict[str, int],\n",
" itos: dict[int, str],\n",
" device: torch.device,\n",
" board_prefix: str,\n",
" angle: int,\n",
" grouped_v: int,\n",
" role_name_to_id: dict[str, int],\n",
" temperature: float = 0.9,\n",
" top_k: int | None = 50,\n",
" max_new_tokens: int = 40,\n",
") -> dict[str, object]:\n",
" \"\"\"Generate one route and return tokens, frames, request metadata, validity.\"\"\"\n",
" unk_id = stoi[\"<UNK>\"]\n",
" eos_id = stoi[\"<EOS>\"]\n",
" forbidden_ids = [\n",
" stoi[\"<PAD>\"],\n",
" stoi[\"<UNK>\"],\n",
" stoi[\"<BOS>\"],\n",
" stoi[\"<CLS>\"],\n",
" stoi[\"<MASK>\"],\n",
" ]\n",
"\n",
" prompt = prompt_tokens(board_prefix, angle, grouped_v)\n",
" prompt_ids = [stoi.get(token, unk_id) for token in prompt]\n",
" token_ids = sample_ids(\n",
" model=model,\n",
" prompt_ids=prompt_ids,\n",
" device=device,\n",
" max_new_tokens=max_new_tokens,\n",
" temperature=temperature,\n",
" top_k=top_k,\n",
" eos_id=eos_id,\n",
" forbidden_ids=forbidden_ids,\n",
" )\n",
" tokens = [itos.get(int(idx), \"<UNK>\") for idx in token_ids]\n",
" validity = validity_summary(tokens, requested_board_prefix=board_prefix)\n",
"\n",
" return {\n",
" \"requested_board_prefix\": board_prefix,\n",
" \"requested_angle\": int(angle),\n",
" \"requested_grouped_v\": int(grouped_v),\n",
" \"temperature\": float(temperature),\n",
" \"top_k\": None if top_k is None else int(top_k),\n",
" \"tokens\": tokens,\n",
" \"sequence\": \" \".join(tokens),\n",
" \"frames\": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),\n",
" **validity,\n",
" }"
]
},
{
"cell_type": "markdown",
"id": "69926180",
@@ -356,14 +939,22 @@
"- **Temperature** (default 0.9): Controls randomness. Lower = more deterministic, higher = more random\n",
"- **Top-k** (default 50): Only consider the k most likely tokens. This prevents the model from generating very unlikely tokens.\n",
"\n",
"These are the same techniques used in language models like GPT-3 to control output diversity."
"These are the same techniques used in language models like GPT-3 to control output diversity.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "029eb911",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T23:43:38.244254Z",
"iopub.status.busy": "2026-06-07T23:43:38.244037Z",
"iopub.status.idle": "2026-06-07T23:43:38.680983Z",
"shell.execute_reply": "2026-06-07T23:43:38.679992Z"
}
},
"outputs": [],
"source": [
"# Generate sample routes for both boards\n",
@@ -400,14 +991,22 @@
"source": [
"## Generate More Routes for Evaluation\n",
"\n",
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards."
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "generate_bulk",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T23:43:38.684476Z",
"iopub.status.busy": "2026-06-07T23:43:38.683935Z",
"iopub.status.idle": "2026-06-07T23:43:53.779260Z",
"shell.execute_reply": "2026-06-07T23:43:53.778391Z"
}
},
"outputs": [],
"source": [
"# Generate routes across multiple angles and grades for evaluation\n",
@@ -454,14 +1053,22 @@
"source": [
"## Save Model and Generated Routes\n",
"\n",
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation)."
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation).\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "save_outputs",
"metadata": {},
"metadata": {
"execution": {
"iopub.execute_input": "2026-06-07T23:43:53.782874Z",
"iopub.status.busy": "2026-06-07T23:43:53.782303Z",
"iopub.status.idle": "2026-06-07T23:43:53.831685Z",
"shell.execute_reply": "2026-06-07T23:43:53.830785Z"
}
},
"outputs": [],
"source": [
"import os\n",
@@ -509,8 +1116,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.11"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,