Update notebook results and README stats
This commit is contained in:
@@ -90,13 +90,13 @@ The route generator is a small **GPT-style model** — the same general idea as
|
||||
|
||||
and it predicts what hold token should come next, then the next, then the next, until it produces an `<EOS>` token. The result is a novel sequence of holds that the model thinks is a plausible V6 on the Kilter at 40°.
|
||||
|
||||
**~91% of generated routes pass basic structural checks** (has a start hold, has a finish hold, holds exist on the right board, no duplicates).
|
||||
**~89% of generated routes pass basic structural checks** (has a start hold, has a finish hold, holds exist on the right board, no duplicates).
|
||||
|
||||
---
|
||||
|
||||
## Quantitative results
|
||||
|
||||
These numbers are from the full training run documented by this repository.
|
||||
These numbers are from the full training run documented by this repository. The notebooks in `notebooks/` are self-contained walkthroughs of the pipeline stages. The reported pipeline run was executed on Kaggle; notebooks 01-04 took about 8h 1m 59s total using GPU T4 x2.
|
||||
|
||||
In practice: the grade model is usually within one V-grade, and the generator usually makes structurally valid routes, but exact grade control is still imperfect.
|
||||
|
||||
@@ -112,27 +112,27 @@ Shared vocabulary: **4,438 tokens** (6 special + 2 board + 12 angle + 16 grade +
|
||||
|
||||
### Grade prediction accuracy
|
||||
|
||||
The model has ~1.17M parameters. Early stopping selected epoch 8 (validation MAE ≈ 1.480).
|
||||
The model has ~1.17M parameters. Early stopping selected epoch 11 (validation MAE ≈ 1.488).
|
||||
|
||||
| Metric | Overall | TB2 | Kilter |
|
||||
|---|---:|---:|---:|
|
||||
| Exact V-grade | 36.0% | 37.3% | 35.8% |
|
||||
| Within ±1 V-grade | 79.3% | 80.0% | 79.2% |
|
||||
| Within ±2 V-grades | 94.8% | 95.5% | 94.7% |
|
||||
| R² | 0.768 | 0.800 | 0.763 |
|
||||
| Exact V-grade | 35.8% | 35.8% | 35.8% |
|
||||
| Within ±1 V-grade | 79.2% | 79.4% | 79.1% |
|
||||
| Within ±2 V-grades | 94.9% | 95.5% | 94.8% |
|
||||
| R² | 0.763 | 0.793 | 0.758 |
|
||||
|
||||
### Route generation
|
||||
|
||||
The generator has ~1.41M parameters. Best validation perplexity: 24.2.
|
||||
The generator has ~1.41M parameters. Best validation perplexity: 24.3.
|
||||
|
||||
| Metric | TB2 | Kilter |
|
||||
|---|---:|---:|
|
||||
| Routes evaluated | 200 | 200 |
|
||||
| Structurally valid | 89.0% | 94.0% |
|
||||
| Exact requested grade (critic) | 29.5% | 27.0% |
|
||||
| Within ±1 V-grade (critic) | 68.5% | 73.0% |
|
||||
| Within ±2 V-grades (critic) | 90.5% | 93.5% |
|
||||
| Mean novelty (Jaccard distance) | 0.656 | 0.634 |
|
||||
| Structurally valid | 91.5% | 86.0% |
|
||||
| Exact requested grade (critic) | 34.5% | 37.0% |
|
||||
| Within ±1 V-grade (critic) | 73.0% | 79.5% |
|
||||
| Within ±2 V-grades (critic) | 91.0% | 96.5% |
|
||||
| Mean novelty (Jaccard distance) | 0.656 | 0.643 |
|
||||
|
||||
---
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,7 +30,7 @@
|
||||
"A climb's difficulty depends on the *relationships between holds*, not just individual holds. Self-attention naturally captures these relationships:\n",
|
||||
"\n",
|
||||
"- A start hold far from the first middle hold suggests a big opening move\n",
|
||||
"- Two hand holds close together with a foot hold far away suggests a dyno\n",
|
||||
"- Two holds that are very far apart suggest a dyno\n",
|
||||
"- The overall spatial distribution determines the \"flow\" of the climb\n",
|
||||
"\n",
|
||||
"The transformer can learn these spatial relationships through attention, without us having to manually engineer features like \"mean hand reach\" or \"height gained\" (though those features were useful in the classical model).\n",
|
||||
@@ -47,40 +47,57 @@
|
||||
"\n",
|
||||
"```text\n",
|
||||
"display_difficulty (continuous value, e.g., 20.5)\n",
|
||||
"```"
|
||||
"```\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3dfd6081",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:37.490884Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:37.490209Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:42.972689Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:42.971662Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"import sys\n",
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
|
||||
"from torch.utils.data import DataLoader, Dataset\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if ROOT.name == \"notebooks\":\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"sys.path.insert(0, str(ROOT / \"src\"))\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.datasets import RouteGradeDataset\n",
|
||||
"from climbingboardgpt.metrics import regression_metrics, metrics_by_board\n",
|
||||
"from climbingboardgpt.models import JointRouteTransformerRegressor"
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8a9e2443",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:42.976137Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:42.975792Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:48.768984Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:48.768115Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOKENIZED = ROOT / \"data\" / \"processed\" / \"tokenized\"\n",
|
||||
@@ -95,7 +112,8 @@
|
||||
"unk_id = stoi[\"<UNK>\"]\n",
|
||||
"\n",
|
||||
"print(f\"Vocabulary size: {len(stoi):,}\")\n",
|
||||
"print(f\"Total routes: {len(df_routes):,}\")"
|
||||
"print(f\"Total routes: {len(df_routes):,}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -114,14 +132,22 @@
|
||||
"2. `y_norm`: Normalized vertical position on the board (-1 to 1)\n",
|
||||
"3. `is_hold`: 1 if this token represents a hold, 0 otherwise\n",
|
||||
"\n",
|
||||
"These features are projected through a linear layer and added to the token embeddings. This is similar to how some vision-language models inject spatial features from images alongside text tokens."
|
||||
"These features are projected through a linear layer and added to the token embeddings. This is similar to how some vision-language models inject spatial features from images alongside text tokens.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95bb745f",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:48.772384Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:48.771749Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:52.916642Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:52.915616Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def encode(tokens):\n",
|
||||
@@ -153,7 +179,73 @@
|
||||
"\n",
|
||||
"print(f\"Max sequence length: {max_len}\")\n",
|
||||
"print(f\"Coordinate features shape: {coord_features.shape}\")\n",
|
||||
"print(f\"Vocabulary size: {len(stoi)}\")"
|
||||
"print(f\"Vocabulary size: {len(stoi)}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9033f9e8",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dataset helper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c55c1d26",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:52.920221Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:52.919793Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:52.927627Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:52.926737Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Pad route-token sequences for transformer grade prediction.\n",
|
||||
"class RouteGradeDataset(Dataset):\n",
|
||||
" \"\"\"Dataset for transformer encoder grade prediction.\n",
|
||||
"\n",
|
||||
" Each item returns a padded token sequence, a boolean attention mask, the\n",
|
||||
" continuous display-difficulty target, and a small amount of route identity\n",
|
||||
" metadata used when writing prediction CSVs.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, df, max_len: int, pad_id: int):\n",
|
||||
" \"\"\"Store model IDs and labels from a tokenized route DataFrame.\"\"\"\n",
|
||||
" self.row_ids = df[\"row_id\"].tolist() if \"row_id\" in df.columns else df.index.tolist()\n",
|
||||
" self.ids = df[\"model_ids\"].tolist()\n",
|
||||
" self.targets = df[\"display_difficulty\"].astype(float).values\n",
|
||||
" self.uuids = df[\"uuid\"].tolist()\n",
|
||||
" self.boards = df[\"board_key\"].astype(str).tolist()\n",
|
||||
" self.max_len = int(max_len)\n",
|
||||
" self.pad_id = int(pad_id)\n",
|
||||
"\n",
|
||||
" def __len__(self) -> int:\n",
|
||||
" \"\"\"Return the number of route examples.\"\"\"\n",
|
||||
" return len(self.ids)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx: int):\n",
|
||||
" \"\"\"Return one padded encoder example and its regression target.\"\"\"\n",
|
||||
" ids = list(self.ids[idx])[: self.max_len]\n",
|
||||
" mask = [1] * len(ids)\n",
|
||||
" if len(ids) < self.max_len:\n",
|
||||
" pad_n = self.max_len - len(ids)\n",
|
||||
" ids += [self.pad_id] * pad_n\n",
|
||||
" mask += [0] * pad_n\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"input_ids\": torch.tensor(ids, dtype=torch.long),\n",
|
||||
" \"attention_mask\": torch.tensor(mask, dtype=torch.bool),\n",
|
||||
" \"target\": torch.tensor(self.targets[idx], dtype=torch.float32),\n",
|
||||
" \"row_id\": int(self.row_ids[idx]),\n",
|
||||
" \"uuid\": self.uuids[idx],\n",
|
||||
" \"board_key\": self.boards[idx],\n",
|
||||
" }\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -178,14 +270,22 @@
|
||||
"- `input_ids`: Integer token IDs, padded to `max_len`\n",
|
||||
"- `attention_mask`: 1 for real tokens, 0 for padding\n",
|
||||
"- `target`: The difficulty score we want to predict\n",
|
||||
"- `uuid`, `board_key`: Metadata for evaluation"
|
||||
"- `uuid`, `board_key`: Metadata for evaluation\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2c9e5543",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:52.930809Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:52.930299Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:53.612170Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:53.611156Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_df = df_routes[df_routes[\"split\"] == \"train\"].reset_index(drop=True)\n",
|
||||
@@ -202,7 +302,106 @@
|
||||
"\n",
|
||||
"print(f\"Training samples: {len(train_ds):,}\")\n",
|
||||
"print(f\"Validation samples: {len(val_ds):,}\")\n",
|
||||
"print(f\"Test samples: {len(test_ds):,}\")"
|
||||
"print(f\"Test samples: {len(test_ds):,}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "03091a62",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Transformer regressor model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "78612fe7",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:53.616012Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:53.615396Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:53.640842Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:53.639849Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Transformer encoder used as a continuous grade regressor.\n",
|
||||
"class JointRouteTransformerRegressor(nn.Module):\n",
|
||||
" \"\"\"Transformer encoder for joint TB2/Kilter route difficulty prediction.\n",
|
||||
"\n",
|
||||
" Inputs are token IDs plus an attention mask. Token, position, and learned\n",
|
||||
" projections of coordinate metadata are added before the encoder. The first\n",
|
||||
" ``<CLS>`` position is then used as a pooled route representation for scalar\n",
|
||||
" difficulty regression.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" vocab_size: int,\n",
|
||||
" max_len: int,\n",
|
||||
" coord_features: torch.Tensor,\n",
|
||||
" d_model: int = 128,\n",
|
||||
" nhead: int = 4,\n",
|
||||
" num_layers: int = 4,\n",
|
||||
" dim_feedforward: int = 256,\n",
|
||||
" dropout: float = 0.10,\n",
|
||||
" pad_id: int = 0,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Create the encoder, coordinate projection, and regression head.\"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.vocab_size = vocab_size\n",
|
||||
" self.max_len = max_len\n",
|
||||
" self.d_model = d_model\n",
|
||||
" self.pad_id = pad_id\n",
|
||||
"\n",
|
||||
" self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)\n",
|
||||
" self.pos_emb = nn.Embedding(max_len, d_model)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\"coord_features\", coord_features.clone().float())\n",
|
||||
" self.coord_proj = nn.Linear(coord_features.shape[1], d_model)\n",
|
||||
"\n",
|
||||
" encoder_layer = nn.TransformerEncoderLayer(\n",
|
||||
" d_model=d_model,\n",
|
||||
" nhead=nhead,\n",
|
||||
" dim_feedforward=dim_feedforward,\n",
|
||||
" dropout=dropout,\n",
|
||||
" activation=\"gelu\",\n",
|
||||
" batch_first=True,\n",
|
||||
" norm_first=True,\n",
|
||||
" )\n",
|
||||
" self.encoder = nn.TransformerEncoder(\n",
|
||||
" encoder_layer,\n",
|
||||
" num_layers=num_layers,\n",
|
||||
" enable_nested_tensor=False,\n",
|
||||
" )\n",
|
||||
" self.norm = nn.LayerNorm(d_model)\n",
|
||||
" self.head = nn.Sequential(\n",
|
||||
" nn.Linear(d_model, d_model),\n",
|
||||
" nn.GELU(),\n",
|
||||
" nn.Dropout(dropout),\n",
|
||||
" nn.Linear(d_model, 1),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:\n",
|
||||
" \"\"\"Return one continuous difficulty prediction per input sequence.\"\"\"\n",
|
||||
" batch_size, seq_len = input_ids.shape\n",
|
||||
" positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)\n",
|
||||
"\n",
|
||||
" # Coordinate features are indexed by token ID, so every occurrence of a\n",
|
||||
" # hold token gets the same physical x/y hint wherever it appears.\n",
|
||||
" x = self.token_emb(input_ids) + self.pos_emb(positions)\n",
|
||||
" x = x + self.coord_proj(self.coord_features[input_ids])\n",
|
||||
"\n",
|
||||
" key_padding_mask = ~attention_mask.bool()\n",
|
||||
" h = self.encoder(x, src_key_padding_mask=key_padding_mask)\n",
|
||||
" h = self.norm(h)\n",
|
||||
"\n",
|
||||
" cls_state = h[:, 0, :]\n",
|
||||
" return self.head(cls_state).squeeze(-1)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -235,14 +434,22 @@
|
||||
"- `nhead=4`: Number of attention heads (multi-head attention)\n",
|
||||
"- `num_layers=4`: Number of transformer layers\n",
|
||||
"- `dim_feedforward=256`: Dimension of the feedforward network inside each layer\n",
|
||||
"- `dropout=0.10`: Dropout probability for regularization"
|
||||
"- `dropout=0.10`: Dropout probability for regularization\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "62c2db48",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:53.644453Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:53.643654Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:59.327913Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:59.326972Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
@@ -262,7 +469,8 @@
|
||||
"optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)\n",
|
||||
"\n",
|
||||
"print(f\"Device: {device}\")\n",
|
||||
"print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")"
|
||||
"print(f\"Parameters: {sum(p.numel() for p in model.parameters()):,}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -284,14 +492,22 @@
|
||||
"\n",
|
||||
"### Early stopping\n",
|
||||
"\n",
|
||||
"We stop training if validation loss doesn't improve for `patience` epochs. This prevents overfitting and saves compute."
|
||||
"We stop training if validation loss doesn't improve for `patience` epochs. This prevents overfitting and saves compute.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "665deadb",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:59.331996Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:59.331485Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:59.340181Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:59.339495Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def run_epoch(model, loader, device, optimizer=None):\n",
|
||||
@@ -341,7 +557,90 @@
|
||||
"patience = 12\n",
|
||||
"\n",
|
||||
"print(f\"Max epochs: {num_epochs}\")\n",
|
||||
"print(f\"Early stopping patience: {patience}\")"
|
||||
"print(f\"Early stopping patience: {patience}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0e5bb77f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Grade metrics helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aeeb2294",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:59.343447Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:59.342978Z",
|
||||
"iopub.status.idle": "2026-06-07T15:48:59.353066Z",
|
||||
"shell.execute_reply": "2026-06-07T15:48:59.352152Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Map BoardLib display difficulties into grouped V-grade tokens.\n",
|
||||
"GRADE_TO_V = {\n",
|
||||
" 10: 0, 11: 0, 12: 0,\n",
|
||||
" 13: 1, 14: 1,\n",
|
||||
" 15: 2,\n",
|
||||
" 16: 3, 17: 3,\n",
|
||||
" 18: 4, 19: 4,\n",
|
||||
" 20: 5, 21: 5,\n",
|
||||
" 22: 6,\n",
|
||||
" 23: 7,\n",
|
||||
" 24: 8, 25: 8,\n",
|
||||
" 26: 9,\n",
|
||||
" 27: 10,\n",
|
||||
" 28: 11,\n",
|
||||
" 29: 12,\n",
|
||||
" 30: 13,\n",
|
||||
" 31: 14,\n",
|
||||
" 32: 15,\n",
|
||||
" 33: 16,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def to_grouped_v(display_difficulty: float) -> int:\n",
|
||||
" \"\"\"Map a continuous display difficulty to the nearest grouped V grade.\"\"\"\n",
|
||||
" rounded = int(round(float(display_difficulty)))\n",
|
||||
" rounded = max(min(rounded, max(GRADE_TO_V)), min(GRADE_TO_V))\n",
|
||||
" return GRADE_TO_V[rounded]\n",
|
||||
"\n",
|
||||
"def grade_token(display_difficulty: float) -> str:\n",
|
||||
" \"\"\"Return the grade-conditioning token for a display difficulty value.\"\"\"\n",
|
||||
" return f\"<GRADE_V{to_grouped_v(display_difficulty)}>\"\n",
|
||||
"\n",
|
||||
"# Evaluate difficulty regression and grouped V-grade accuracy.\n",
|
||||
"def regression_metrics(y_true, y_pred) -> dict[str, float]:\n",
|
||||
" \"\"\"Compute difficulty-scale and grouped-V-grade prediction metrics.\"\"\"\n",
|
||||
" y_true = np.asarray(y_true)\n",
|
||||
" y_pred = np.asarray(y_pred)\n",
|
||||
" true_v = np.asarray([to_grouped_v(x) for x in y_true])\n",
|
||||
" pred_v = np.asarray([to_grouped_v(x) for x in y_pred])\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"mae\": float(mean_absolute_error(y_true, y_pred)),\n",
|
||||
" \"rmse\": float(math.sqrt(mean_squared_error(y_true, y_pred))),\n",
|
||||
" \"r2\": float(r2_score(y_true, y_pred)),\n",
|
||||
" \"within_1_difficulty\": float(np.mean(np.abs(y_true - y_pred) <= 1) * 100),\n",
|
||||
" \"within_2_difficulty\": float(np.mean(np.abs(y_true - y_pred) <= 2) * 100),\n",
|
||||
" \"exact_grouped_v\": float(np.mean(true_v == pred_v) * 100),\n",
|
||||
" \"within_1_vgrade\": float(np.mean(np.abs(true_v - pred_v) <= 1) * 100),\n",
|
||||
" \"within_2_vgrades\": float(np.mean(np.abs(true_v - pred_v) <= 2) * 100),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"def metrics_by_board(pred_df: pd.DataFrame) -> pd.DataFrame:\n",
|
||||
" \"\"\"Compute regression metrics separately for each board in a prediction table.\"\"\"\n",
|
||||
" rows = []\n",
|
||||
" for board_key, frame in pred_df.groupby(\"board_key\"):\n",
|
||||
" metrics = regression_metrics(frame[\"y_true\"].values, frame[\"y_pred\"].values)\n",
|
||||
" rows.append({\"board_key\": board_key, **metrics})\n",
|
||||
" return pd.DataFrame(rows)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -360,14 +659,22 @@
|
||||
"5. **Validate**: Check performance on held-out validation data\n",
|
||||
"6. **Early stopping**: Stop if validation loss stops improving\n",
|
||||
"\n",
|
||||
"We track both fine-grained metrics (MAE, RMSE) and practical metrics (V-grade accuracy within ±1 grade)."
|
||||
"We track both fine-grained metrics (MAE, RMSE) and practical metrics (V-grade accuracy within ±1 grade).\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "476b158d",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T15:48:59.356313Z",
|
||||
"iopub.status.busy": "2026-06-07T15:48:59.355799Z",
|
||||
"iopub.status.idle": "2026-06-07T19:11:46.644946Z",
|
||||
"shell.execute_reply": "2026-06-07T19:11:46.644060Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history = []\n",
|
||||
@@ -420,7 +727,8 @@
|
||||
"if best_state is not None:\n",
|
||||
" model.load_state_dict(best_state)\n",
|
||||
"\n",
|
||||
"print(f\"\\nTraining complete. Best epoch: {best_epoch}, Best val MAE: {best_val_mae:.4f}\")"
|
||||
"print(f\"\\nTraining complete. Best epoch: {best_epoch}, Best val MAE: {best_val_mae:.4f}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -438,14 +746,22 @@
|
||||
"- **Within ±1 difficulty**: Percentage of predictions within 1 point\n",
|
||||
"- **Within ±1 V-grade**: Percentage of predictions within 1 V-grade\n",
|
||||
"\n",
|
||||
"We also break down performance by board (TB2 vs Kilter) to check for bias."
|
||||
"We also break down performance by board (TB2 vs Kilter) to check for bias.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9abc3a72",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:11:46.648067Z",
|
||||
"iopub.status.busy": "2026-06-07T19:11:46.647798Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:05.427217Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:05.426288Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_loss, test_pred, test_true, test_uuid, test_board = run_epoch(model, test_loader, device, optimizer=None)\n",
|
||||
@@ -467,7 +783,60 @@
|
||||
" print(f\"{key:24s}: {value:8.4f}{suffix}\")\n",
|
||||
"\n",
|
||||
"print(\"\\nBoard-specific test performance:\")\n",
|
||||
"print(board_metrics_df.to_string(index=False))"
|
||||
"print(board_metrics_df.to_string(index=False))\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "01c90e93",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### JSON output helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3027d982",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:05.430611Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:05.430084Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:05.436838Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:05.436135Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Write JSON artifacts after converting NumPy/pandas values to plain Python values.\n",
|
||||
"def json_safe(obj: Any) -> Any:\n",
|
||||
" \"\"\"Convert NumPy/pandas values into JSON-serializable Python objects.\"\"\"\n",
|
||||
" if isinstance(obj, dict):\n",
|
||||
" return {str(k): json_safe(v) for k, v in obj.items()}\n",
|
||||
" if isinstance(obj, (list, tuple)):\n",
|
||||
" return [json_safe(v) for v in obj]\n",
|
||||
" if isinstance(obj, np.integer):\n",
|
||||
" return int(obj)\n",
|
||||
" if isinstance(obj, np.floating):\n",
|
||||
" if np.isnan(obj):\n",
|
||||
" return None\n",
|
||||
" return float(obj)\n",
|
||||
" if isinstance(obj, np.ndarray):\n",
|
||||
" return json_safe(obj.tolist())\n",
|
||||
" try:\n",
|
||||
" if pd.isna(obj):\n",
|
||||
" return None\n",
|
||||
" except Exception:\n",
|
||||
" pass\n",
|
||||
" return obj\n",
|
||||
"\n",
|
||||
"def write_json(path: str | Path, payload: Any) -> None:\n",
|
||||
" \"\"\"Write an object as indented UTF-8 JSON after ``json_safe`` cleanup.\"\"\"\n",
|
||||
" path = Path(path)\n",
|
||||
" path.parent.mkdir(parents=True, exist_ok=True)\n",
|
||||
" path.write_text(json.dumps(json_safe(payload), indent=2), encoding=\"utf-8\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -477,14 +846,22 @@
|
||||
"source": [
|
||||
"## Save Model and Artifacts\n",
|
||||
"\n",
|
||||
"We save the trained model checkpoint and evaluation metrics for use in notebook 04 (route evaluation) and for future inference."
|
||||
"We save the trained model checkpoint and evaluation metrics for use in notebook 04 (route evaluation) and for future inference.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "save_model",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:05.439746Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:05.439205Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:05.604325Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:05.603607Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save model checkpoint\n",
|
||||
@@ -520,13 +897,14 @@
|
||||
"pred_df.to_csv(OUT_DIR / \"test_predictions.csv\", index=False)\n",
|
||||
"board_metrics_df.to_csv(OUT_DIR / \"board_metrics.csv\", index=False)\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.utils import write_json\n",
|
||||
"# write_json is defined in the JSON output helper cell above.\n",
|
||||
"write_json(OUT_DIR / \"overall_metrics.json\", overall_metrics)\n",
|
||||
"\n",
|
||||
"print(f\"Saved model checkpoint to: {model_path}\")\n",
|
||||
"print(f\"Saved training history to: {OUT_DIR / 'training_history.csv'}\")\n",
|
||||
"print(f\"Saved test predictions to: {OUT_DIR / 'test_predictions.csv'}\")\n",
|
||||
"print(f\"Saved board metrics to: {OUT_DIR / 'board_metrics.csv'}\")"
|
||||
"print(f\"Saved board metrics to: {OUT_DIR / 'board_metrics.csv'}\")\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -542,7 +920,8 @@
|
||||
"\n",
|
||||
"3. **Joint training across boards**: By training on both TB2 and Kilter data simultaneously, the model can share statistical strength. The board token (`<BOARD_TB2>` vs `<BOARD_KILTER>`) tells it which \"language\" it's operating in.\n",
|
||||
"\n",
|
||||
"4. **The gap between fine-grained and grouped metrics**: Being off by 1 difficulty point often stays within the same V-grade bucket. This is why the ±1 V-grade accuracy is much higher than the ±1 difficulty accuracy."
|
||||
"4. **The gap between fine-grained and grouped metrics**: Being off by 1 difficulty point often stays within the same V-grade bucket. This is why the ±1 V-grade accuracy is much higher than the ±1 difficulty accuracy.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -553,8 +932,16 @@
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"version": "3.14.4"
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -43,40 +43,58 @@
|
||||
"- `<ANGLE_40>`: At 40 degrees\n",
|
||||
"- `<GRADE_V6>`: At V6 difficulty\n",
|
||||
"\n",
|
||||
"This is analogous to how ChatGPT uses a system prompt to condition its responses."
|
||||
"This is analogous to how ChatGPT uses a system prompt to condition its responses.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b6590822",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:14.101804Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:14.101439Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:16.162395Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:16.161684Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"import sys\n",
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import ast\n",
|
||||
"import json\n",
|
||||
"import math\n",
|
||||
"import re\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from pathlib import Path\n",
|
||||
"from typing import Iterable\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import DataLoader\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from torch.utils.data import DataLoader, Dataset\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if ROOT.name == \"notebooks\":\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"sys.path.insert(0, str(ROOT / \"src\"))\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.config import load_board_configs\n",
|
||||
"from climbingboardgpt.datasets import RouteGPTDataset\n",
|
||||
"from climbingboardgpt.generation import generate_one\n",
|
||||
"from climbingboardgpt.models import JointRouteGPT"
|
||||
" ROOT = ROOT.parent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f09fdf54",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:16.166017Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:16.165651Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:21.900885Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:21.899985Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"TOKENIZED = ROOT / \"data\" / \"processed\" / \"tokenized\"\n",
|
||||
@@ -92,6 +110,59 @@
|
||||
"print(f\"Total routes: {len(df_routes):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4fcba532",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Causal dataset helper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "40021fc1",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:21.904008Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:21.903750Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:21.910270Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:21.909595Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Pad route-token sequences and create shifted input/target pairs for causal modeling.\n",
|
||||
"class RouteGPTDataset(Dataset):\n",
|
||||
" \"\"\"Dataset for causal next-token route generation.\n",
|
||||
"\n",
|
||||
" The full sequence is padded once, then split into ``input_ids`` and\n",
|
||||
" ``target_ids`` shifted by one position for teacher-forced language-model\n",
|
||||
" training.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, df, max_len: int, pad_id: int):\n",
|
||||
" \"\"\"Store GPT token ID sequences from a tokenized route DataFrame.\"\"\"\n",
|
||||
" self.ids = df[\"gpt_ids\"].tolist()\n",
|
||||
" self.max_len = int(max_len)\n",
|
||||
" self.pad_id = int(pad_id)\n",
|
||||
"\n",
|
||||
" def __len__(self) -> int:\n",
|
||||
" \"\"\"Return the number of route examples.\"\"\"\n",
|
||||
" return len(self.ids)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx: int):\n",
|
||||
" \"\"\"Return one padded causal-language-model training example.\"\"\"\n",
|
||||
" ids = list(self.ids[idx])[: self.max_len]\n",
|
||||
" if len(ids) < self.max_len:\n",
|
||||
" ids += [self.pad_id] * (self.max_len - len(ids))\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"input_ids\": torch.tensor(ids[:-1], dtype=torch.long),\n",
|
||||
" \"target_ids\": torch.tensor(ids[1:], dtype=torch.long),\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fe4b0faf",
|
||||
@@ -114,14 +185,22 @@
|
||||
"\n",
|
||||
"For the grade predictor (notebook 02), we excluded the grade because the model needed to predict it. But for the generator, we **include** the grade (`<GRADE_V6>`) in the training data so the model learns the relationship between grade and hold selection.\n",
|
||||
"\n",
|
||||
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade."
|
||||
"At generation time, we provide the grade as part of the prompt, and the model generates holds that are appropriate for that grade.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7ad61dbd",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:21.913286Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:21.912788Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:25.369590Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:25.368643Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def encode(tokens):\n",
|
||||
@@ -153,6 +232,132 @@
|
||||
"print(f\"Validation samples: {len(val_ds):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "552b9f69",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### GPT model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4085a314",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:25.373057Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:25.372812Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:25.388342Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:25.387476Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# GPT-style causal transformer for route generation.\n",
|
||||
"class JointRouteGPT(nn.Module):\n",
|
||||
" \"\"\"Tiny GPT-style causal transformer for board-conditioned route generation.\n",
|
||||
"\n",
|
||||
" PyTorch's ``TransformerEncoder`` is used with a causal mask, which makes it\n",
|
||||
" behave like a decoder-only language model for short route sequences.\n",
|
||||
"\n",
|
||||
" Why use ``TransformerEncoder`` rather than ``TransformerDecoder``?\n",
|
||||
" -------------------------------------------------------------------\n",
|
||||
" PyTorch's ``TransformerDecoderLayer`` expects two inputs: a decoder\n",
|
||||
" sequence and a separate encoder memory for cross-attention. For\n",
|
||||
" unconditional or prompt-conditioned generation there is no encoder,\n",
|
||||
" so ``TransformerDecoderLayer`` would always ignore the second input\n",
|
||||
" or require a dummy placeholder. Using ``TransformerEncoder`` with a\n",
|
||||
" causal mask avoids this mismatch, keeps the module list uniform,\n",
|
||||
" and produces identical behaviour for short autoregressive generation.\n",
|
||||
"\n",
|
||||
" The trade-off is that ``TransformerEncoder`` does not natively prevent\n",
|
||||
" attention to future positions — the causal mask must be constructed\n",
|
||||
" manually (see ``forward``). For the sequence lengths seen here\n",
|
||||
" (at most ~400 tokens) the overhead of the upper-triangular mask is\n",
|
||||
" negligible, and ``enable_nested_tensor=False`` is set to avoid SDPA\n",
|
||||
" optimisations that do not support masked encoders.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" vocab_size: int,\n",
|
||||
" block_size: int,\n",
|
||||
" n_embd: int = 128,\n",
|
||||
" n_head: int = 4,\n",
|
||||
" n_layer: int = 4,\n",
|
||||
" dropout: float = 0.10,\n",
|
||||
" pad_id: int = 0,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Create the token/position embeddings, causal blocks, and LM head.\"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.vocab_size = vocab_size\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.pad_id = pad_id\n",
|
||||
"\n",
|
||||
" self.token_emb = nn.Embedding(vocab_size, n_embd, padding_idx=pad_id)\n",
|
||||
" self.pos_emb = nn.Embedding(block_size, n_embd)\n",
|
||||
" self.drop = nn.Dropout(dropout)\n",
|
||||
"\n",
|
||||
" layer = nn.TransformerEncoderLayer(\n",
|
||||
" d_model=n_embd,\n",
|
||||
" nhead=n_head,\n",
|
||||
" dim_feedforward=4 * n_embd,\n",
|
||||
" dropout=dropout,\n",
|
||||
" activation=\"gelu\",\n",
|
||||
" batch_first=True,\n",
|
||||
" norm_first=True,\n",
|
||||
" )\n",
|
||||
" self.blocks = nn.TransformerEncoder(\n",
|
||||
" layer,\n",
|
||||
" num_layers=n_layer,\n",
|
||||
" enable_nested_tensor=False,\n",
|
||||
" )\n",
|
||||
" self.ln_f = nn.LayerNorm(n_embd)\n",
|
||||
" self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)\n",
|
||||
" self.lm_head.weight = self.token_emb.weight\n",
|
||||
"\n",
|
||||
" def forward(\n",
|
||||
" self,\n",
|
||||
" idx: torch.Tensor,\n",
|
||||
" targets: torch.Tensor | None = None,\n",
|
||||
" ) -> tuple[torch.Tensor, torch.Tensor | None]:\n",
|
||||
" \"\"\"Return next-token logits and, when targets are supplied, CE loss.\"\"\"\n",
|
||||
" _, seq_len = idx.shape\n",
|
||||
" if seq_len > self.block_size:\n",
|
||||
" idx = idx[:, -self.block_size :]\n",
|
||||
" seq_len = idx.shape[1]\n",
|
||||
"\n",
|
||||
" positions = torch.arange(seq_len, device=idx.device).unsqueeze(0)\n",
|
||||
" x = self.drop(self.token_emb(idx) + self.pos_emb(positions))\n",
|
||||
"\n",
|
||||
" causal_mask = torch.triu(\n",
|
||||
" torch.ones(seq_len, seq_len, device=idx.device, dtype=torch.bool),\n",
|
||||
" diagonal=1,\n",
|
||||
" )\n",
|
||||
" # Padding masks suppress attention to right-padded context tokens while\n",
|
||||
" # the causal mask suppresses attention to future positions.\n",
|
||||
" key_padding_mask = idx.eq(self.pad_id)\n",
|
||||
"\n",
|
||||
" h = self.blocks(\n",
|
||||
" x,\n",
|
||||
" mask=causal_mask,\n",
|
||||
" src_key_padding_mask=key_padding_mask,\n",
|
||||
" )\n",
|
||||
" h = self.ln_f(h)\n",
|
||||
" logits = self.lm_head(h)\n",
|
||||
"\n",
|
||||
" loss = None\n",
|
||||
" if targets is not None:\n",
|
||||
" loss = F.cross_entropy(\n",
|
||||
" logits.reshape(-1, logits.size(-1)),\n",
|
||||
" targets.reshape(-1),\n",
|
||||
" ignore_index=self.pad_id,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return logits, loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "66d98641",
|
||||
@@ -177,21 +382,29 @@
|
||||
"- `n_layer=4`: Number of transformer layers (GPT-2 small uses 12)\n",
|
||||
"- `dropout=0.10`: Dropout probability\n",
|
||||
"\n",
|
||||
"This is intentionally small — we're training on ~40K short sequences, not billions of long documents.\n",
|
||||
"This is intentionally small — we're training on a few hundred thousand short sequences, not billions of long documents.\n",
|
||||
"\n",
|
||||
"### Weight tying\n",
|
||||
"\n",
|
||||
"The output projection layer shares weights with the token embedding layer (`self.lm_head.weight = self.token_emb.weight`). This is a common technique that:\n",
|
||||
"- Reduces parameter count\n",
|
||||
"- Acts as a regularizer\n",
|
||||
"- Is used in GPT-2 and many other language models"
|
||||
"- Is used in GPT-2 and many other language models\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3eec6f35",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:25.391551Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:25.391044Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:27.304182Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:27.303257Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
@@ -216,7 +429,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f999cf05",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:27.307681Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:27.307034Z",
|
||||
"iopub.status.idle": "2026-06-07T19:12:27.314453Z",
|
||||
"shell.execute_reply": "2026-06-07T19:12:27.313491Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def train_epoch():\n",
|
||||
@@ -275,14 +495,22 @@
|
||||
"- A model that picks uniformly from a 1000-token vocab has perplexity = 1000\n",
|
||||
"- Good language models on English text achieve perplexity ~15-20\n",
|
||||
"\n",
|
||||
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns."
|
||||
"Our vocabulary is ~4000+ tokens, so a perplexity significantly below that indicates the model is learning meaningful patterns.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "70b38b02",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T19:12:27.317432Z",
|
||||
"iopub.status.busy": "2026-06-07T19:12:27.317008Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.199890Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.198963Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"history = []\n",
|
||||
@@ -335,6 +563,361 @@
|
||||
"print(f\"Best validation perplexity: {math.exp(min(best_val_loss, 20)):.1f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "20096c62",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Board configuration helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8d26d6d4",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.203421Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.203126Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.217533Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.216653Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find the project root and load board configuration JSON files.\n",
|
||||
"def find_project_root(start: str | Path | None = None) -> Path:\n",
|
||||
" \"\"\"Walk upward until the repository root markers are found.\n",
|
||||
"\n",
|
||||
" The project root is identified by both ``pyproject.toml`` and ``configs``.\n",
|
||||
" If neither marker pair is found, the resolved starting directory is returned\n",
|
||||
" so callers still have a deterministic base path.\n",
|
||||
" \"\"\"\n",
|
||||
" current = Path(start).resolve() if start is not None else Path.cwd().resolve()\n",
|
||||
" for candidate in [current, *current.parents]:\n",
|
||||
" if (candidate / \"pyproject.toml\").exists() and (candidate / \"configs\").exists():\n",
|
||||
" return candidate\n",
|
||||
" return current\n",
|
||||
"\n",
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class BoardConfig:\n",
|
||||
" \"\"\"Configuration for a single climbing board.\n",
|
||||
" \n",
|
||||
" This dataclass stores all board-specific settings needed for\n",
|
||||
" data loading, tokenization, and model training.\n",
|
||||
" \n",
|
||||
" Attributes:\n",
|
||||
" board_key: Short identifier (e.g., \"tb2\", \"kilter\")\n",
|
||||
" display_name: Human-readable name (e.g., \"Tension Board 2 Mirror\")\n",
|
||||
" token_prefix: Namespace for hold tokens (e.g., \"TB2\", \"KILTER\")\n",
|
||||
" db_path: Path to the SQLite database\n",
|
||||
" layout_id: Which layout in the database to use\n",
|
||||
" max_angle: Filter out routes steeper than this (None = no filter)\n",
|
||||
" min_fa_date: Filter out routes first ascended before this date\n",
|
||||
" placement_y_max: Filter out placements above this Y coordinate\n",
|
||||
" include_mirror_placement_id: Whether to include mirror info (TB2 only)\n",
|
||||
" role_definitions: Maps semantic role names to numeric IDs\n",
|
||||
" boardlib_database_command: Command to download the database\n",
|
||||
" boardlib_images_command: Command to download board images\n",
|
||||
" notes: Additional notes about the configuration\n",
|
||||
" \"\"\"\n",
|
||||
" board_key: str\n",
|
||||
" display_name: str\n",
|
||||
" token_prefix: str\n",
|
||||
" db_path: Path\n",
|
||||
" layout_id: int\n",
|
||||
" max_angle: float | None\n",
|
||||
" min_fa_date: str | None\n",
|
||||
" placement_y_max: float | None\n",
|
||||
" include_mirror_placement_id: bool\n",
|
||||
" role_definitions: dict[str, int]\n",
|
||||
" boardlib_database_command: str | None = None\n",
|
||||
" boardlib_images_command: str | None = None\n",
|
||||
" notes: tuple[str, ...] = ()\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def role_id_to_name(self) -> dict[int, str]:\n",
|
||||
" \"\"\"Reverse mapping from numeric role IDs to semantic role names.\n",
|
||||
" \n",
|
||||
" Example: {5: 'start', 6: 'middle', 7: 'finish', 8: 'foot'} for TB2\n",
|
||||
" \"\"\"\n",
|
||||
" return {int(role_id): name for name, role_id in self.role_definitions.items()}\n",
|
||||
"\n",
|
||||
" @property\n",
|
||||
" def board_token(self) -> str:\n",
|
||||
" \"\"\"The special token representing this board.\n",
|
||||
" \n",
|
||||
" Example: \"<BOARD_TB2>\" or \"<BOARD_KILTER>\"\n",
|
||||
" \"\"\"\n",
|
||||
" return f\"<BOARD_{self.token_prefix}>\"\n",
|
||||
"\n",
|
||||
" def resolve_db_path(self, project_root: Path | None = None) -> Path:\n",
|
||||
" \"\"\"Resolve the database path relative to the project root.\n",
|
||||
" \n",
|
||||
" If db_path is absolute, return it as-is.\n",
|
||||
" Otherwise, resolve it relative to the project root.\n",
|
||||
" \"\"\"\n",
|
||||
" project_root = project_root or find_project_root()\n",
|
||||
" return self.db_path if self.db_path.is_absolute() else project_root / self.db_path\n",
|
||||
"\n",
|
||||
"def load_board_config(board_key: str, config_dir: str | Path | None = None) -> BoardConfig:\n",
|
||||
" \"\"\"Load a single board configuration from a JSON file.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" board_key: Board identifier (e.g., \"tb2\", \"kilter\")\n",
|
||||
" config_dir: Directory containing config JSON files\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" BoardConfig dataclass with all board settings\n",
|
||||
" \n",
|
||||
" Raises:\n",
|
||||
" FileNotFoundError: If the config file doesn't exist\n",
|
||||
" \"\"\"\n",
|
||||
" project_root = find_project_root()\n",
|
||||
" config_dir = Path(config_dir) if config_dir is not None else project_root / \"configs\"\n",
|
||||
" path = config_dir / f\"{board_key}.json\"\n",
|
||||
" if not path.exists():\n",
|
||||
" available = sorted(p.stem for p in config_dir.glob(\"*.json\"))\n",
|
||||
" raise FileNotFoundError(\n",
|
||||
" f\"Unknown board config '{board_key}'. Available: {available}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" payload = json.loads(path.read_text(encoding=\"utf-8\"))\n",
|
||||
" return BoardConfig(\n",
|
||||
" board_key=str(payload[\"board_key\"]),\n",
|
||||
" display_name=str(payload[\"display_name\"]),\n",
|
||||
" token_prefix=str(payload[\"token_prefix\"]),\n",
|
||||
" db_path=Path(payload[\"db_path\"]),\n",
|
||||
" layout_id=int(payload[\"layout_id\"]),\n",
|
||||
" max_angle=None if payload.get(\"max_angle\") is None else float(payload[\"max_angle\"]),\n",
|
||||
" min_fa_date=payload.get(\"min_fa_date\"),\n",
|
||||
" placement_y_max=None if payload.get(\"placement_y_max\") is None else float(payload[\"placement_y_max\"]),\n",
|
||||
" include_mirror_placement_id=bool(payload.get(\"include_mirror_placement_id\", False)),\n",
|
||||
" role_definitions={str(k): int(v) for k, v in payload[\"role_definitions\"].items()},\n",
|
||||
" boardlib_database_command=payload.get(\"boardlib_database_command\"),\n",
|
||||
" boardlib_images_command=payload.get(\"boardlib_images_command\"),\n",
|
||||
" notes=tuple(payload.get(\"notes\", [])),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"def load_board_configs(board_keys: list[str] | tuple[str, ...]) -> list[BoardConfig]:\n",
|
||||
" \"\"\"Load multiple board configurations.\n",
|
||||
" \n",
|
||||
" Args:\n",
|
||||
" board_keys: List of board identifiers\n",
|
||||
" \n",
|
||||
" Returns:\n",
|
||||
" List of BoardConfig dataclasses\n",
|
||||
" \"\"\"\n",
|
||||
" return [load_board_config(board_key) for board_key in board_keys]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "94d352c6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Generation helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "abdabe8e",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.220501Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.220197Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.241455Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.240589Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse generated hold tokens back into structured hold records.\n",
|
||||
"HOLD_TOKEN_PATTERN = re.compile(r\"^<([A-Z0-9_]+)_p(\\d+)_(start|middle|finish|foot|unknown)>$\")\n",
|
||||
"\n",
|
||||
"def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
||||
" \"\"\"Extract hold records from model tokens using the shared hold-token grammar.\"\"\"\n",
|
||||
" rows: list[dict[str, object]] = []\n",
|
||||
" for token in tokens:\n",
|
||||
" match = HOLD_TOKEN_PATTERN.match(str(token))\n",
|
||||
" if match is None:\n",
|
||||
" continue\n",
|
||||
" board_prefix = match.group(1)\n",
|
||||
" rows.append(\n",
|
||||
" {\n",
|
||||
" \"token\": str(token),\n",
|
||||
" \"board_token_prefix\": board_prefix,\n",
|
||||
" \"board_prefix\": board_prefix,\n",
|
||||
" \"placement_id\": int(match.group(2)),\n",
|
||||
" \"role\": match.group(3),\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" return rows\n",
|
||||
"\n",
|
||||
"# Sample routes from the trained GPT model and convert them back to frames strings.\n",
|
||||
"def top_k_filter(logits: torch.Tensor, k: int | None) -> torch.Tensor:\n",
|
||||
" \"\"\"Mask logits outside the top ``k`` choices for each batch row.\"\"\"\n",
|
||||
" if k is None or k <= 0 or k >= logits.size(-1):\n",
|
||||
" return logits\n",
|
||||
" values, _ = torch.topk(logits, k)\n",
|
||||
" cutoff = values[:, [-1]]\n",
|
||||
" return torch.where(logits < cutoff, torch.full_like(logits, -float(\"inf\")), logits)\n",
|
||||
"\n",
|
||||
"@torch.no_grad()\n",
|
||||
"def sample_ids(\n",
|
||||
" model,\n",
|
||||
" prompt_ids: list[int],\n",
|
||||
" device: torch.device,\n",
|
||||
" max_new_tokens: int = 40,\n",
|
||||
" temperature: float = 0.9,\n",
|
||||
" top_k: int | None = 50,\n",
|
||||
" eos_id: int | None = None,\n",
|
||||
" forbidden_ids: Iterable[int] | None = None,\n",
|
||||
") -> list[int]:\n",
|
||||
" \"\"\"Autoregressively sample token IDs from a trained route generator.\n",
|
||||
"\n",
|
||||
" The returned list includes the prompt IDs and all sampled IDs up to either\n",
|
||||
" ``max_new_tokens`` or the first sampled ``eos_id``.\n",
|
||||
" \"\"\"\n",
|
||||
" model.eval()\n",
|
||||
" sequence = torch.tensor([prompt_ids], dtype=torch.long, device=device)\n",
|
||||
" forbidden_ids = set(forbidden_ids or [])\n",
|
||||
"\n",
|
||||
" for _ in range(max_new_tokens):\n",
|
||||
" idx_cond = sequence[:, -model.block_size :]\n",
|
||||
" logits, _ = model(idx_cond)\n",
|
||||
" logits = logits[:, -1, :] / max(temperature, 1e-6)\n",
|
||||
"\n",
|
||||
" # Special tokens like <PAD> and <CLS> are valid vocabulary entries but\n",
|
||||
" # should never be emitted in the middle of a generated climb.\n",
|
||||
" for token_id in forbidden_ids:\n",
|
||||
" logits[:, int(token_id)] = -float(\"inf\")\n",
|
||||
"\n",
|
||||
" logits = top_k_filter(logits, top_k)\n",
|
||||
" probs = F.softmax(logits, dim=-1)\n",
|
||||
" next_id = torch.multinomial(probs, num_samples=1)\n",
|
||||
" sequence = torch.cat([sequence, next_id], dim=1)\n",
|
||||
"\n",
|
||||
" if eos_id is not None and int(next_id.item()) == int(eos_id):\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" return sequence[0].detach().cpu().tolist()\n",
|
||||
"\n",
|
||||
"def prompt_tokens(board_prefix: str, angle: int, grouped_v: int) -> list[str]:\n",
|
||||
" \"\"\"Build the conditioning prefix used before sampling hold tokens.\"\"\"\n",
|
||||
" return [\n",
|
||||
" \"<BOS>\",\n",
|
||||
" f\"<BOARD_{board_prefix}>\",\n",
|
||||
" f\"<ANGLE_{int(angle)}>\",\n",
|
||||
" f\"<GRADE_V{int(grouped_v)}>\",\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
||||
" \"\"\"Extract hold records from generated tokens.\"\"\"\n",
|
||||
" return tokens_to_hold_records(tokens)\n",
|
||||
"\n",
|
||||
"def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:\n",
|
||||
" \"\"\"Summarize basic structural validity for generated token sequences.\"\"\"\n",
|
||||
" records = hold_records(tokens)\n",
|
||||
" placements = [record[\"placement_id\"] for record in records]\n",
|
||||
" roles = [record[\"role\"] for record in records]\n",
|
||||
" prefixes = [record[\"board_prefix\"] for record in records]\n",
|
||||
"\n",
|
||||
" one_board_only = len(set(prefixes)) <= 1\n",
|
||||
" matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)\n",
|
||||
" no_duplicates = len(placements) == len(set(placements))\n",
|
||||
" has_start = \"start\" in roles\n",
|
||||
" has_finish = \"finish\" in roles\n",
|
||||
" enough_holds = len(records) >= 3\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"n_hold_tokens\": len(records),\n",
|
||||
" \"n_unique_placements\": len(set(placements)),\n",
|
||||
" \"has_duplicate_placements\": not no_duplicates,\n",
|
||||
" \"one_board_only\": one_board_only,\n",
|
||||
" \"matches_requested_board\": matches_requested_board,\n",
|
||||
" \"has_start\": has_start,\n",
|
||||
" \"has_middle\": \"middle\" in roles,\n",
|
||||
" \"has_finish\": has_finish,\n",
|
||||
" \"n_start\": roles.count(\"start\"),\n",
|
||||
" \"n_middle\": roles.count(\"middle\"),\n",
|
||||
" \"n_foot\": roles.count(\"foot\"),\n",
|
||||
" \"n_finish\": roles.count(\"finish\"),\n",
|
||||
" \"basic_valid\": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:\n",
|
||||
" \"\"\"Convert generated hold tokens back into a frames string.\n",
|
||||
"\n",
|
||||
" Duplicate placements and unknown roles are skipped, matching the forgiving\n",
|
||||
" cleanup used by the demo scripts and webapp.\n",
|
||||
" \"\"\"\n",
|
||||
" pieces = []\n",
|
||||
" seen = set()\n",
|
||||
" for record in hold_records(tokens):\n",
|
||||
" if board_prefix is not None and str(record[\"board_prefix\"]) != board_prefix:\n",
|
||||
" continue\n",
|
||||
" placement_id = int(record[\"placement_id\"])\n",
|
||||
" role = str(record[\"role\"])\n",
|
||||
" if placement_id in seen or role not in role_name_to_id:\n",
|
||||
" continue\n",
|
||||
" seen.add(placement_id)\n",
|
||||
" pieces.append(f\"p{placement_id}r{int(role_name_to_id[role])}\")\n",
|
||||
" return \"\".join(pieces)\n",
|
||||
"\n",
|
||||
"def generate_one(\n",
|
||||
" model,\n",
|
||||
" stoi: dict[str, int],\n",
|
||||
" itos: dict[int, str],\n",
|
||||
" device: torch.device,\n",
|
||||
" board_prefix: str,\n",
|
||||
" angle: int,\n",
|
||||
" grouped_v: int,\n",
|
||||
" role_name_to_id: dict[str, int],\n",
|
||||
" temperature: float = 0.9,\n",
|
||||
" top_k: int | None = 50,\n",
|
||||
" max_new_tokens: int = 40,\n",
|
||||
") -> dict[str, object]:\n",
|
||||
" \"\"\"Generate one route and return tokens, frames, request metadata, validity.\"\"\"\n",
|
||||
" unk_id = stoi[\"<UNK>\"]\n",
|
||||
" eos_id = stoi[\"<EOS>\"]\n",
|
||||
" forbidden_ids = [\n",
|
||||
" stoi[\"<PAD>\"],\n",
|
||||
" stoi[\"<UNK>\"],\n",
|
||||
" stoi[\"<BOS>\"],\n",
|
||||
" stoi[\"<CLS>\"],\n",
|
||||
" stoi[\"<MASK>\"],\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" prompt = prompt_tokens(board_prefix, angle, grouped_v)\n",
|
||||
" prompt_ids = [stoi.get(token, unk_id) for token in prompt]\n",
|
||||
" token_ids = sample_ids(\n",
|
||||
" model=model,\n",
|
||||
" prompt_ids=prompt_ids,\n",
|
||||
" device=device,\n",
|
||||
" max_new_tokens=max_new_tokens,\n",
|
||||
" temperature=temperature,\n",
|
||||
" top_k=top_k,\n",
|
||||
" eos_id=eos_id,\n",
|
||||
" forbidden_ids=forbidden_ids,\n",
|
||||
" )\n",
|
||||
" tokens = [itos.get(int(idx), \"<UNK>\") for idx in token_ids]\n",
|
||||
" validity = validity_summary(tokens, requested_board_prefix=board_prefix)\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" \"requested_board_prefix\": board_prefix,\n",
|
||||
" \"requested_angle\": int(angle),\n",
|
||||
" \"requested_grouped_v\": int(grouped_v),\n",
|
||||
" \"temperature\": float(temperature),\n",
|
||||
" \"top_k\": None if top_k is None else int(top_k),\n",
|
||||
" \"tokens\": tokens,\n",
|
||||
" \"sequence\": \" \".join(tokens),\n",
|
||||
" \"frames\": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),\n",
|
||||
" **validity,\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "69926180",
|
||||
@@ -356,14 +939,22 @@
|
||||
"- **Temperature** (default 0.9): Controls randomness. Lower = more deterministic, higher = more random\n",
|
||||
"- **Top-k** (default 50): Only consider the k most likely tokens. This prevents the model from generating very unlikely tokens.\n",
|
||||
"\n",
|
||||
"These are the same techniques used in language models like GPT-3 to control output diversity."
|
||||
"These are the same techniques used in language models like GPT-3 to control output diversity.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "029eb911",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.244254Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.244037Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:38.680983Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:38.679992Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate sample routes for both boards\n",
|
||||
@@ -400,14 +991,22 @@
|
||||
"source": [
|
||||
"## Generate More Routes for Evaluation\n",
|
||||
"\n",
|
||||
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards."
|
||||
"Notebook 04 needs a larger set of generated routes for meaningful evaluation. Let's generate routes across multiple angles and grades for both boards.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "generate_bulk",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:38.684476Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:38.683935Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:53.779260Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:53.778391Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Generate routes across multiple angles and grades for evaluation\n",
|
||||
@@ -454,14 +1053,22 @@
|
||||
"source": [
|
||||
"## Save Model and Generated Routes\n",
|
||||
"\n",
|
||||
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation)."
|
||||
"We save the trained model checkpoint and generated routes for use in notebook 04 (evaluation).\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "save_outputs",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:43:53.782874Z",
|
||||
"iopub.status.busy": "2026-06-07T23:43:53.782303Z",
|
||||
"iopub.status.idle": "2026-06-07T23:43:53.831685Z",
|
||||
"shell.execute_reply": "2026-06-07T23:43:53.830785Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
@@ -509,8 +1116,16 @@
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"version": "3.11"
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -19,8 +19,8 @@
|
||||
"### Validity checks\n",
|
||||
"\n",
|
||||
"A \"basic valid\" route must have:\n",
|
||||
"- At least 3 holds (you need at least 2 hands + 1 foot to climb)\n",
|
||||
"- No duplicate placements (you can't use the same hold twice)\n",
|
||||
"- At least 3 holds\n",
|
||||
"- No duplicate placements\n",
|
||||
"- At least one start hold and one finish hold\n",
|
||||
"- All holds from the same board (no mixing TB2 and Kilter holds)\n",
|
||||
"\n",
|
||||
@@ -35,46 +35,55 @@
|
||||
"- Jaccard similarity = |A intersection B| / |A union B|\n",
|
||||
"- Novelty distance = 1 - Jaccard similarity\n",
|
||||
"\n",
|
||||
"A novelty distance of 1.0 means the generated route shares no holds with any real route. A distance of 0.0 means it's identical to an existing route."
|
||||
"A novelty distance of 1.0 means the generated route shares no holds with any real route. A distance of 0.0 means it's identical to an existing route.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "726b846f",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:02.200057Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:02.199717Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:04.626359Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:04.625624Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import ast\n",
|
||||
"import re\n",
|
||||
"from pathlib import Path\n",
|
||||
"import sys\n",
|
||||
"from typing import Iterable\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from scipy.spatial.distance import pdist\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if ROOT.name == \"notebooks\":\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"sys.path.insert(0, str(ROOT / \"src\"))\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.evaluation import (\n",
|
||||
" build_placement_coords,\n",
|
||||
" frames_to_holds,\n",
|
||||
" holds_to_placement_set,\n",
|
||||
" nearest_real_route_same_board,\n",
|
||||
" parse_token_list,\n",
|
||||
" simple_route_features,\n",
|
||||
" tokens_to_hold_records,\n",
|
||||
" validity_from_records,\n",
|
||||
")\n",
|
||||
"from climbingboardgpt.grades import to_grouped_v\n",
|
||||
"from climbingboardgpt.models import JointRouteTransformerRegressor"
|
||||
" ROOT = ROOT.parent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f8bb61f",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:04.629832Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:04.629390Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.364160Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.363335Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load generated routes and real routes for comparison\n",
|
||||
@@ -111,6 +120,107 @@
|
||||
"print(f\"Real routes: {len(df_real):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6dc0ac67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Token parsing and validity helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c32f7ced",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.368243Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.367603Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.380028Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.379312Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse generated token strings and compute basic route-validity flags.\n",
|
||||
"HOLD_TOKEN_PATTERN = re.compile(r\"^<([A-Z0-9_]+)_p(\\d+)_(start|middle|finish|foot|unknown)>$\")\n",
|
||||
"\n",
|
||||
"def parse_tokens(value) -> list[str]:\n",
|
||||
" \"\"\"Parse tokens from a list, repr-style list string, or whitespace sequence.\"\"\"\n",
|
||||
" if isinstance(value, list):\n",
|
||||
" return [str(v) for v in value]\n",
|
||||
" if not isinstance(value, str):\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" parsed = ast.literal_eval(value)\n",
|
||||
" if isinstance(parsed, list):\n",
|
||||
" return [str(v) for v in parsed]\n",
|
||||
" except (SyntaxError, ValueError):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" return value.split()\n",
|
||||
"\n",
|
||||
"def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
||||
" \"\"\"Extract hold records from model tokens using the shared hold-token grammar.\"\"\"\n",
|
||||
" rows: list[dict[str, object]] = []\n",
|
||||
" for token in tokens:\n",
|
||||
" match = HOLD_TOKEN_PATTERN.match(str(token))\n",
|
||||
" if match is None:\n",
|
||||
" continue\n",
|
||||
" board_prefix = match.group(1)\n",
|
||||
" rows.append(\n",
|
||||
" {\n",
|
||||
" \"token\": str(token),\n",
|
||||
" \"board_token_prefix\": board_prefix,\n",
|
||||
" \"board_prefix\": board_prefix,\n",
|
||||
" \"placement_id\": int(match.group(2)),\n",
|
||||
" \"role\": match.group(3),\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" return rows\n",
|
||||
"\n",
|
||||
"def parse_token_list(value) -> list[str]:\n",
|
||||
" \"\"\"Compatibility wrapper around the shared token parser.\"\"\"\n",
|
||||
" return parse_tokens(value)\n",
|
||||
"\n",
|
||||
"def validity_from_records(records: list[dict[str, object]], requested_board_prefix: str | None = None) -> dict[str, object]:\n",
|
||||
" \"\"\"Compute evaluation-specific route-validity flags from hold records.\"\"\"\n",
|
||||
" placements = [int(record[\"placement_id\"]) for record in records]\n",
|
||||
" roles = [str(record[\"role\"]) for record in records]\n",
|
||||
" prefixes = [str(record[\"board_token_prefix\"]) for record in records]\n",
|
||||
" one_board_only = len(set(prefixes)) <= 1\n",
|
||||
" matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)\n",
|
||||
"\n",
|
||||
" out = {\n",
|
||||
" \"n_holds_eval\": len(records),\n",
|
||||
" \"n_unique_placements_eval\": len(set(placements)),\n",
|
||||
" \"has_duplicate_placements_eval\": len(records) != len(set(placements)),\n",
|
||||
" \"one_board_only_eval\": one_board_only,\n",
|
||||
" \"matches_requested_board_eval\": matches_requested_board,\n",
|
||||
" \"n_start_eval\": roles.count(\"start\"),\n",
|
||||
" \"n_middle_eval\": roles.count(\"middle\"),\n",
|
||||
" \"n_foot_eval\": roles.count(\"foot\"),\n",
|
||||
" \"n_finish_eval\": roles.count(\"finish\"),\n",
|
||||
" \"has_start_eval\": \"start\" in roles,\n",
|
||||
" \"has_middle_eval\": \"middle\" in roles,\n",
|
||||
" \"has_finish_eval\": \"finish\" in roles,\n",
|
||||
" }\n",
|
||||
" out[\"basic_valid_eval\"] = (\n",
|
||||
" one_board_only\n",
|
||||
" and out[\"n_holds_eval\"] >= 3\n",
|
||||
" and out[\"n_holds_eval\"] == out[\"n_unique_placements_eval\"]\n",
|
||||
" and out[\"has_start_eval\"]\n",
|
||||
" and out[\"has_finish_eval\"]\n",
|
||||
" )\n",
|
||||
" out[\"strict_valid_eval\"] = (\n",
|
||||
" out[\"basic_valid_eval\"]\n",
|
||||
" and out[\"has_middle_eval\"]\n",
|
||||
" and out[\"n_holds_eval\"] >= 4\n",
|
||||
" )\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0091bafb",
|
||||
@@ -118,14 +228,22 @@
|
||||
"source": [
|
||||
"## Parse generated tokens and check validity\n",
|
||||
"\n",
|
||||
"We parse the generated token sequences and check each route for validity."
|
||||
"We parse the generated token sequences and check each route for validity.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f5c2b25a",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.383121Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.382759Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.430410Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.429741Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse the token strings into structured records\n",
|
||||
@@ -149,6 +267,89 @@
|
||||
"print(validity_summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0ff31b72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Novelty helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10a40f48",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.434117Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.433720Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.443045Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.442319Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Compare generated hold sets to real routes using Jaccard similarity.\n",
|
||||
"def frames_to_holds(frames: str | None) -> list[tuple[int, int]]:\n",
|
||||
" \"\"\"Parse a frames string into ``(placement_id, role_id)`` pairs.\"\"\"\n",
|
||||
" if not isinstance(frames, str):\n",
|
||||
" return []\n",
|
||||
" return [(int(p), int(r)) for p, r in re.findall(r\"p(\\d+)r(\\d+)\", frames)]\n",
|
||||
"\n",
|
||||
"def holds_to_placement_set(holds: Iterable[tuple[int, int]]) -> frozenset[int]:\n",
|
||||
" \"\"\"Drop role IDs and represent a route by its unique placement IDs.\"\"\"\n",
|
||||
" return frozenset(int(placement_id) for placement_id, _ in holds)\n",
|
||||
"\n",
|
||||
"def jaccard(a: frozenset[int], b: frozenset[int]) -> float:\n",
|
||||
" \"\"\"Return Jaccard similarity between two placement sets.\"\"\"\n",
|
||||
" if not a and not b:\n",
|
||||
" return 1.0\n",
|
||||
" if not a or not b:\n",
|
||||
" return 0.0\n",
|
||||
" return len(a & b) / len(a | b)\n",
|
||||
"\n",
|
||||
"def nearest_real_route_same_board(\n",
|
||||
" generated_set: frozenset[int],\n",
|
||||
" generated_board_key: str,\n",
|
||||
" real_df: pd.DataFrame,\n",
|
||||
") -> dict[str, object]:\n",
|
||||
" \"\"\"Find the most similar real route on the same board by Jaccard score.\n",
|
||||
"\n",
|
||||
" .. note::\n",
|
||||
"\n",
|
||||
" This function performs an O(n) linear scan over all real routes for\n",
|
||||
" the matching board, computing a Jaccard similarity for each one. With\n",
|
||||
" ~256K training examples, evaluating 400 generated routes costs roughly\n",
|
||||
" O(100M) Jaccard comparisons. This is acceptable for evaluation scripts\n",
|
||||
" but would not scale to a real-time or high-throughput setting without\n",
|
||||
" an approximate nearest-neighbour index.\n",
|
||||
" \"\"\"\n",
|
||||
" board_frame = real_df[real_df[\"board_key\"] == generated_board_key]\n",
|
||||
" if board_frame.empty:\n",
|
||||
" return {\n",
|
||||
" \"nearest_real_jaccard\": np.nan,\n",
|
||||
" \"nearest_real_uuid\": None,\n",
|
||||
" \"nearest_real_name\": None,\n",
|
||||
" \"nearest_real_grouped_v\": None,\n",
|
||||
" \"nearest_real_angle\": None,\n",
|
||||
" \"novelty_distance\": np.nan,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" similarities = board_frame[\"hold_set\"].map(lambda hold_set: jaccard(generated_set, hold_set))\n",
|
||||
" best_idx = similarities.idxmax()\n",
|
||||
" row = board_frame.loc[best_idx]\n",
|
||||
"\n",
|
||||
" nearest_real_jaccard = float(similarities.loc[best_idx])\n",
|
||||
" return {\n",
|
||||
" \"nearest_real_jaccard\": nearest_real_jaccard,\n",
|
||||
" \"nearest_real_uuid\": row[\"uuid\"],\n",
|
||||
" \"nearest_real_name\": row[\"climb_name\"],\n",
|
||||
" \"nearest_real_grouped_v\": row[\"grouped_v\"],\n",
|
||||
" \"nearest_real_angle\": row[\"angle\"],\n",
|
||||
" \"novelty_distance\": 1.0 - nearest_real_jaccard,\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0cf2170e",
|
||||
@@ -156,14 +357,22 @@
|
||||
"source": [
|
||||
"## Novelty against real climbs\n",
|
||||
"\n",
|
||||
"For each generated route, we find the most similar real route from the same board (by Jaccard similarity of hold sets). A good generator should produce routes that are novel (low Jaccard similarity to existing routes) while still being valid."
|
||||
"For each generated route, we find the most similar real route from the same board (by Jaccard similarity of hold sets). A good generator should produce routes that are novel (low Jaccard similarity to existing routes) while still being valid.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7f34524",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.446422Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.445998Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:41.914124Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:41.913292Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert hold sets to frozensets for fast comparison\n",
|
||||
@@ -201,6 +410,105 @@
|
||||
"print(novelty_summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ad70ff4c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Geometry helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "85ddaf53",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:41.918125Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:41.917658Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:41.929570Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:41.928790Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Compute simple geometric descriptors from placement coordinates.\n",
|
||||
"def build_placement_coords(df_token_meta: pd.DataFrame) -> dict[tuple[str, int], dict[str, float]]:\n",
|
||||
" \"\"\"Build a placement-coordinate lookup from token metadata.\"\"\"\n",
|
||||
" hold_meta = df_token_meta[df_token_meta[\"kind\"] == \"hold\"].dropna(subset=[\"placement_id\"]).copy()\n",
|
||||
" coords = {}\n",
|
||||
" for _, row in hold_meta.drop_duplicates([\"board_key\", \"placement_id\"]).iterrows():\n",
|
||||
" key = (str(row[\"board_key\"]), int(row[\"placement_id\"]))\n",
|
||||
" coords[key] = {\n",
|
||||
" \"x\": float(row[\"x\"]),\n",
|
||||
" \"y\": float(row[\"y\"]),\n",
|
||||
" }\n",
|
||||
" return coords\n",
|
||||
"\n",
|
||||
"def simple_route_features(\n",
|
||||
" board_key: str,\n",
|
||||
" records: list[dict[str, object]],\n",
|
||||
" placement_coords: dict[tuple[str, int], dict[str, float]],\n",
|
||||
") -> dict[str, float]:\n",
|
||||
" \"\"\"Compute simple geometric route features from hold coordinates.\n",
|
||||
"\n",
|
||||
" These features are descriptive rather than a full climbing-physics model:\n",
|
||||
" height/width describe route spread, and hand-reach distances summarize the\n",
|
||||
" pairwise spacing among start/middle/finish holds.\n",
|
||||
" \"\"\"\n",
|
||||
" rows = []\n",
|
||||
" for record in records:\n",
|
||||
" key = (str(board_key), int(record[\"placement_id\"]))\n",
|
||||
" coord = placement_coords.get(key)\n",
|
||||
" if coord is None:\n",
|
||||
" continue\n",
|
||||
" x = float(coord[\"x\"])\n",
|
||||
" y = float(coord[\"y\"])\n",
|
||||
" if np.isnan(x) or np.isnan(y):\n",
|
||||
" continue\n",
|
||||
" role = str(record[\"role\"])\n",
|
||||
" rows.append(\n",
|
||||
" {\n",
|
||||
" \"x\": x,\n",
|
||||
" \"y\": y,\n",
|
||||
" \"role\": role,\n",
|
||||
" \"is_hand\": role in {\"start\", \"middle\", \"finish\"},\n",
|
||||
" \"is_foot\": role == \"foot\",\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if not rows:\n",
|
||||
" return {\n",
|
||||
" \"geom_n_holds\": 0.0,\n",
|
||||
" \"geom_height\": np.nan,\n",
|
||||
" \"geom_width\": np.nan,\n",
|
||||
" \"geom_mean_y\": np.nan,\n",
|
||||
" \"geom_mean_x_abs\": np.nan,\n",
|
||||
" \"geom_mean_hand_reach\": np.nan,\n",
|
||||
" \"geom_max_hand_reach\": np.nan,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" d = pd.DataFrame(rows)\n",
|
||||
" out = {\n",
|
||||
" \"geom_n_holds\": float(len(d)),\n",
|
||||
" \"geom_height\": float(d[\"y\"].max() - d[\"y\"].min()),\n",
|
||||
" \"geom_width\": float(d[\"x\"].max() - d[\"x\"].min()),\n",
|
||||
" \"geom_mean_y\": float(d[\"y\"].mean()),\n",
|
||||
" \"geom_mean_x_abs\": float(d[\"x\"].abs().mean()),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" hands = d[d[\"is_hand\"]].sort_values([\"y\", \"x\"])\n",
|
||||
" if len(hands) >= 2:\n",
|
||||
" distances = pdist(hands[[\"x\", \"y\"]].values)\n",
|
||||
" out[\"geom_mean_hand_reach\"] = float(distances.mean())\n",
|
||||
" out[\"geom_max_hand_reach\"] = float(distances.max())\n",
|
||||
" else:\n",
|
||||
" out[\"geom_mean_hand_reach\"] = np.nan\n",
|
||||
" out[\"geom_max_hand_reach\"] = np.nan\n",
|
||||
"\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b581705d",
|
||||
@@ -215,14 +523,22 @@
|
||||
"- `geom_width`: Horizontal extent\n",
|
||||
"- `geom_mean_hand_reach`: Average distance between hand holds\n",
|
||||
"\n",
|
||||
"These features help us understand whether generated routes have reasonable spatial properties."
|
||||
"These features help us understand whether generated routes have reasonable spatial properties.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d74d4cad",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:41.932520Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:41.932262Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:42.775565Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:42.774476Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build coordinate lookup from token metadata\n",
|
||||
@@ -252,6 +568,134 @@
|
||||
"print(geom_summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "44036a1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Critic model and grade helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9cfff1f4",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:42.779195Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:42.778895Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:42.791727Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:42.790706Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Map BoardLib display difficulties into grouped V-grade tokens.\n",
|
||||
"GRADE_TO_V = {\n",
|
||||
" 10: 0, 11: 0, 12: 0,\n",
|
||||
" 13: 1, 14: 1,\n",
|
||||
" 15: 2,\n",
|
||||
" 16: 3, 17: 3,\n",
|
||||
" 18: 4, 19: 4,\n",
|
||||
" 20: 5, 21: 5,\n",
|
||||
" 22: 6,\n",
|
||||
" 23: 7,\n",
|
||||
" 24: 8, 25: 8,\n",
|
||||
" 26: 9,\n",
|
||||
" 27: 10,\n",
|
||||
" 28: 11,\n",
|
||||
" 29: 12,\n",
|
||||
" 30: 13,\n",
|
||||
" 31: 14,\n",
|
||||
" 32: 15,\n",
|
||||
" 33: 16,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def to_grouped_v(display_difficulty: float) -> int:\n",
|
||||
" \"\"\"Map a continuous display difficulty to the nearest grouped V grade.\"\"\"\n",
|
||||
" rounded = int(round(float(display_difficulty)))\n",
|
||||
" rounded = max(min(rounded, max(GRADE_TO_V)), min(GRADE_TO_V))\n",
|
||||
" return GRADE_TO_V[rounded]\n",
|
||||
"\n",
|
||||
"def grade_token(display_difficulty: float) -> str:\n",
|
||||
" \"\"\"Return the grade-conditioning token for a display difficulty value.\"\"\"\n",
|
||||
" return f\"<GRADE_V{to_grouped_v(display_difficulty)}>\"\n",
|
||||
"\n",
|
||||
"# Transformer encoder used as a continuous grade regressor.\n",
|
||||
"class JointRouteTransformerRegressor(nn.Module):\n",
|
||||
" \"\"\"Transformer encoder for joint TB2/Kilter route difficulty prediction.\n",
|
||||
"\n",
|
||||
" Inputs are token IDs plus an attention mask. Token, position, and learned\n",
|
||||
" projections of coordinate metadata are added before the encoder. The first\n",
|
||||
" ``<CLS>`` position is then used as a pooled route representation for scalar\n",
|
||||
" difficulty regression.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" vocab_size: int,\n",
|
||||
" max_len: int,\n",
|
||||
" coord_features: torch.Tensor,\n",
|
||||
" d_model: int = 128,\n",
|
||||
" nhead: int = 4,\n",
|
||||
" num_layers: int = 4,\n",
|
||||
" dim_feedforward: int = 256,\n",
|
||||
" dropout: float = 0.10,\n",
|
||||
" pad_id: int = 0,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Create the encoder, coordinate projection, and regression head.\"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.vocab_size = vocab_size\n",
|
||||
" self.max_len = max_len\n",
|
||||
" self.d_model = d_model\n",
|
||||
" self.pad_id = pad_id\n",
|
||||
"\n",
|
||||
" self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)\n",
|
||||
" self.pos_emb = nn.Embedding(max_len, d_model)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\"coord_features\", coord_features.clone().float())\n",
|
||||
" self.coord_proj = nn.Linear(coord_features.shape[1], d_model)\n",
|
||||
"\n",
|
||||
" encoder_layer = nn.TransformerEncoderLayer(\n",
|
||||
" d_model=d_model,\n",
|
||||
" nhead=nhead,\n",
|
||||
" dim_feedforward=dim_feedforward,\n",
|
||||
" dropout=dropout,\n",
|
||||
" activation=\"gelu\",\n",
|
||||
" batch_first=True,\n",
|
||||
" norm_first=True,\n",
|
||||
" )\n",
|
||||
" self.encoder = nn.TransformerEncoder(\n",
|
||||
" encoder_layer,\n",
|
||||
" num_layers=num_layers,\n",
|
||||
" enable_nested_tensor=False,\n",
|
||||
" )\n",
|
||||
" self.norm = nn.LayerNorm(d_model)\n",
|
||||
" self.head = nn.Sequential(\n",
|
||||
" nn.Linear(d_model, d_model),\n",
|
||||
" nn.GELU(),\n",
|
||||
" nn.Dropout(dropout),\n",
|
||||
" nn.Linear(d_model, 1),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:\n",
|
||||
" \"\"\"Return one continuous difficulty prediction per input sequence.\"\"\"\n",
|
||||
" batch_size, seq_len = input_ids.shape\n",
|
||||
" positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)\n",
|
||||
"\n",
|
||||
" # Coordinate features are indexed by token ID, so every occurrence of a\n",
|
||||
" # hold token gets the same physical x/y hint wherever it appears.\n",
|
||||
" x = self.token_emb(input_ids) + self.pos_emb(positions)\n",
|
||||
" x = x + self.coord_proj(self.coord_features[input_ids])\n",
|
||||
"\n",
|
||||
" key_padding_mask = ~attention_mask.bool()\n",
|
||||
" h = self.encoder(x, src_key_padding_mask=key_padding_mask)\n",
|
||||
" h = self.norm(h)\n",
|
||||
"\n",
|
||||
" cls_state = h[:, 0, :]\n",
|
||||
" return self.head(cls_state).squeeze(-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4455557a",
|
||||
@@ -259,16 +703,21 @@
|
||||
"source": [
|
||||
"## Grade consistency (using the trained critic)\n",
|
||||
"\n",
|
||||
"If we have a trained grade predictor (from notebook 02), we can use it as a **critic** to check whether generated routes have grades consistent with what was requested.\n",
|
||||
"\n",
|
||||
"This is similar to how GANs use a discriminator to evaluate generated samples, except our critic is a regression model rather than a binary classifier."
|
||||
"If we have a trained grade predictor (from notebook 02), we can use it as a **critic** to check whether generated routes have grades consistent with what was requested.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "88747d6e",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:42.795099Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:42.794788Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:43.323348Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:43.321923Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Try to load the grade critic from notebook 02\n",
|
||||
@@ -355,7 +804,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "critic_eval",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:43.327454Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:43.326834Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:44.473105Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:44.472309Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Apply the critic to evaluate grade consistency\n",
|
||||
@@ -390,14 +846,22 @@
|
||||
"- **Basic validity** (required): At least 3 holds, start/finish, no duplicates, one board\n",
|
||||
"- **Strict validity** (bonus): Also has middle holds and 4+ holds\n",
|
||||
"- **Novelty** (higher is better): Distance from nearest real route\n",
|
||||
"- **Grade consistency** (if critic available): Predicted grade close to requested grade"
|
||||
"- **Grade consistency** (if critic available): Predicted grade close to requested grade\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "88747d6e2",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:44.476183Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:44.475814Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:44.489525Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:44.488845Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Rank candidates by composite score\n",
|
||||
@@ -427,14 +891,22 @@
|
||||
"source": [
|
||||
"## Save evaluation results\n",
|
||||
"\n",
|
||||
"We save the full evaluation DataFrame and the top candidates for further analysis."
|
||||
"We save the full evaluation DataFrame and the top candidates for further analysis.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "save_results",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:44.492676Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:44.492218Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:44.561651Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:44.560880Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save evaluation results\n",
|
||||
@@ -485,7 +957,8 @@
|
||||
"\n",
|
||||
"- Validity checks are structural, not semantic. A route might have valid start/finish holds but still be impossible.\n",
|
||||
"- Geometric features are simple. More sophisticated analysis could check reachability and move sequences.\n",
|
||||
"- The critic model was trained on real data, so it may not generalize well to novel route structures."
|
||||
"- The critic model was trained on real data, so it may not generalize well to novel route structures.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -505,7 +978,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.14.4"
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user