Update notebook results and README stats
This commit is contained in:
@@ -19,8 +19,8 @@
|
||||
"### Validity checks\n",
|
||||
"\n",
|
||||
"A \"basic valid\" route must have:\n",
|
||||
"- At least 3 holds (you need at least 2 hands + 1 foot to climb)\n",
|
||||
"- No duplicate placements (you can't use the same hold twice)\n",
|
||||
"- At least 3 holds\n",
|
||||
"- No duplicate placements\n",
|
||||
"- At least one start hold and one finish hold\n",
|
||||
"- All holds from the same board (no mixing TB2 and Kilter holds)\n",
|
||||
"\n",
|
||||
@@ -35,46 +35,55 @@
|
||||
"- Jaccard similarity = |A intersection B| / |A union B|\n",
|
||||
"- Novelty distance = 1 - Jaccard similarity\n",
|
||||
"\n",
|
||||
"A novelty distance of 1.0 means the generated route shares no holds with any real route. A distance of 0.0 means it's identical to an existing route."
|
||||
"A novelty distance of 1.0 means the generated route shares no holds with any real route. A distance of 0.0 means it's identical to an existing route.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "726b846f",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:02.200057Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:02.199717Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:04.626359Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:04.625624Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import annotations\n",
|
||||
"\n",
|
||||
"import ast\n",
|
||||
"import re\n",
|
||||
"from pathlib import Path\n",
|
||||
"import sys\n",
|
||||
"from typing import Iterable\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from scipy.spatial.distance import pdist\n",
|
||||
"\n",
|
||||
"ROOT = Path.cwd().resolve()\n",
|
||||
"if ROOT.name == \"notebooks\":\n",
|
||||
" ROOT = ROOT.parent\n",
|
||||
"sys.path.insert(0, str(ROOT / \"src\"))\n",
|
||||
"\n",
|
||||
"from climbingboardgpt.evaluation import (\n",
|
||||
" build_placement_coords,\n",
|
||||
" frames_to_holds,\n",
|
||||
" holds_to_placement_set,\n",
|
||||
" nearest_real_route_same_board,\n",
|
||||
" parse_token_list,\n",
|
||||
" simple_route_features,\n",
|
||||
" tokens_to_hold_records,\n",
|
||||
" validity_from_records,\n",
|
||||
")\n",
|
||||
"from climbingboardgpt.grades import to_grouped_v\n",
|
||||
"from climbingboardgpt.models import JointRouteTransformerRegressor"
|
||||
" ROOT = ROOT.parent"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f8bb61f",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:04.629832Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:04.629390Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.364160Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.363335Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load generated routes and real routes for comparison\n",
|
||||
@@ -111,6 +120,107 @@
|
||||
"print(f\"Real routes: {len(df_real):,}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6dc0ac67",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Token parsing and validity helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c32f7ced",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.368243Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.367603Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.380028Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.379312Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse generated token strings and compute basic route-validity flags.\n",
|
||||
"HOLD_TOKEN_PATTERN = re.compile(r\"^<([A-Z0-9_]+)_p(\\d+)_(start|middle|finish|foot|unknown)>$\")\n",
|
||||
"\n",
|
||||
"def parse_tokens(value) -> list[str]:\n",
|
||||
" \"\"\"Parse tokens from a list, repr-style list string, or whitespace sequence.\"\"\"\n",
|
||||
" if isinstance(value, list):\n",
|
||||
" return [str(v) for v in value]\n",
|
||||
" if not isinstance(value, str):\n",
|
||||
" return []\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" parsed = ast.literal_eval(value)\n",
|
||||
" if isinstance(parsed, list):\n",
|
||||
" return [str(v) for v in parsed]\n",
|
||||
" except (SyntaxError, ValueError):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" return value.split()\n",
|
||||
"\n",
|
||||
"def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:\n",
|
||||
" \"\"\"Extract hold records from model tokens using the shared hold-token grammar.\"\"\"\n",
|
||||
" rows: list[dict[str, object]] = []\n",
|
||||
" for token in tokens:\n",
|
||||
" match = HOLD_TOKEN_PATTERN.match(str(token))\n",
|
||||
" if match is None:\n",
|
||||
" continue\n",
|
||||
" board_prefix = match.group(1)\n",
|
||||
" rows.append(\n",
|
||||
" {\n",
|
||||
" \"token\": str(token),\n",
|
||||
" \"board_token_prefix\": board_prefix,\n",
|
||||
" \"board_prefix\": board_prefix,\n",
|
||||
" \"placement_id\": int(match.group(2)),\n",
|
||||
" \"role\": match.group(3),\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" return rows\n",
|
||||
"\n",
|
||||
"def parse_token_list(value) -> list[str]:\n",
|
||||
" \"\"\"Compatibility wrapper around the shared token parser.\"\"\"\n",
|
||||
" return parse_tokens(value)\n",
|
||||
"\n",
|
||||
"def validity_from_records(records: list[dict[str, object]], requested_board_prefix: str | None = None) -> dict[str, object]:\n",
|
||||
" \"\"\"Compute evaluation-specific route-validity flags from hold records.\"\"\"\n",
|
||||
" placements = [int(record[\"placement_id\"]) for record in records]\n",
|
||||
" roles = [str(record[\"role\"]) for record in records]\n",
|
||||
" prefixes = [str(record[\"board_token_prefix\"]) for record in records]\n",
|
||||
" one_board_only = len(set(prefixes)) <= 1\n",
|
||||
" matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)\n",
|
||||
"\n",
|
||||
" out = {\n",
|
||||
" \"n_holds_eval\": len(records),\n",
|
||||
" \"n_unique_placements_eval\": len(set(placements)),\n",
|
||||
" \"has_duplicate_placements_eval\": len(records) != len(set(placements)),\n",
|
||||
" \"one_board_only_eval\": one_board_only,\n",
|
||||
" \"matches_requested_board_eval\": matches_requested_board,\n",
|
||||
" \"n_start_eval\": roles.count(\"start\"),\n",
|
||||
" \"n_middle_eval\": roles.count(\"middle\"),\n",
|
||||
" \"n_foot_eval\": roles.count(\"foot\"),\n",
|
||||
" \"n_finish_eval\": roles.count(\"finish\"),\n",
|
||||
" \"has_start_eval\": \"start\" in roles,\n",
|
||||
" \"has_middle_eval\": \"middle\" in roles,\n",
|
||||
" \"has_finish_eval\": \"finish\" in roles,\n",
|
||||
" }\n",
|
||||
" out[\"basic_valid_eval\"] = (\n",
|
||||
" one_board_only\n",
|
||||
" and out[\"n_holds_eval\"] >= 3\n",
|
||||
" and out[\"n_holds_eval\"] == out[\"n_unique_placements_eval\"]\n",
|
||||
" and out[\"has_start_eval\"]\n",
|
||||
" and out[\"has_finish_eval\"]\n",
|
||||
" )\n",
|
||||
" out[\"strict_valid_eval\"] = (\n",
|
||||
" out[\"basic_valid_eval\"]\n",
|
||||
" and out[\"has_middle_eval\"]\n",
|
||||
" and out[\"n_holds_eval\"] >= 4\n",
|
||||
" )\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0091bafb",
|
||||
@@ -118,14 +228,22 @@
|
||||
"source": [
|
||||
"## Parse generated tokens and check validity\n",
|
||||
"\n",
|
||||
"We parse the generated token sequences and check each route for validity."
|
||||
"We parse the generated token sequences and check each route for validity.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f5c2b25a",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.383121Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.382759Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.430410Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.429741Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Parse the token strings into structured records\n",
|
||||
@@ -149,6 +267,89 @@
|
||||
"print(validity_summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0ff31b72",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Novelty helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "10a40f48",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.434117Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.433720Z",
|
||||
"iopub.status.idle": "2026-06-07T23:44:10.443045Z",
|
||||
"shell.execute_reply": "2026-06-07T23:44:10.442319Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Compare generated hold sets to real routes using Jaccard similarity.\n",
|
||||
"def frames_to_holds(frames: str | None) -> list[tuple[int, int]]:\n",
|
||||
" \"\"\"Parse a frames string into ``(placement_id, role_id)`` pairs.\"\"\"\n",
|
||||
" if not isinstance(frames, str):\n",
|
||||
" return []\n",
|
||||
" return [(int(p), int(r)) for p, r in re.findall(r\"p(\\d+)r(\\d+)\", frames)]\n",
|
||||
"\n",
|
||||
"def holds_to_placement_set(holds: Iterable[tuple[int, int]]) -> frozenset[int]:\n",
|
||||
" \"\"\"Drop role IDs and represent a route by its unique placement IDs.\"\"\"\n",
|
||||
" return frozenset(int(placement_id) for placement_id, _ in holds)\n",
|
||||
"\n",
|
||||
"def jaccard(a: frozenset[int], b: frozenset[int]) -> float:\n",
|
||||
" \"\"\"Return Jaccard similarity between two placement sets.\"\"\"\n",
|
||||
" if not a and not b:\n",
|
||||
" return 1.0\n",
|
||||
" if not a or not b:\n",
|
||||
" return 0.0\n",
|
||||
" return len(a & b) / len(a | b)\n",
|
||||
"\n",
|
||||
"def nearest_real_route_same_board(\n",
|
||||
" generated_set: frozenset[int],\n",
|
||||
" generated_board_key: str,\n",
|
||||
" real_df: pd.DataFrame,\n",
|
||||
") -> dict[str, object]:\n",
|
||||
" \"\"\"Find the most similar real route on the same board by Jaccard score.\n",
|
||||
"\n",
|
||||
" .. note::\n",
|
||||
"\n",
|
||||
" This function performs an O(n) linear scan over all real routes for\n",
|
||||
" the matching board, computing a Jaccard similarity for each one. With\n",
|
||||
" ~256K training examples, evaluating 400 generated routes costs roughly\n",
|
||||
" O(100M) Jaccard comparisons. This is acceptable for evaluation scripts\n",
|
||||
" but would not scale to a real-time or high-throughput setting without\n",
|
||||
" an approximate nearest-neighbour index.\n",
|
||||
" \"\"\"\n",
|
||||
" board_frame = real_df[real_df[\"board_key\"] == generated_board_key]\n",
|
||||
" if board_frame.empty:\n",
|
||||
" return {\n",
|
||||
" \"nearest_real_jaccard\": np.nan,\n",
|
||||
" \"nearest_real_uuid\": None,\n",
|
||||
" \"nearest_real_name\": None,\n",
|
||||
" \"nearest_real_grouped_v\": None,\n",
|
||||
" \"nearest_real_angle\": None,\n",
|
||||
" \"novelty_distance\": np.nan,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" similarities = board_frame[\"hold_set\"].map(lambda hold_set: jaccard(generated_set, hold_set))\n",
|
||||
" best_idx = similarities.idxmax()\n",
|
||||
" row = board_frame.loc[best_idx]\n",
|
||||
"\n",
|
||||
" nearest_real_jaccard = float(similarities.loc[best_idx])\n",
|
||||
" return {\n",
|
||||
" \"nearest_real_jaccard\": nearest_real_jaccard,\n",
|
||||
" \"nearest_real_uuid\": row[\"uuid\"],\n",
|
||||
" \"nearest_real_name\": row[\"climb_name\"],\n",
|
||||
" \"nearest_real_grouped_v\": row[\"grouped_v\"],\n",
|
||||
" \"nearest_real_angle\": row[\"angle\"],\n",
|
||||
" \"novelty_distance\": 1.0 - nearest_real_jaccard,\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0cf2170e",
|
||||
@@ -156,14 +357,22 @@
|
||||
"source": [
|
||||
"## Novelty against real climbs\n",
|
||||
"\n",
|
||||
"For each generated route, we find the most similar real route from the same board (by Jaccard similarity of hold sets). A good generator should produce routes that are novel (low Jaccard similarity to existing routes) while still being valid."
|
||||
"For each generated route, we find the most similar real route from the same board (by Jaccard similarity of hold sets). A good generator should produce routes that are novel (low Jaccard similarity to existing routes) while still being valid.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e7f34524",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:44:10.446422Z",
|
||||
"iopub.status.busy": "2026-06-07T23:44:10.445998Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:41.914124Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:41.913292Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Convert hold sets to frozensets for fast comparison\n",
|
||||
@@ -201,6 +410,105 @@
|
||||
"print(novelty_summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ad70ff4c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Geometry helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "85ddaf53",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:41.918125Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:41.917658Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:41.929570Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:41.928790Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Compute simple geometric descriptors from placement coordinates.\n",
|
||||
"def build_placement_coords(df_token_meta: pd.DataFrame) -> dict[tuple[str, int], dict[str, float]]:\n",
|
||||
" \"\"\"Build a placement-coordinate lookup from token metadata.\"\"\"\n",
|
||||
" hold_meta = df_token_meta[df_token_meta[\"kind\"] == \"hold\"].dropna(subset=[\"placement_id\"]).copy()\n",
|
||||
" coords = {}\n",
|
||||
" for _, row in hold_meta.drop_duplicates([\"board_key\", \"placement_id\"]).iterrows():\n",
|
||||
" key = (str(row[\"board_key\"]), int(row[\"placement_id\"]))\n",
|
||||
" coords[key] = {\n",
|
||||
" \"x\": float(row[\"x\"]),\n",
|
||||
" \"y\": float(row[\"y\"]),\n",
|
||||
" }\n",
|
||||
" return coords\n",
|
||||
"\n",
|
||||
"def simple_route_features(\n",
|
||||
" board_key: str,\n",
|
||||
" records: list[dict[str, object]],\n",
|
||||
" placement_coords: dict[tuple[str, int], dict[str, float]],\n",
|
||||
") -> dict[str, float]:\n",
|
||||
" \"\"\"Compute simple geometric route features from hold coordinates.\n",
|
||||
"\n",
|
||||
" These features are descriptive rather than a full climbing-physics model:\n",
|
||||
" height/width describe route spread, and hand-reach distances summarize the\n",
|
||||
" pairwise spacing among start/middle/finish holds.\n",
|
||||
" \"\"\"\n",
|
||||
" rows = []\n",
|
||||
" for record in records:\n",
|
||||
" key = (str(board_key), int(record[\"placement_id\"]))\n",
|
||||
" coord = placement_coords.get(key)\n",
|
||||
" if coord is None:\n",
|
||||
" continue\n",
|
||||
" x = float(coord[\"x\"])\n",
|
||||
" y = float(coord[\"y\"])\n",
|
||||
" if np.isnan(x) or np.isnan(y):\n",
|
||||
" continue\n",
|
||||
" role = str(record[\"role\"])\n",
|
||||
" rows.append(\n",
|
||||
" {\n",
|
||||
" \"x\": x,\n",
|
||||
" \"y\": y,\n",
|
||||
" \"role\": role,\n",
|
||||
" \"is_hand\": role in {\"start\", \"middle\", \"finish\"},\n",
|
||||
" \"is_foot\": role == \"foot\",\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" if not rows:\n",
|
||||
" return {\n",
|
||||
" \"geom_n_holds\": 0.0,\n",
|
||||
" \"geom_height\": np.nan,\n",
|
||||
" \"geom_width\": np.nan,\n",
|
||||
" \"geom_mean_y\": np.nan,\n",
|
||||
" \"geom_mean_x_abs\": np.nan,\n",
|
||||
" \"geom_mean_hand_reach\": np.nan,\n",
|
||||
" \"geom_max_hand_reach\": np.nan,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" d = pd.DataFrame(rows)\n",
|
||||
" out = {\n",
|
||||
" \"geom_n_holds\": float(len(d)),\n",
|
||||
" \"geom_height\": float(d[\"y\"].max() - d[\"y\"].min()),\n",
|
||||
" \"geom_width\": float(d[\"x\"].max() - d[\"x\"].min()),\n",
|
||||
" \"geom_mean_y\": float(d[\"y\"].mean()),\n",
|
||||
" \"geom_mean_x_abs\": float(d[\"x\"].abs().mean()),\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" hands = d[d[\"is_hand\"]].sort_values([\"y\", \"x\"])\n",
|
||||
" if len(hands) >= 2:\n",
|
||||
" distances = pdist(hands[[\"x\", \"y\"]].values)\n",
|
||||
" out[\"geom_mean_hand_reach\"] = float(distances.mean())\n",
|
||||
" out[\"geom_max_hand_reach\"] = float(distances.max())\n",
|
||||
" else:\n",
|
||||
" out[\"geom_mean_hand_reach\"] = np.nan\n",
|
||||
" out[\"geom_max_hand_reach\"] = np.nan\n",
|
||||
"\n",
|
||||
" return out"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b581705d",
|
||||
@@ -215,14 +523,22 @@
|
||||
"- `geom_width`: Horizontal extent\n",
|
||||
"- `geom_mean_hand_reach`: Average distance between hand holds\n",
|
||||
"\n",
|
||||
"These features help us understand whether generated routes have reasonable spatial properties."
|
||||
"These features help us understand whether generated routes have reasonable spatial properties.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d74d4cad",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:41.932520Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:41.932262Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:42.775565Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:42.774476Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Build coordinate lookup from token metadata\n",
|
||||
@@ -252,6 +568,134 @@
|
||||
"print(geom_summary)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "44036a1e",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Critic model and grade helpers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9cfff1f4",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:42.779195Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:42.778895Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:42.791727Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:42.790706Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Map BoardLib display difficulties into grouped V-grade tokens.\n",
|
||||
"GRADE_TO_V = {\n",
|
||||
" 10: 0, 11: 0, 12: 0,\n",
|
||||
" 13: 1, 14: 1,\n",
|
||||
" 15: 2,\n",
|
||||
" 16: 3, 17: 3,\n",
|
||||
" 18: 4, 19: 4,\n",
|
||||
" 20: 5, 21: 5,\n",
|
||||
" 22: 6,\n",
|
||||
" 23: 7,\n",
|
||||
" 24: 8, 25: 8,\n",
|
||||
" 26: 9,\n",
|
||||
" 27: 10,\n",
|
||||
" 28: 11,\n",
|
||||
" 29: 12,\n",
|
||||
" 30: 13,\n",
|
||||
" 31: 14,\n",
|
||||
" 32: 15,\n",
|
||||
" 33: 16,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"def to_grouped_v(display_difficulty: float) -> int:\n",
|
||||
" \"\"\"Map a continuous display difficulty to the nearest grouped V grade.\"\"\"\n",
|
||||
" rounded = int(round(float(display_difficulty)))\n",
|
||||
" rounded = max(min(rounded, max(GRADE_TO_V)), min(GRADE_TO_V))\n",
|
||||
" return GRADE_TO_V[rounded]\n",
|
||||
"\n",
|
||||
"def grade_token(display_difficulty: float) -> str:\n",
|
||||
" \"\"\"Return the grade-conditioning token for a display difficulty value.\"\"\"\n",
|
||||
" return f\"<GRADE_V{to_grouped_v(display_difficulty)}>\"\n",
|
||||
"\n",
|
||||
"# Transformer encoder used as a continuous grade regressor.\n",
|
||||
"class JointRouteTransformerRegressor(nn.Module):\n",
|
||||
" \"\"\"Transformer encoder for joint TB2/Kilter route difficulty prediction.\n",
|
||||
"\n",
|
||||
" Inputs are token IDs plus an attention mask. Token, position, and learned\n",
|
||||
" projections of coordinate metadata are added before the encoder. The first\n",
|
||||
" ``<CLS>`` position is then used as a pooled route representation for scalar\n",
|
||||
" difficulty regression.\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" vocab_size: int,\n",
|
||||
" max_len: int,\n",
|
||||
" coord_features: torch.Tensor,\n",
|
||||
" d_model: int = 128,\n",
|
||||
" nhead: int = 4,\n",
|
||||
" num_layers: int = 4,\n",
|
||||
" dim_feedforward: int = 256,\n",
|
||||
" dropout: float = 0.10,\n",
|
||||
" pad_id: int = 0,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Create the encoder, coordinate projection, and regression head.\"\"\"\n",
|
||||
" super().__init__()\n",
|
||||
" self.vocab_size = vocab_size\n",
|
||||
" self.max_len = max_len\n",
|
||||
" self.d_model = d_model\n",
|
||||
" self.pad_id = pad_id\n",
|
||||
"\n",
|
||||
" self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)\n",
|
||||
" self.pos_emb = nn.Embedding(max_len, d_model)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\"coord_features\", coord_features.clone().float())\n",
|
||||
" self.coord_proj = nn.Linear(coord_features.shape[1], d_model)\n",
|
||||
"\n",
|
||||
" encoder_layer = nn.TransformerEncoderLayer(\n",
|
||||
" d_model=d_model,\n",
|
||||
" nhead=nhead,\n",
|
||||
" dim_feedforward=dim_feedforward,\n",
|
||||
" dropout=dropout,\n",
|
||||
" activation=\"gelu\",\n",
|
||||
" batch_first=True,\n",
|
||||
" norm_first=True,\n",
|
||||
" )\n",
|
||||
" self.encoder = nn.TransformerEncoder(\n",
|
||||
" encoder_layer,\n",
|
||||
" num_layers=num_layers,\n",
|
||||
" enable_nested_tensor=False,\n",
|
||||
" )\n",
|
||||
" self.norm = nn.LayerNorm(d_model)\n",
|
||||
" self.head = nn.Sequential(\n",
|
||||
" nn.Linear(d_model, d_model),\n",
|
||||
" nn.GELU(),\n",
|
||||
" nn.Dropout(dropout),\n",
|
||||
" nn.Linear(d_model, 1),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:\n",
|
||||
" \"\"\"Return one continuous difficulty prediction per input sequence.\"\"\"\n",
|
||||
" batch_size, seq_len = input_ids.shape\n",
|
||||
" positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len)\n",
|
||||
"\n",
|
||||
" # Coordinate features are indexed by token ID, so every occurrence of a\n",
|
||||
" # hold token gets the same physical x/y hint wherever it appears.\n",
|
||||
" x = self.token_emb(input_ids) + self.pos_emb(positions)\n",
|
||||
" x = x + self.coord_proj(self.coord_features[input_ids])\n",
|
||||
"\n",
|
||||
" key_padding_mask = ~attention_mask.bool()\n",
|
||||
" h = self.encoder(x, src_key_padding_mask=key_padding_mask)\n",
|
||||
" h = self.norm(h)\n",
|
||||
"\n",
|
||||
" cls_state = h[:, 0, :]\n",
|
||||
" return self.head(cls_state).squeeze(-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4455557a",
|
||||
@@ -259,16 +703,21 @@
|
||||
"source": [
|
||||
"## Grade consistency (using the trained critic)\n",
|
||||
"\n",
|
||||
"If we have a trained grade predictor (from notebook 02), we can use it as a **critic** to check whether generated routes have grades consistent with what was requested.\n",
|
||||
"\n",
|
||||
"This is similar to how GANs use a discriminator to evaluate generated samples, except our critic is a regression model rather than a binary classifier."
|
||||
"If we have a trained grade predictor (from notebook 02), we can use it as a **critic** to check whether generated routes have grades consistent with what was requested.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "88747d6e",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:42.795099Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:42.794788Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:43.323348Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:43.321923Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Try to load the grade critic from notebook 02\n",
|
||||
@@ -355,7 +804,14 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "critic_eval",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:43.327454Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:43.326834Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:44.473105Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:44.472309Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Apply the critic to evaluate grade consistency\n",
|
||||
@@ -390,14 +846,22 @@
|
||||
"- **Basic validity** (required): At least 3 holds, start/finish, no duplicates, one board\n",
|
||||
"- **Strict validity** (bonus): Also has middle holds and 4+ holds\n",
|
||||
"- **Novelty** (higher is better): Distance from nearest real route\n",
|
||||
"- **Grade consistency** (if critic available): Predicted grade close to requested grade"
|
||||
"- **Grade consistency** (if critic available): Predicted grade close to requested grade\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "88747d6e2",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:44.476183Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:44.475814Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:44.489525Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:44.488845Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Rank candidates by composite score\n",
|
||||
@@ -427,14 +891,22 @@
|
||||
"source": [
|
||||
"## Save evaluation results\n",
|
||||
"\n",
|
||||
"We save the full evaluation DataFrame and the top candidates for further analysis."
|
||||
"We save the full evaluation DataFrame and the top candidates for further analysis.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "save_results",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2026-06-07T23:46:44.492676Z",
|
||||
"iopub.status.busy": "2026-06-07T23:46:44.492218Z",
|
||||
"iopub.status.idle": "2026-06-07T23:46:44.561651Z",
|
||||
"shell.execute_reply": "2026-06-07T23:46:44.560880Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save evaluation results\n",
|
||||
@@ -485,7 +957,8 @@
|
||||
"\n",
|
||||
"- Validity checks are structural, not semantic. A route might have valid start/finish holds but still be impossible.\n",
|
||||
"- Geometric features are simple. More sophisticated analysis could check reachability and move sequences.\n",
|
||||
"- The critic model was trained on real data, so it may not generalize well to novel route structures."
|
||||
"- The critic model was trained on real data, so it may not generalize well to novel route structures.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -505,7 +978,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.14.4"
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user