Next version. Models + scripts updated. 2
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,4 +1,6 @@
|
||||
data/*
|
||||
models/*
|
||||
images/*
|
||||
notebooks/*_executed.ipynb
|
||||
notebooks/*_executed.ipynb
|
||||
src/climbingboardgpt/__pycache__
|
||||
outputs/
|
||||
22
LICENSE
22
LICENSE
@@ -1,21 +1,9 @@
|
||||
MIT License
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2026
|
||||
Copyright © 2026 Pawel Sarkowicz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
713
README.md
713
README.md
@@ -1,25 +1,25 @@
|
||||
# ClimbingBoardGPT
|
||||
|
||||
**Applying LLM-style transformer techniques to climbing board route generation and grade prediction.**
|
||||
**ClimbingBoardGPT** is a unified transformer-style modeling project for climbing-board routes on:
|
||||
|
||||
This project treats climbing routes as language and trains transformer models on data from the **Tension Board 2 Mirror** and **Kilter Board Original** — learning to predict grades and generate entirely new routes.
|
||||
- **Tension Board 2 Mirror**
|
||||
- **Kilter Board Original**
|
||||
|
||||
The project treats climbing-board problems as symbolic sequences of board-aware hold-role tokens. It supports:
|
||||
|
||||
1. joint route tokenization for TB2 and Kilter,
|
||||
2. transformer-based grade prediction,
|
||||
3. GPT-style route generation conditioned on board, wall angle, and target grade,
|
||||
4. calibrated board-background visualization,
|
||||
5. command-line demo scripts for generation and grade prediction.
|
||||
|
||||
This repo is the transformer/GPT follow-up project to [Tension-Board-2-Analysis] and [Kilter-Board-Analysis].
|
||||
|
||||
---
|
||||
|
||||
## The Core Idea
|
||||
## Core idea
|
||||
|
||||
Large language models process text as sequences of tokens and learn statistical patterns from billions of examples. Climbing routes have the same structure:
|
||||
|
||||
| NLP Concept | Climbing Analog |
|
||||
|---|---|
|
||||
| Word / Subword | Hold token (`TB2_p344_start`) |
|
||||
| Sentence | Route (sequence of holds) |
|
||||
| Document language | Board type (TB2 vs Kilter) |
|
||||
| POS tag | Semantic role (start / middle / finish / foot) |
|
||||
| Genre / Domain | Angle + Grade conditioning |
|
||||
| Special tokens | `<BOS>`, `<EOS>`, `<PAD>`, `<CLS>`, `<MASK>`, `<UNK>` |
|
||||
|
||||
A route becomes a symbolic sequence:
|
||||
A route is represented as a sequence like:
|
||||
|
||||
```text
|
||||
<BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>
|
||||
@@ -27,295 +27,576 @@ A route becomes a symbolic sequence:
|
||||
<EOS>
|
||||
```
|
||||
|
||||
The same transformer architectures that power GPT and BERT can learn "climb grammar" — which holds tend to follow which, how start holds differ from finish holds, and how difficulty emerges from spatial relationships.
|
||||
or:
|
||||
|
||||
```text
|
||||
<BOS> <BOARD_KILTER> <ANGLE_40> <GRADE_V6>
|
||||
<KILTER_p1084_start> <KILTER_p1231_middle> <KILTER_p1395_finish>
|
||||
<EOS>
|
||||
```
|
||||
|
||||
Hold tokens are **board-namespaced**, so a TB2 placement ID and a Kilter placement ID never collide.
|
||||
|
||||
For grade prediction, the grade token is removed:
|
||||
|
||||
```text
|
||||
<CLS> <BOARD_TB2> <ANGLE_40>
|
||||
<TB2_p344_start> <TB2_p369_middle> <TB2_p603_finish>
|
||||
<EOS>
|
||||
```
|
||||
|
||||
The model then predicts the climb difficulty from the board, angle, and hold-role tokens.
|
||||
|
||||
|
||||
---
|
||||
|
||||
## What This Repo Does
|
||||
## Quantitative results from the executed notebooks
|
||||
|
||||
### 1. Tokenization (`01_tokenize_routes`)
|
||||
These numbers come from the executed four-notebook run included with the project. They should be treated as the current benchmark for this checkpoint/data snapshot; rerun the pipeline if the raw databases, tokenization, model sizes, or train/validation/test split change.
|
||||
|
||||
Converts raw SQLite data into tokenized sequences:
|
||||
### Dataset and tokenization scale
|
||||
|
||||
- Parses `frames` strings (e.g., `p344r5p369r6p603r7`) into structured hold records
|
||||
- Maps board-specific role IDs to shared semantic roles (TB2: 5/6/7/8 → Kilter: 12/13/14/15 → `start`/`middle`/`finish`/`foot`)
|
||||
- Sorts holds canonically by (role priority, y-position, x-position)
|
||||
- Generates two sequence versions:
|
||||
- **With grade** — for GPT generation training
|
||||
- **Without grade** — for BERT-style grade prediction
|
||||
- Builds a shared vocabulary (~4,400 tokens), stratified train/val/test splits, and coordinate metadata
|
||||
The unified tokenizer builds one shared corpus across TB2 and Kilter.
|
||||
|
||||
### 2. Grade Prediction (`02_train_grade_predictor`)
|
||||
| Quantity | Value |
|
||||
|---|---:|
|
||||
| Total route/angle entries | 321,085 |
|
||||
| TB2 entries | 42,596 |
|
||||
| Kilter entries | 278,489 |
|
||||
| Placement metadata rows | 1,139 |
|
||||
| Shared vocabulary size | 4,438 tokens |
|
||||
| Special tokens | 6 |
|
||||
| Board tokens | 2 |
|
||||
| Angle tokens | 12 |
|
||||
| Grade tokens | 16 |
|
||||
| Hold-role tokens | 4,402 |
|
||||
| Grade-predictor max sequence length | 398 |
|
||||
| GPT-generator max sequence length | 399 |
|
||||
|
||||
Trains a **transformer encoder** (BERT-style) to predict climb difficulty:
|
||||
The train/validation/test split used in the executed notebooks was:
|
||||
|
||||
- Input: `<CLS> <BOARD_TB2> <ANGLE_40> <TB2_p344_start> ...` (grade excluded)
|
||||
- Output: Single difficulty score (regression)
|
||||
- Coordinate features (x, y, is_hold) are projected and added to token embeddings
|
||||
- Joint training across both boards with board-conditioning tokens
|
||||
| Board | Train | Validation | Test |
|
||||
|---|---:|---:|---:|
|
||||
| TB2 | 33,719 | 4,430 | 4,447 |
|
||||
| Kilter | 223,112 | 27,555 | 27,822 |
|
||||
| **Total** | **256,831** | **31,985** | **32,269** |
|
||||
|
||||
**Results (joint model, test set):**
|
||||
### Grade prediction performance
|
||||
|
||||
The grade predictor is a transformer encoder trained jointly on both boards. It receives board, angle, hold-role tokens, and coordinate features, but **does not receive the grade token**.
|
||||
|
||||
| Metric | Overall | TB2 | Kilter |
|
||||
|---|---|---|---|
|
||||
| MAE | 1.47 | 1.42 | 1.48 |
|
||||
| R² | 0.787 | 0.816 | 0.782 |
|
||||
| Within ±1 V-grade | 80.1% | 81.3% | 80.0% |
|
||||
| Within ±2 V-grades | 95.3% | 96.1% | 95.2% |
|
||||
|---|---:|---:|---:|
|
||||
| MAE | 1.481 | 1.420 | 1.490 |
|
||||
| RMSE | 1.941 | 1.845 | 1.956 |
|
||||
| R² | 0.768 | 0.800 | 0.763 |
|
||||
| Exact grouped V-grade | 36.0% | 37.3% | 35.8% |
|
||||
| Within ±1 V-grade | 79.3% | 80.0% | 79.2% |
|
||||
| Within ±2 V-grades | 94.8% | 95.5% | 94.7% |
|
||||
|
||||
### 3. Route Generation (`03_train_route_generator`)
|
||||
The model has about **1.17M parameters**. In the executed run, early stopping selected epoch 8 with validation MAE ≈ **1.480**.
|
||||
|
||||
Trains a **GPT-style causal transformer** to generate new routes:
|
||||
### Route generator training
|
||||
|
||||
- Input prompt: `<BOS> <BOARD_TB2> <ANGLE_40> <GRADE_V6>`
|
||||
- Output: Sequence of hold tokens ending with `<EOS>`
|
||||
- Uses causal masking (each position attends only to previous positions)
|
||||
- Generation uses temperature sampling and top-k filtering
|
||||
The route generator is a GPT-style causal transformer trained on grade-conditioned route sequences.
|
||||
|
||||
**Training results:**
|
||||
- Best validation perplexity: ~24.6
|
||||
- 88.8% basic validity rate for generated routes
|
||||
| Quantity | Value |
|
||||
|---|---:|
|
||||
| Model size | ~1.41M parameters |
|
||||
| Best validation loss | 3.187 |
|
||||
| Best validation perplexity | 24.2 |
|
||||
| Evaluation sample size | 400 generated routes |
|
||||
| Overall basic validity | 91.5% |
|
||||
| Overall strict validity | 91.5% |
|
||||
|
||||
### 4. Evaluation (`04_evaluate_generated_routes`)
|
||||
During the generator evaluation run, routes were sampled across both boards, common angles, and target grades V1–V8.
|
||||
|
||||
Evaluates generated routes on four dimensions:
|
||||
### Generated-route evaluation
|
||||
|
||||
- **Validity**: Structural correctness (start/finish holds, no duplicates, single board)
|
||||
- **Novelty**: Jaccard distance from nearest real route
|
||||
- **Geometric plausibility**: Height, width, reach distances
|
||||
- **Grade consistency**: Uses the trained grade predictor as a critic
|
||||
|
||||
**Evaluation results:**
|
||||
Generated routes are evaluated by structural validity, novelty against real climbs, geometric features, and grade consistency using the trained grade predictor as a critic.
|
||||
|
||||
| Metric | TB2 | Kilter |
|
||||
|---|---|---|
|
||||
| Basic validity | 87.0% | 90.5% |
|
||||
| Mean novelty distance | 0.661 | 0.642 |
|
||||
| Exact V-grade match | 27.5% | 33.5% |
|
||||
| Within ±1 V-grade | 66.0% | 67.5% |
|
||||
| Within ±2 V-grades | 91.0% | 90.0% |
|
||||
|---|---:|---:|
|
||||
| Generated routes evaluated | 200 | 200 |
|
||||
| Basic validity | 89.0% | 94.0% |
|
||||
| Strict validity | 89.0% | 94.0% |
|
||||
| Mean novelty distance | 0.656 | 0.634 |
|
||||
| Median novelty distance | 0.667 | 0.652 |
|
||||
| Mean generated hold count | 11.11 | 12.90 |
|
||||
| Mean route height | 130.76 | 142.32 |
|
||||
| Mean route width | 61.66 | 74.94 |
|
||||
| Mean hand-reach distance | 50.41 | 57.53 |
|
||||
|
||||
Grade consistency of generated climbs, measured by the trained grade predictor:
|
||||
|
||||
| Metric | Overall | TB2 | Kilter |
|
||||
|---|---:|---:|---:|
|
||||
| Exact requested V-grade | 28.2% | 29.5% | 27.0% |
|
||||
| Within ±1 V-grade | 70.8% | 68.5% | 73.0% |
|
||||
| Within ±2 V-grades | 92.0% | 90.5% | 93.5% |
|
||||
| Mean V-grade error | -- | -0.18 | -0.30 |
|
||||
|
||||
Interpretation: the generator is usually structurally valid and usually close to the requested grade according to the critic, but exact grade control remains imperfect. That is expected: this is a small GPT-style model trained on symbolic route data, not a production setter.
|
||||
|
||||
---
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### Board Namespacing
|
||||
|
||||
Hold tokens include the board prefix (`TB2_p344_start` vs `KILTER_p1084_start`). Placement 344 on TB2 is a completely different physical hold than placement 344 on Kilter — the prefix prevents ID collisions.
|
||||
|
||||
### Semantic Role Mapping
|
||||
|
||||
Different boards use different numeric role IDs, but they all map to the same semantic roles:
|
||||
|
||||
| Role | TB2 | Kilter |
|
||||
|---|---|---|
|
||||
| Start | 5 | 12 |
|
||||
| Middle | 6 | 13 |
|
||||
| Finish | 7 | 14 |
|
||||
| Foot | 8 | 15 |
|
||||
|
||||
This shared vocabulary lets the model learn transferable patterns across boards.
|
||||
|
||||
### Coordinate Features
|
||||
|
||||
Each hold token carries physical (x, y) position information that gets projected and added to token embeddings. This gives the model direct spatial knowledge — similar to how some vision-language models inject spatial features.
|
||||
|
||||
### Conditioning Tokens
|
||||
|
||||
Routes are prefixed with board, angle, and grade tokens. This is analogous to how modern LLMs use system prompts to condition generation.
|
||||
|
||||
---
|
||||
|
||||
## Repository Structure
|
||||
## Repository layout
|
||||
|
||||
```text
|
||||
ClimbingBoardGPT/
|
||||
├── configs/
|
||||
│ ├── tb2.json # Tension Board 2 configuration
|
||||
│ └── kilter.json # Kilter Board configuration
|
||||
│ ├── tb2.json
|
||||
│ └── kilter.json
|
||||
├── data/
|
||||
│ ├── raw/ # SQLite databases (not in repo)
|
||||
│ ├── raw/
|
||||
│ │ ├── tb2.db
|
||||
│ │ └── kilter.db
|
||||
│ └── processed/
|
||||
│ ├── tokenized/ # Tokenized route data
|
||||
│ ├── grade_prediction/ # Grade predictor outputs
|
||||
│ ├── generation/ # Generated route data
|
||||
│ └── evaluation/ # Evaluation results
|
||||
├── models/ # Saved model checkpoints
|
||||
├── images/
|
||||
│ ├── tb2_board_12x12_composite.png
|
||||
│ └── kilter-original-16x12_compose.png
|
||||
├── models/
|
||||
│ ├── joint_transformer_grade_predictor.pth
|
||||
│ └── joint_route_gpt_generator.pth
|
||||
├── notebooks/
|
||||
│ ├── 01_unified_route_tokenization.ipynb
|
||||
│ ├── 02_joint_transformer_grade_prediction.ipynb
|
||||
│ ├── 03_joint_nanogpt_route_generator.ipynb
|
||||
│ └── 04_generated_route_evaluation.ipynb
|
||||
├── scripts/
|
||||
│ ├── 01_tokenize_routes.py
|
||||
│ ├── 02_train_grade_predictor.py
|
||||
│ ├── 03_train_route_generator.py
|
||||
│ └── 04_evaluate_generated_routes.py
|
||||
│ ├── 04_evaluate_generated_routes.py
|
||||
│ ├── demo_generate_and_visualize.py
|
||||
│ ├── demo_generate_tb2.py
|
||||
│ ├── demo_generate_kilter.py
|
||||
│ ├── demo_predict_grade.py
|
||||
│ ├── demo_predict_tb2.py
|
||||
│ └── demo_predict_kilter.py
|
||||
├── src/climbingboardgpt/
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py # Board configuration loading
|
||||
│ ├── data.py # SQLite data loading
|
||||
│ ├── datasets.py # PyTorch Dataset classes
|
||||
│ ├── evaluation.py # Route evaluation functions
|
||||
│ ├── generation.py # Route generation logic
|
||||
│ ├── grades.py # Grade-to-V mapping
|
||||
│ ├── metrics.py # Evaluation metrics
|
||||
│ ├── models.py # Transformer architectures
|
||||
│ ├── paths.py # Project root detection
|
||||
│ ├── tokenization.py # Core tokenization logic
|
||||
│ └── utils.py # Utility functions
|
||||
├── README.md
|
||||
├── requirements.txt
|
||||
├── pyproject.toml
|
||||
└── README.md
|
||||
└── pyproject.toml
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Setup
|
||||
|
||||
Create and activate a virtual environment:
|
||||
|
||||
```bash
|
||||
# Clone the repo
|
||||
git clone https://github.com/yourusername/ClimbingBoardGPT.git
|
||||
cd ClimbingBoardGPT
|
||||
|
||||
# Create and activate virtual environment
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # Linux/Mac
|
||||
# .venv\Scripts\activate # Windows
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
# Install dependencies
|
||||
Install the package:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Retrieving Raw Databases
|
||||
|
||||
The project expects SQLite databases at `data/raw/tb2.db` and `data/raw/kilter.db`.
|
||||
|
||||
Using [BoardLib](https://github.com/lemeryf/BoardLib):
|
||||
For CPU-only demo use on a small VPS, the scripts support:
|
||||
|
||||
```bash
|
||||
pip install boardlib
|
||||
boardlib database tension data/raw/tb2.db
|
||||
boardlib database kilter data/raw/kilter.db
|
||||
--torch-threads 1
|
||||
```
|
||||
|
||||
This caps PyTorch CPU thread usage.
|
||||
|
||||
---
|
||||
|
||||
## Running the Pipeline
|
||||
## Data expected by the full training pipeline
|
||||
|
||||
### 1. Tokenize both boards
|
||||
The full tokenization/training pipeline expects raw BoardLib databases at:
|
||||
|
||||
```text
|
||||
data/raw/tb2.db
|
||||
data/raw/kilter.db
|
||||
```
|
||||
|
||||
The project configs are:
|
||||
|
||||
```text
|
||||
configs/tb2.json
|
||||
configs/kilter.json
|
||||
```
|
||||
|
||||
They define board-specific details such as:
|
||||
|
||||
- database path,
|
||||
- layout ID,
|
||||
- role IDs,
|
||||
- token prefix,
|
||||
- angle cutoff,
|
||||
- optional date / placement filters.
|
||||
|
||||
The demo scripts do **not** need the raw databases if the processed tokenization artifacts and trained model checkpoints already exist.
|
||||
|
||||
---
|
||||
|
||||
## Full training pipeline
|
||||
|
||||
From the repository root:
|
||||
|
||||
```bash
|
||||
python scripts/01_tokenize_routes.py --boards tb2,kilter
|
||||
```
|
||||
|
||||
Creates `data/processed/tokenized/` with vocabulary, route sequences, and metadata.
|
||||
|
||||
### 2. Train the grade predictor
|
||||
|
||||
```bash
|
||||
python scripts/02_train_grade_predictor.py
|
||||
```
|
||||
|
||||
Trains a BERT-style transformer encoder and saves to `models/joint_transformer_grade_predictor.pth`.
|
||||
|
||||
### 3. Train the route generator
|
||||
|
||||
```bash
|
||||
python scripts/03_train_route_generator.py
|
||||
```
|
||||
|
||||
Trains a GPT-style causal transformer and saves to `models/joint_route_gpt_generator.pth`.
|
||||
|
||||
### 4. Evaluate generated routes
|
||||
|
||||
```bash
|
||||
python scripts/04_evaluate_generated_routes.py
|
||||
```
|
||||
|
||||
Evaluates validity, novelty, geometry, and grade consistency. Saves results to `data/processed/evaluation/`.
|
||||
This produces the main processed artifacts and trained checkpoints.
|
||||
|
||||
---
|
||||
### Tokenization outputs
|
||||
|
||||
## Model Architectures
|
||||
|
||||
### JointRouteTransformerRegressor (Grade Prediction)
|
||||
|
||||
```
|
||||
Input: [CLS] BOARD ANGLE HOLDS...
|
||||
↓
|
||||
Token Embedding + Position Embedding + Coordinate Features
|
||||
↓
|
||||
Transformer Encoder (4 layers, 4 heads, d_model=128)
|
||||
↓
|
||||
[CLS] token output → Regression Head → difficulty score
|
||||
```text
|
||||
data/processed/tokenized/
|
||||
├── route_sequences.csv
|
||||
├── routes_tokenized.jsonl
|
||||
├── token_vocab.json
|
||||
├── token_metadata.csv
|
||||
├── placement_metadata.csv
|
||||
└── board_summary.csv
|
||||
```
|
||||
|
||||
- ~1.17M parameters
|
||||
- MSE loss, AdamW optimizer
|
||||
- Early stopping on validation MAE
|
||||
### Grade-prediction outputs
|
||||
|
||||
### JointRouteGPT (Route Generation)
|
||||
```text
|
||||
data/processed/grade_prediction/
|
||||
├── training_history.csv
|
||||
├── test_predictions.csv
|
||||
├── board_metrics.csv
|
||||
└── overall_metrics.json
|
||||
|
||||
```
|
||||
Input: BOS BOARD ANGLE GRADE HOLDS...
|
||||
↓
|
||||
Token Embedding + Position Embedding
|
||||
↓
|
||||
Causal Transformer (4 layers, 4 heads, d_embd=128)
|
||||
↓
|
||||
Language Modeling Head → next token logits
|
||||
models/
|
||||
└── joint_transformer_grade_predictor.pth
|
||||
```
|
||||
|
||||
- ~1.41M parameters
|
||||
- Cross-entropy loss, AdamW optimizer
|
||||
- Weight tying between embedding and output layers
|
||||
### Route-generation outputs
|
||||
|
||||
```text
|
||||
data/processed/generation/
|
||||
├── training_history.csv
|
||||
└── generated_routes.csv
|
||||
|
||||
models/
|
||||
└── joint_route_gpt_generator.pth
|
||||
```
|
||||
|
||||
### Generated-route evaluation outputs
|
||||
|
||||
```text
|
||||
data/processed/evaluation/
|
||||
├── generated_route_evaluation.csv
|
||||
└── top_generated_candidates.csv
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Board Configuration
|
||||
## Generate routes and visualize them
|
||||
|
||||
| Setting | TB2 Mirror | Kilter Original |
|
||||
|---|---:|---|
|
||||
| `layout_id` | 10 | 1 |
|
||||
| `token_prefix` | TB2 | KILTER |
|
||||
| `max_angle` | 50 | 55 |
|
||||
| `role_definitions` | start=5, middle=6, finish=7, foot=8 | start=12, middle=13, finish=14, foot=15 |
|
||||
| `include_mirror_placement_id` | true | false |
|
||||
| `min_fa_date` | null | 2016-01-01 |
|
||||
After training the route generator, or after placing a trained checkpoint at:
|
||||
|
||||
To add a new board, create a JSON config in `configs/` following the same format.
|
||||
```text
|
||||
models/joint_route_gpt_generator.pth
|
||||
```
|
||||
|
||||
you can generate and visualize climbs.
|
||||
|
||||
### TB2
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_tb2.py --angle 40 --grade 6 --n 4
|
||||
```
|
||||
|
||||
### Kilter
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_kilter.py --angle 40 --grade 6 --n 4
|
||||
```
|
||||
|
||||
### Generic version
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_and_visualize.py \
|
||||
--board tb2 \
|
||||
--angle 40 \
|
||||
--grade 6 \
|
||||
--n 4 \
|
||||
--temperature 0.9 \
|
||||
--top-k 50
|
||||
```
|
||||
|
||||
Outputs are written to:
|
||||
|
||||
```text
|
||||
outputs/demo_routes/<board>/angle_<angle>/V<grade>/
|
||||
├── generated_routes.csv
|
||||
├── generated_route_001.png
|
||||
├── generated_route_001.svg
|
||||
├── generated_route_002.png
|
||||
├── generated_route_002.svg
|
||||
└── ...
|
||||
```
|
||||
|
||||
### Generated-route visualization
|
||||
|
||||
The visualization uses calibrated board backgrounds:
|
||||
|
||||
```text
|
||||
images/tb2_board_12x12_composite.png
|
||||
images/kilter-original-16x12_compose.png
|
||||
```
|
||||
|
||||
These are overlaid using product-size coordinate windows:
|
||||
|
||||
```text
|
||||
TB2: x = [-68, 68], y = [0, 144]
|
||||
Kilter: x = [-24, 168], y = [0, 156]
|
||||
```
|
||||
|
||||
These extents match the old visualization notebooks better than simply using the min/max of observed hold coordinates, because the hold coordinates are inset from the product boundary.
|
||||
|
||||
The role markers are:
|
||||
|
||||
| Role | Marker |
|
||||
|---|---|
|
||||
| start | green circle |
|
||||
| middle | blue circle |
|
||||
| finish | red star |
|
||||
| foot | small yellow square |
|
||||
|
||||
|
||||
### Annotate holds
|
||||
|
||||
To label route holds by placement ID:
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_tb2.py \
|
||||
--angle 40 \
|
||||
--grade 6 \
|
||||
--n 2 \
|
||||
--annotate
|
||||
```
|
||||
|
||||
### CPU/VPS-friendly run
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_tb2.py \
|
||||
--angle 40 \
|
||||
--grade 6 \
|
||||
--n 2 \
|
||||
--torch-threads 1
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Comparison with Classical Approach
|
||||
## Temperature and sampling
|
||||
|
||||
The earlier TB2 project used hand-engineered features with Random Forest and neural networks. This project replaces feature engineering with transformer attention:
|
||||
The `--temperature` argument controls generation randomness.
|
||||
|
||||
| Aspect | Classical (TB2 Notebooks 01-06) | Transformer (This Project) |
|
||||
|---|---|---|
|
||||
| Input | 30+ engineered features | Raw token sequences |
|
||||
| Feature engineering | Manual (spatial, geometric) | Learned via attention |
|
||||
| Board handling | Single board (TB2) | Joint model with board token |
|
||||
| Grade prediction | Random Forest / MLP | Transformer encoder |
|
||||
| Route generation | Not supported | GPT-style decoder |
|
||||
| Interpretability | Feature importance | Attention weights |
|
||||
The model predicts probabilities for the next token. Temperature rescales those probabilities before sampling.
|
||||
|
||||
| Temperature | Effect |
|
||||
|---:|---|
|
||||
| `0.3`–`0.6` | conservative; picks safer/common tokens |
|
||||
| `0.9` | balanced default |
|
||||
| `1.0` | samples directly from the learned probabilities |
|
||||
| `1.1`–`1.3` | more exploratory; can produce weirder climbs |
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_kilter.py \
|
||||
--angle 40 \
|
||||
--grade 6 \
|
||||
--n 4 \
|
||||
--temperature 0.6
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Future Extensions
|
||||
## Predict grade from board, angle, and frames string
|
||||
|
||||
- **Masked hold prediction**: Mask holds and predict them (like BERT's MLM)
|
||||
- **Stronger legality constraints**: Enforce valid start/finish positions in generation
|
||||
- **Board transfer experiments**: Train on TB2, evaluate on Kilter (zero-shot)
|
||||
- **GUI for route generation**: Interactive tool to generate and visualize climbs
|
||||
- **Integration with classical features**: Combine transformer embeddings with engineered features
|
||||
After training the grade predictor, or after placing a trained checkpoint at:
|
||||
|
||||
```text
|
||||
models/joint_transformer_grade_predictor.pth
|
||||
```
|
||||
|
||||
you can predict a grade directly from a BoardLib-style frames string.
|
||||
|
||||
### Generic
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_grade.py \
|
||||
--board tb2 \
|
||||
--angle 40 \
|
||||
--frames 'p652r5p631r6p322r6p326r7'
|
||||
```
|
||||
|
||||
### TB2 wrapper
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_tb2.py \
|
||||
--angle 40 \
|
||||
--frames 'p652r5p631r6p322r6p326r7'
|
||||
```
|
||||
|
||||
### Kilter wrapper
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_kilter.py \
|
||||
--angle 40 \
|
||||
--frames 'p1127r12p1196r13p1216r13p1388r14'
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```text
|
||||
Board: Tension Board 2 Mirror (tb2)
|
||||
Angle: 40°
|
||||
Frames: p652r5p631r6p322r6p326r7
|
||||
Predicted: V6
|
||||
Difficulty: 22.400
|
||||
```
|
||||
|
||||
The `Predicted` line is the grouped V-grade. The `Difficulty` line is the model's continuous prediction in the underlying BoardLib difficulty scale.
|
||||
|
||||
### JSON output
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_grade.py \
|
||||
--board kilter \
|
||||
--angle 40 \
|
||||
--frames 'p1127r12p1196r13p1216r13p1388r14' \
|
||||
--json
|
||||
```
|
||||
|
||||
### Show model tokens
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_tb2.py \
|
||||
--angle 40 \
|
||||
--frames 'p652r5p631r6p322r6p326r7' \
|
||||
--show-tokens
|
||||
```
|
||||
|
||||
### Save a visualization of the input climb
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_tb2.py \
|
||||
--angle 40 \
|
||||
--frames 'p652r5p631r6p322r6p326r7' \
|
||||
--visualize
|
||||
```
|
||||
|
||||
This writes:
|
||||
|
||||
```text
|
||||
outputs/grade_predictions/<board>/angle_<angle>/
|
||||
├── <name>.png
|
||||
├── <name>.svg
|
||||
└── <name>.json
|
||||
```
|
||||
|
||||
Example with custom output name:
|
||||
|
||||
```bash
|
||||
python scripts/demo_predict_kilter.py \
|
||||
--angle 40 \
|
||||
--frames 'p1127r12p1196r13p1216r13p1388r14' \
|
||||
--visualize \
|
||||
--output-name my_kilter_climb
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Acknowledgments
|
||||
## Grade prediction in generated-route visualizations
|
||||
|
||||
- Board data from [Tension Climbing](https://tensionclimbing.com/) and [Kilter Board](https://kilterboard.com/)
|
||||
- Database access via [BoardLib](https://github.com/lemeryf/BoardLib)
|
||||
- Original TB2 analysis notebooks for foundational data exploration
|
||||
If both checkpoints exist:
|
||||
|
||||
```text
|
||||
models/joint_route_gpt_generator.pth
|
||||
models/joint_transformer_grade_predictor.pth
|
||||
```
|
||||
|
||||
then the generation demo automatically scores each generated climb with the grade predictor.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_tb2.py --angle 40 --grade 6 --n 4
|
||||
```
|
||||
|
||||
The terminal output includes something like:
|
||||
|
||||
```text
|
||||
predicted=V5 (difficulty=20.81, error=-1 V)
|
||||
```
|
||||
|
||||
The visualization subtitle also includes:
|
||||
|
||||
```text
|
||||
predicted V5 (20.81) | error -1V
|
||||
```
|
||||
|
||||
To disable this scoring:
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_tb2.py \
|
||||
--angle 40 \
|
||||
--grade 6 \
|
||||
--n 4 \
|
||||
--no-grade-prediction
|
||||
```
|
||||
|
||||
To use a non-default grade predictor:
|
||||
|
||||
```bash
|
||||
python scripts/demo_generate_and_visualize.py \
|
||||
--board kilter \
|
||||
--angle 40 \
|
||||
--grade 6 \
|
||||
--grade-model-path models/joint_transformer_grade_predictor.pth
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Important caveats
|
||||
|
||||
Generated climbs are **machine-generated candidates**, not guaranteed to be safe, good, or fun.
|
||||
|
||||
The grade predictor is a model-based estimate, not ground truth. Climbing grades are noisy and subjective, and board climbs can be highly style-dependent.
|
||||
|
||||
The route sequence is a canonical ordering of holds, not necessarily actual beta order. This is fine for symbolic modeling, but it should not be interpreted as the intended movement sequence.
|
||||
|
||||
The visualizations are calibrated to match the existing board images, but any change in image file, crop, or coordinate convention may require adjusting board extents in:
|
||||
|
||||
```text
|
||||
src/climbingboardgpt/visualization.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Next step: webapp demo
|
||||
|
||||
The next planned layer is a simple webapp with:
|
||||
|
||||
1. grade prediction from board + angle + frames string,
|
||||
2. route generation from board + angle + target grade,
|
||||
3. rendered PNG output for both generated climbs and user-submitted climbs.
|
||||
|
||||
The webapp should use the same backend helpers already added here:
|
||||
|
||||
```text
|
||||
load_route_generator(...)
|
||||
generate_route(...)
|
||||
load_grade_predictor(...)
|
||||
predict_frames_grade(...)
|
||||
visualize_route_tokens(...)
|
||||
```
|
||||
|
||||
# License
|
||||
This project is licensed under the MIT License. See the [`LICENSE`](LICENSE) file for details.
|
||||
|
||||
The project is for educational purposes. Climb data belongs to Tension Climbing and Kilter respectively.
|
||||
@@ -19,7 +19,7 @@
|
||||
"notes": [
|
||||
"Matches the Kilter Original layout used in the earlier Kilter analysis.",
|
||||
"The modeling cutoff uses wall angles <= 55 degrees.",
|
||||
"The first-ascent date filter keeps rows after 2016-01-01.",
|
||||
"Placement metadata is restricted to y <= 156, following the earlier notebook pipeline."
|
||||
"First-ascent dates are filtered to after 2016-01-01.",
|
||||
"Placement metadata is restricted to y <= 156."
|
||||
]
|
||||
}
|
||||
@@ -80,7 +80,7 @@
|
||||
" make_placement_lookup,\n",
|
||||
" vocab_payload,\n",
|
||||
")\n",
|
||||
"from climbingboardgpt.utils import safe_train_test_split, write_json, json_safe"
|
||||
"from climbingboardgpt.utils import assign_group_splits, write_json, json_safe"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -314,30 +314,17 @@
|
||||
"source": [
|
||||
"df_routes[\"ids_with_grade\"] = df_routes[\"tokens_with_grade\"].apply(lambda tokens: encode(tokens, stoi))\n",
|
||||
"df_routes[\"ids_no_grade\"] = df_routes[\"tokens_no_grade\"].apply(lambda tokens: encode(tokens, stoi))\n",
|
||||
"\n",
|
||||
"df_routes[\"split_stratum\"] = df_routes[\"board_key\"].astype(str) + \"__V\" + df_routes[\"grouped_v\"].astype(str)\n",
|
||||
"\n",
|
||||
"train_df, temp_df = safe_train_test_split(\n",
|
||||
"df_routes[\"split\"] = assign_group_splits(\n",
|
||||
" df_routes,\n",
|
||||
" group_cols=[\"board_key\", \"uuid\"],\n",
|
||||
" test_size=0.20,\n",
|
||||
" random_state=42,\n",
|
||||
" stratify_col=\"split_stratum\",\n",
|
||||
")\n",
|
||||
"val_df, test_df = safe_train_test_split(\n",
|
||||
" temp_df,\n",
|
||||
" test_size=0.50,\n",
|
||||
" random_state=42,\n",
|
||||
" val_size_within_temp=0.50,\n",
|
||||
" random_state=3,\n",
|
||||
" stratify_col=\"split_stratum\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"split_map = {}\n",
|
||||
"split_map.update({uuid: \"train\" for uuid in train_df[\"uuid\"]})\n",
|
||||
"split_map.update({uuid: \"val\" for uuid in val_df[\"uuid\"]})\n",
|
||||
"split_map.update({uuid: \"test\" for uuid in test_df[\"uuid\"]})\n",
|
||||
"df_routes[\"split\"] = df_routes[\"uuid\"].map(split_map)\n",
|
||||
"\n",
|
||||
"print(\"Split counts by board:\")\n",
|
||||
"print(df_routes.groupby([\"board_key\", \"split\"]).size().unstack(fill_value=0))"
|
||||
"df_routes.groupby([\"board_key\", \"split\"]).size().unstack(fill_value=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -457,8 +444,16 @@
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"version": "3.11"
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "climbingboardgpt"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
description = "Unified TB2/Kilter transformer route modeling, grade prediction, and GPT-style route generation."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -60,7 +60,7 @@ from climbingboardgpt.tokenization import (
|
||||
make_placement_lookup,
|
||||
vocab_payload,
|
||||
)
|
||||
from climbingboardgpt.utils import json_safe, safe_train_test_split, set_seed, write_json
|
||||
from climbingboardgpt.utils import assign_group_splits, json_safe, set_seed, write_json
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
@@ -101,8 +101,8 @@ Examples:
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Random seed for reproducible splits (default: 42)",
|
||||
default=3,
|
||||
help="Random seed for reproducible splits (default: 3)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -244,41 +244,33 @@ def main() -> None:
|
||||
df_routes["ids_with_grade"] = df_routes["tokens_with_grade"].apply(lambda tokens: encode(tokens, stoi))
|
||||
df_routes["ids_no_grade"] = df_routes["tokens_no_grade"].apply(lambda tokens: encode(tokens, stoi))
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 6: Train/val/test split (stratified)
|
||||
# Step 6: Train/val/test split (grouped by logical climb)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# We split 80/10/10, stratified by board_key × grouped_v.
|
||||
# This ensures both boards and all difficulty levels are represented
|
||||
# in each split, which is critical for fair evaluation.
|
||||
# A single climb UUID can appear at multiple wall angles. We therefore
|
||||
# split by (board_key, uuid), not by individual rows. This avoids putting
|
||||
# one angle of a climb in train and another angle of the same climb in test.
|
||||
#
|
||||
# Stratification prevents scenarios like "all V14 climbs end up in
|
||||
# the test set while training has none."
|
||||
# The split is stratified by board_key × grouped_v at the group level when
|
||||
# possible. The row proportions may differ slightly from 80/10/10 because
|
||||
# some climbs have more angle entries than others, but this is preferable
|
||||
# to route-level leakage or brittle UUID-overwrite logic.
|
||||
df_routes["split_stratum"] = (
|
||||
df_routes["board_key"].astype(str)
|
||||
+ "__V"
|
||||
+ df_routes["grouped_v"].astype(str)
|
||||
)
|
||||
|
||||
train_df, temp_df = safe_train_test_split(
|
||||
df_routes["split"] = assign_group_splits(
|
||||
df_routes,
|
||||
group_cols=["board_key", "uuid"],
|
||||
test_size=0.20,
|
||||
random_state=args.seed,
|
||||
stratify_col="split_stratum",
|
||||
)
|
||||
val_df, test_df = safe_train_test_split(
|
||||
temp_df,
|
||||
test_size=0.50,
|
||||
val_size_within_temp=0.50,
|
||||
random_state=args.seed,
|
||||
stratify_col="split_stratum",
|
||||
)
|
||||
|
||||
split_map = {}
|
||||
split_map.update({uuid: "train" for uuid in train_df["uuid"]})
|
||||
split_map.update({uuid: "val" for uuid in val_df["uuid"]})
|
||||
split_map.update({uuid: "test" for uuid in test_df["uuid"]})
|
||||
df_routes["split"] = df_routes["uuid"].map(split_map)
|
||||
|
||||
print(f"\nSplit counts:")
|
||||
print("\nSplit counts:")
|
||||
print(df_routes.groupby(["board_key", "split"]).size().unstack(fill_value=0))
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
@@ -357,4 +349,4 @@ def main() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@@ -99,7 +99,7 @@ accuracy (within ±1 V-grade).
|
||||
parser.add_argument("--num-layers", type=int, default=4, help="Number of transformer layers")
|
||||
parser.add_argument("--dim-feedforward", type=int, default=256, help="Feedforward dimension")
|
||||
parser.add_argument("--dropout", type=float, default=0.10, help="Dropout probability")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--seed", type=int, default=3, help="Random seed")
|
||||
parser.add_argument("--device", type=str, default=None, help="Device (cpu or cuda)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ specific board, or leave unset to generate for all boards.
|
||||
parser.add_argument("--generate-board", type=str, default=None, help="Board key: tb2 or kilter")
|
||||
parser.add_argument("--generate-angles", type=str, default=None, help="Comma-separated angles")
|
||||
parser.add_argument("--generate-grades", type=str, default=None, help="Comma-separated V-grades")
|
||||
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
||||
parser.add_argument("--seed", type=int, default=3, help="Random seed")
|
||||
parser.add_argument("--device", type=str, default=None, help="Device (cpu or cuda)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -202,7 +202,15 @@ def main() -> None:
|
||||
lambda records: frozenset(int(record["placement_id"]) for record in records)
|
||||
)
|
||||
|
||||
validity = pd.DataFrame(df_generated["hold_records"].apply(validity_from_records).tolist())
|
||||
validity = pd.DataFrame(
|
||||
df_generated.apply(
|
||||
lambda row: validity_from_records(
|
||||
row["hold_records"],
|
||||
requested_board_prefix=row.get("requested_board_prefix"),
|
||||
),
|
||||
axis=1,
|
||||
).tolist()
|
||||
)
|
||||
df_eval = pd.concat([df_generated.reset_index(drop=True), validity], axis=1)
|
||||
|
||||
print(f"Evaluated generated routes: {len(df_eval):,}")
|
||||
|
||||
198
scripts/demo_generate_and_visualize.py
Normal file
198
scripts/demo_generate_and_visualize.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate ClimbingBoardGPT routes and save board visualizations.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Generate four TB2 V6 climbs at 40 degrees:
|
||||
|
||||
python scripts/demo_generate_and_visualize.py --board tb2 --angle 40 --grade 6 --n 4
|
||||
|
||||
Generate Kilter climbs with placement labels:
|
||||
|
||||
python scripts/demo_generate_and_visualize.py --board kilter --angle 35 --grade 5 --annotate
|
||||
|
||||
The script writes:
|
||||
- generated_routes.csv
|
||||
- generated_route_001.png
|
||||
- generated_route_001.svg
|
||||
- ...
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(REPO_ROOT / "src"))
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from climbingboardgpt.inference import (
|
||||
generate_route,
|
||||
load_board_for_demo,
|
||||
load_grade_predictor,
|
||||
load_route_generator,
|
||||
predict_route_grade,
|
||||
)
|
||||
from climbingboardgpt.visualization import load_token_metadata, visualize_route_result
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate ClimbingBoardGPT routes and save route visualizations.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--board", choices=["tb2", "kilter"], required=True)
|
||||
parser.add_argument("--angle", type=int, default=40)
|
||||
parser.add_argument("--grade", type=int, default=6, help="Target grouped V-grade.")
|
||||
parser.add_argument("--n", type=int, default=4, help="Number of routes to sample.")
|
||||
parser.add_argument("--temperature", type=float, default=0.9)
|
||||
parser.add_argument("--top-k", type=int, default=50)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=40)
|
||||
parser.add_argument("--annotate", action="store_true", help="Label route holds by placement ID.")
|
||||
parser.add_argument("--device", type=str, default=None, help="cpu, cuda, or omit for auto.")
|
||||
parser.add_argument("--torch-threads", type=int, default=None, help="Optional CPU thread cap for VPS demos.")
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "models" / "joint_route_gpt_generator.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grade-model-path",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "models" / "joint_transformer_grade_predictor.pth",
|
||||
help="Optional grade-predictor checkpoint used to score generated routes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-grade-prediction",
|
||||
action="store_true",
|
||||
help="Skip grade-predictor scoring even if the checkpoint exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenized-dir",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "data" / "processed" / "tokenized",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "outputs" / "demo_routes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--background-image",
|
||||
type=Path,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional board image to draw under the scatter plot. "
|
||||
"If omitted, the script automatically uses images/tb2_board_12x12_composite.png "
|
||||
"for TB2 and images/kilter-original-16x12_compose.png for Kilter when present."
|
||||
),
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
||||
def default_background_for_board(board: str) -> Path | None:
|
||||
candidates = {
|
||||
"tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png",
|
||||
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_compose.png",
|
||||
}
|
||||
path = candidates.get(board)
|
||||
return path if path is not None and path.exists() else None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
board_config = load_board_for_demo(args.board, config_dir=REPO_ROOT / "configs")
|
||||
generator = load_route_generator(args.model_path, device=args.device, torch_threads=args.torch_threads)
|
||||
token_meta = load_token_metadata(args.tokenized_dir)
|
||||
background_image = args.background_image or default_background_for_board(args.board)
|
||||
|
||||
grade_predictor = None
|
||||
if not args.no_grade_prediction:
|
||||
if args.grade_model_path.exists():
|
||||
grade_predictor = load_grade_predictor(
|
||||
args.grade_model_path,
|
||||
device=args.device,
|
||||
torch_threads=args.torch_threads,
|
||||
)
|
||||
else:
|
||||
print(f"Grade predictor not found at {args.grade_model_path}; skipping grade prediction.")
|
||||
|
||||
run_dir = args.out_dir / args.board / f"angle_{args.angle}" / f"V{args.grade}"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rows = []
|
||||
for i in range(1, args.n + 1):
|
||||
result = generate_route(
|
||||
generator=generator,
|
||||
board_config=board_config,
|
||||
angle=args.angle,
|
||||
grade=args.grade,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
)
|
||||
|
||||
if grade_predictor is not None:
|
||||
grade_result = predict_route_grade(grade_predictor, result["tokens"])
|
||||
result.update(grade_result)
|
||||
result["critic_v_error"] = (
|
||||
int(result["predicted_grouped_v"]) - int(result["requested_grouped_v"])
|
||||
)
|
||||
|
||||
rows.append(result)
|
||||
|
||||
stem = f"generated_route_{i:03d}"
|
||||
png_path = run_dir / f"{stem}.png"
|
||||
svg_path = run_dir / f"{stem}.svg"
|
||||
|
||||
fig, _, _ = visualize_route_result(
|
||||
result,
|
||||
df_token_meta=token_meta,
|
||||
output_path=png_path,
|
||||
annotate=args.annotate,
|
||||
background_image=background_image,
|
||||
)
|
||||
fig.savefig(svg_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
print(f"[{i}/{args.n}] {result['frames']}")
|
||||
print(f" valid={result['basic_valid']} holds={result['n_hold_tokens']}")
|
||||
if "predicted_grouped_v" in result:
|
||||
print(
|
||||
f" predicted=V{result['predicted_grouped_v']} "
|
||||
f"(difficulty={result['predicted_display_difficulty']:.2f}, "
|
||||
f"error={result['critic_v_error']:+d} V)"
|
||||
)
|
||||
try:
|
||||
png_display = png_path.resolve().relative_to(REPO_ROOT.resolve())
|
||||
except Exception:
|
||||
png_display = png_path
|
||||
print(f" saved {png_display}")
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
df["tokens_json"] = df["tokens"].apply(json.dumps)
|
||||
df.drop(columns=["tokens"]).to_csv(run_dir / "generated_routes.csv", index=False)
|
||||
|
||||
if background_image is not None:
|
||||
try:
|
||||
bg_display = background_image.relative_to(REPO_ROOT)
|
||||
except Exception:
|
||||
bg_display = background_image
|
||||
print(f"Using background image: {bg_display}")
|
||||
else:
|
||||
print("Using background image: none (coordinate-board style only)")
|
||||
|
||||
print("\nSaved route table:")
|
||||
print(run_dir / "generated_routes.csv")
|
||||
print("\nOutput directory:")
|
||||
print(run_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
scripts/demo_generate_kilter.py
Normal file
19
scripts/demo_generate_kilter.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convenience wrapper: generate and visualize Kilter routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(REPO_ROOT / "scripts" / "demo_generate_and_visualize.py"),
|
||||
"--board",
|
||||
"kilter",
|
||||
*sys.argv[1:],
|
||||
]
|
||||
raise SystemExit(subprocess.call(cmd))
|
||||
19
scripts/demo_generate_tb2.py
Normal file
19
scripts/demo_generate_tb2.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convenience wrapper: generate and visualize TB2 routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(REPO_ROOT / "scripts" / "demo_generate_and_visualize.py"),
|
||||
"--board",
|
||||
"tb2",
|
||||
*sys.argv[1:],
|
||||
]
|
||||
raise SystemExit(subprocess.call(cmd))
|
||||
180
scripts/demo_predict_grade.py
Normal file
180
scripts/demo_predict_grade.py
Normal file
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Predict a climb grade from board, angle, and BoardLib frames string.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Generic:
|
||||
|
||||
python scripts/demo_predict_grade.py \
|
||||
--board tb2 \
|
||||
--angle 40 \
|
||||
--frames 'p652r5p631r6p322r6p326r7'
|
||||
|
||||
TB2 wrapper:
|
||||
|
||||
python scripts/demo_predict_tb2.py \
|
||||
--angle 40 \
|
||||
--frames 'p652r5p631r6p322r6p326r7'
|
||||
|
||||
Kilter wrapper:
|
||||
|
||||
python scripts/demo_predict_kilter.py \
|
||||
--angle 40 \
|
||||
--frames 'p1127r12p1196r13p1388r14'
|
||||
|
||||
Add ``--visualize`` to save a PNG/SVG overlay using the board background.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(REPO_ROOT / "src"))
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from climbingboardgpt.inference import (
|
||||
frames_to_grade_model_tokens,
|
||||
load_board_for_demo,
|
||||
load_grade_predictor,
|
||||
predict_frames_grade,
|
||||
)
|
||||
from climbingboardgpt.visualization import load_token_metadata, visualize_route_tokens
|
||||
|
||||
|
||||
def default_background_for_board(board: str) -> Path | None:
|
||||
candidates = {
|
||||
"tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png",
|
||||
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_compose.png",
|
||||
}
|
||||
path = candidates.get(board)
|
||||
return path if path is not None and path.exists() else None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Predict climb grade from board, angle, and frames string.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--board", choices=["tb2", "kilter"], required=True)
|
||||
parser.add_argument("--angle", type=int, required=True)
|
||||
parser.add_argument("--frames", type=str, required=True)
|
||||
parser.add_argument("--device", type=str, default=None, help="cpu, cuda, or omit for auto.")
|
||||
parser.add_argument("--torch-threads", type=int, default=None, help="Optional CPU thread cap.")
|
||||
parser.add_argument(
|
||||
"--grade-model-path",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "models" / "joint_transformer_grade_predictor.pth",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenized-dir",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "data" / "processed" / "tokenized",
|
||||
)
|
||||
parser.add_argument("--json", action="store_true", help="Print JSON instead of human-readable text.")
|
||||
parser.add_argument("--show-tokens", action="store_true", help="Print the model token sequence.")
|
||||
parser.add_argument("--visualize", action="store_true", help="Save a board-background visualization.")
|
||||
parser.add_argument("--annotate", action="store_true", help="Label route holds by placement ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=Path,
|
||||
default=REPO_ROOT / "outputs" / "grade_predictions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Output image/table stem. Defaults to <board>_angle_<angle>_prediction.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--background-image",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Optional background image override.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
|
||||
board_config = load_board_for_demo(args.board, config_dir=REPO_ROOT / "configs")
|
||||
token_meta = load_token_metadata(args.tokenized_dir)
|
||||
predictor = load_grade_predictor(
|
||||
args.grade_model_path,
|
||||
device=args.device,
|
||||
torch_threads=args.torch_threads,
|
||||
)
|
||||
|
||||
result = predict_frames_grade(
|
||||
grade_predictor=predictor,
|
||||
frames=args.frames,
|
||||
angle=args.angle,
|
||||
board_config=board_config,
|
||||
df_token_meta=token_meta,
|
||||
)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print(f"Board: {result['board_display_name']} ({result['board_key']})")
|
||||
print(f"Angle: {result['requested_angle']}°")
|
||||
print(f"Frames: {result['frames']}")
|
||||
print(f"Predicted: V{result['predicted_grouped_v']}")
|
||||
print(f"Difficulty: {result['predicted_display_difficulty']:.3f}")
|
||||
if args.show_tokens:
|
||||
print()
|
||||
print("Model tokens:")
|
||||
print(result["sequence"])
|
||||
|
||||
if args.visualize:
|
||||
out_dir = args.out_dir / args.board / f"angle_{args.angle}"
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stem = args.output_name or f"{args.board}_angle_{args.angle}_prediction"
|
||||
png_path = out_dir / f"{stem}.png"
|
||||
svg_path = out_dir / f"{stem}.svg"
|
||||
json_path = out_dir / f"{stem}.json"
|
||||
|
||||
background_image = args.background_image or default_background_for_board(args.board)
|
||||
title = (
|
||||
f"{result['board_display_name']} predicted "
|
||||
f"V{result['predicted_grouped_v']} @ {args.angle}°"
|
||||
)
|
||||
subtitle = (
|
||||
f"difficulty={result['predicted_display_difficulty']:.2f} | "
|
||||
f"frames={args.frames}"
|
||||
)
|
||||
|
||||
fig, _, _ = visualize_route_tokens(
|
||||
tokens=result["tokens"],
|
||||
df_token_meta=token_meta,
|
||||
board_key=args.board,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
output_path=png_path,
|
||||
annotate=args.annotate,
|
||||
background_image=background_image,
|
||||
)
|
||||
fig.savefig(svg_path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
json_path.write_text(json.dumps(result, indent=2), encoding="utf-8")
|
||||
|
||||
print()
|
||||
if background_image is not None:
|
||||
try:
|
||||
bg_display = background_image.relative_to(REPO_ROOT)
|
||||
except Exception:
|
||||
bg_display = background_image
|
||||
print(f"Using background image: {bg_display}")
|
||||
print(f"Saved PNG: {png_path}")
|
||||
print(f"Saved SVG: {svg_path}")
|
||||
print(f"Saved JSON: {json_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
scripts/demo_predict_kilter.py
Normal file
19
scripts/demo_predict_kilter.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convenience wrapper: predict grade for a Kilter frames string."""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(REPO_ROOT / "scripts" / "demo_predict_grade.py"),
|
||||
"--board",
|
||||
"kilter",
|
||||
*sys.argv[1:],
|
||||
]
|
||||
raise SystemExit(subprocess.call(cmd))
|
||||
19
scripts/demo_predict_tb2.py
Normal file
19
scripts/demo_predict_tb2.py
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convenience wrapper: predict grade for a TB2 frames string."""
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(REPO_ROOT / "scripts" / "demo_predict_grade.py"),
|
||||
"--board",
|
||||
"tb2",
|
||||
*sys.argv[1:],
|
||||
]
|
||||
raise SystemExit(subprocess.call(cmd))
|
||||
@@ -42,17 +42,19 @@ def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
return rows
|
||||
|
||||
|
||||
def validity_from_records(records: list[dict[str, object]]) -> dict[str, object]:
|
||||
def validity_from_records(records: list[dict[str, object]], requested_board_prefix: str | None = None) -> dict[str, object]:
|
||||
placements = [int(record["placement_id"]) for record in records]
|
||||
roles = [str(record["role"]) for record in records]
|
||||
prefixes = [str(record["board_token_prefix"]) for record in records]
|
||||
one_board_only = len(set(prefixes)) <= 1
|
||||
matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)
|
||||
|
||||
out = {
|
||||
"n_holds_eval": len(records),
|
||||
"n_unique_placements_eval": len(set(placements)),
|
||||
"has_duplicate_placements_eval": len(records) != len(set(placements)),
|
||||
"one_board_only_eval": one_board_only,
|
||||
"matches_requested_board_eval": matches_requested_board,
|
||||
"n_start_eval": roles.count("start"),
|
||||
"n_middle_eval": roles.count("middle"),
|
||||
"n_foot_eval": roles.count("foot"),
|
||||
|
||||
@@ -77,13 +77,14 @@ def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
return rows
|
||||
|
||||
|
||||
def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
|
||||
def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:
|
||||
records = hold_records(tokens)
|
||||
placements = [record["placement_id"] for record in records]
|
||||
roles = [record["role"] for record in records]
|
||||
prefixes = [record["board_prefix"] for record in records]
|
||||
|
||||
one_board_only = len(set(prefixes)) <= 1
|
||||
matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)
|
||||
no_duplicates = len(placements) == len(set(placements))
|
||||
has_start = "start" in roles
|
||||
has_finish = "finish" in roles
|
||||
@@ -94,6 +95,7 @@ def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
|
||||
"n_unique_placements": len(set(placements)),
|
||||
"has_duplicate_placements": not no_duplicates,
|
||||
"one_board_only": one_board_only,
|
||||
"matches_requested_board": matches_requested_board,
|
||||
"has_start": has_start,
|
||||
"has_middle": "middle" in roles,
|
||||
"has_finish": has_finish,
|
||||
@@ -101,14 +103,16 @@ def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
|
||||
"n_middle": roles.count("middle"),
|
||||
"n_foot": roles.count("foot"),
|
||||
"n_finish": roles.count("finish"),
|
||||
"basic_valid": bool(one_board_only and no_duplicates and has_start and has_finish and enough_holds),
|
||||
"basic_valid": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),
|
||||
}
|
||||
|
||||
|
||||
def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int]) -> str:
|
||||
def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:
|
||||
pieces = []
|
||||
seen = set()
|
||||
for record in hold_records(tokens):
|
||||
if board_prefix is not None and str(record["board_prefix"]) != board_prefix:
|
||||
continue
|
||||
placement_id = int(record["placement_id"])
|
||||
role = str(record["role"])
|
||||
if placement_id in seen or role not in role_name_to_id:
|
||||
@@ -154,7 +158,7 @@ def generate_one(
|
||||
forbidden_ids=forbidden_ids,
|
||||
)
|
||||
tokens = [itos.get(int(idx), "<UNK>") for idx in token_ids]
|
||||
validity = validity_summary(tokens)
|
||||
validity = validity_summary(tokens, requested_board_prefix=board_prefix)
|
||||
|
||||
return {
|
||||
"requested_board_prefix": board_prefix,
|
||||
@@ -164,6 +168,6 @@ def generate_one(
|
||||
"top_k": None if top_k is None else int(top_k),
|
||||
"tokens": tokens,
|
||||
"sequence": " ".join(tokens),
|
||||
"frames": generated_tokens_to_frames(tokens, role_name_to_id),
|
||||
"frames": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),
|
||||
**validity,
|
||||
}
|
||||
|
||||
335
src/climbingboardgpt/inference.py
Normal file
335
src/climbingboardgpt/inference.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Inference helpers for ClimbingBoardGPT demos.
|
||||
|
||||
This module is intentionally small and webapp-friendly: it loads trained
|
||||
checkpoints once, keeps them in memory, and exposes route generation helpers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from .config import BoardConfig, load_board_config
|
||||
from .generation import generate_one
|
||||
from .grades import to_grouped_v
|
||||
from .models import JointRouteGPT, JointRouteTransformerRegressor
|
||||
from .tokenization import (
|
||||
angle_token,
|
||||
board_token,
|
||||
canonicalize_holds,
|
||||
hold_token,
|
||||
parse_frames,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedGenerator:
|
||||
"""A loaded GPT-style route generator plus vocabulary metadata."""
|
||||
|
||||
model: JointRouteGPT
|
||||
stoi: dict[str, int]
|
||||
itos: dict[int, str]
|
||||
device: torch.device
|
||||
checkpoint_path: Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedGradePredictor:
|
||||
"""A loaded transformer grade predictor plus vocabulary metadata."""
|
||||
|
||||
model: JointRouteTransformerRegressor
|
||||
stoi: dict[str, int]
|
||||
itos: dict[int, str]
|
||||
device: torch.device
|
||||
checkpoint_path: Path
|
||||
max_len: int
|
||||
pad_id: int
|
||||
unk_id: int
|
||||
|
||||
|
||||
def load_grade_predictor(
|
||||
checkpoint_path: str | Path,
|
||||
device: str | torch.device | None = None,
|
||||
torch_threads: int | None = None,
|
||||
) -> LoadedGradePredictor:
|
||||
"""Load a trained joint grade-prediction checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path:
|
||||
Path to ``models/joint_transformer_grade_predictor.pth``.
|
||||
device:
|
||||
``"cpu"``, ``"cuda"``, or None for auto-detection.
|
||||
torch_threads:
|
||||
Optional CPU thread cap for small VPS demos.
|
||||
|
||||
Returns:
|
||||
LoadedGradePredictor containing the PyTorch model and tokenizer maps.
|
||||
"""
|
||||
if torch_threads is not None:
|
||||
torch.set_num_threads(int(torch_threads))
|
||||
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Could not find grade predictor checkpoint: {checkpoint_path}")
|
||||
|
||||
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
itos = {int(k): str(v) for k, v in checkpoint["itos"].items()}
|
||||
coord_features = checkpoint["coord_features"]
|
||||
if not isinstance(coord_features, torch.Tensor):
|
||||
coord_features = torch.tensor(coord_features, dtype=torch.float32)
|
||||
|
||||
model = JointRouteTransformerRegressor(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_len=cfg["max_len"],
|
||||
coord_features=coord_features,
|
||||
d_model=cfg.get("d_model", 128),
|
||||
nhead=cfg.get("nhead", 4),
|
||||
num_layers=cfg.get("num_layers", 4),
|
||||
dim_feedforward=cfg.get("dim_feedforward", 256),
|
||||
dropout=cfg.get("dropout", 0.10),
|
||||
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
|
||||
).to(resolved_device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
return LoadedGradePredictor(
|
||||
model=model,
|
||||
stoi=stoi,
|
||||
itos=itos,
|
||||
device=resolved_device,
|
||||
checkpoint_path=checkpoint_path,
|
||||
max_len=int(cfg["max_len"]),
|
||||
pad_id=int(cfg.get("pad_id", stoi["<PAD>"])),
|
||||
unk_id=int(stoi["<UNK>"]),
|
||||
)
|
||||
|
||||
|
||||
def predict_route_grade(
|
||||
grade_predictor: LoadedGradePredictor,
|
||||
tokens: list[str],
|
||||
) -> dict[str, object]:
|
||||
"""Predict the grade of a route-token sequence.
|
||||
|
||||
The grade token is removed before scoring, because the predictor should
|
||||
infer the grade from the board, angle, and hold-role tokens rather than
|
||||
reading the requested grade.
|
||||
"""
|
||||
model_tokens = [token for token in tokens if not str(token).startswith("<GRADE_")]
|
||||
if model_tokens and model_tokens[0] == "<BOS>":
|
||||
model_tokens = ["<CLS>"] + model_tokens[1:]
|
||||
else:
|
||||
model_tokens = ["<CLS>"] + model_tokens
|
||||
|
||||
ids = [grade_predictor.stoi.get(token, grade_predictor.unk_id) for token in model_tokens]
|
||||
ids = ids[: grade_predictor.max_len]
|
||||
mask = [1] * len(ids)
|
||||
|
||||
if len(ids) < grade_predictor.max_len:
|
||||
pad_n = grade_predictor.max_len - len(ids)
|
||||
ids += [grade_predictor.pad_id] * pad_n
|
||||
mask += [0] * pad_n
|
||||
|
||||
with torch.no_grad():
|
||||
input_ids = torch.tensor([ids], dtype=torch.long, device=grade_predictor.device)
|
||||
attention_mask = torch.tensor([mask], dtype=torch.bool, device=grade_predictor.device)
|
||||
pred_display_difficulty = float(grade_predictor.model(input_ids, attention_mask).cpu().item())
|
||||
|
||||
return {
|
||||
"predicted_display_difficulty": pred_display_difficulty,
|
||||
"predicted_grouped_v": int(to_grouped_v(pred_display_difficulty)),
|
||||
}
|
||||
|
||||
|
||||
def load_route_generator(
|
||||
checkpoint_path: str | Path,
|
||||
device: str | torch.device | None = None,
|
||||
torch_threads: int | None = None,
|
||||
) -> LoadedGenerator:
|
||||
"""Load a trained joint route generator checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path:
|
||||
Path to ``models/joint_route_gpt_generator.pth``.
|
||||
device:
|
||||
``"cpu"``, ``"cuda"``, or None for auto-detection.
|
||||
torch_threads:
|
||||
Optional CPU thread cap for small VPS demos.
|
||||
|
||||
Returns:
|
||||
LoadedGenerator containing the PyTorch model and tokenizer maps.
|
||||
"""
|
||||
if torch_threads is not None:
|
||||
torch.set_num_threads(int(torch_threads))
|
||||
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Could not find generator checkpoint: {checkpoint_path}")
|
||||
|
||||
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
itos = {int(k): str(v) for k, v in checkpoint["itos"].items()}
|
||||
|
||||
model = JointRouteGPT(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
block_size=cfg["block_size"],
|
||||
n_embd=cfg.get("n_embd", 128),
|
||||
n_head=cfg.get("n_head", 4),
|
||||
n_layer=cfg.get("n_layer", 4),
|
||||
dropout=cfg.get("dropout", 0.10),
|
||||
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
|
||||
).to(resolved_device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
return LoadedGenerator(
|
||||
model=model,
|
||||
stoi=stoi,
|
||||
itos=itos,
|
||||
device=resolved_device,
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
def generate_route(
|
||||
generator: LoadedGenerator,
|
||||
board_config: BoardConfig,
|
||||
angle: int,
|
||||
grade: int,
|
||||
temperature: float = 0.9,
|
||||
top_k: int | None = 50,
|
||||
max_new_tokens: int = 40,
|
||||
) -> dict[str, object]:
|
||||
"""Generate a single route for a board/angle/grade condition."""
|
||||
return {
|
||||
"board_key": board_config.board_key,
|
||||
"board_display_name": board_config.display_name,
|
||||
**generate_one(
|
||||
model=generator.model,
|
||||
stoi=generator.stoi,
|
||||
itos=generator.itos,
|
||||
device=generator.device,
|
||||
board_prefix=board_config.token_prefix,
|
||||
angle=int(angle),
|
||||
grouped_v=int(grade),
|
||||
role_name_to_id=board_config.role_definitions,
|
||||
temperature=float(temperature),
|
||||
top_k=top_k,
|
||||
max_new_tokens=int(max_new_tokens),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def load_board_for_demo(board: str, config_dir: str | Path | None = None) -> BoardConfig:
|
||||
"""Load a board config by key, with a clearer demo error message."""
|
||||
try:
|
||||
return load_board_config(board, config_dir=config_dir)
|
||||
except FileNotFoundError as exc:
|
||||
raise FileNotFoundError(
|
||||
f"Unknown board '{board}'. Expected one of the JSON configs in configs/."
|
||||
) from exc
|
||||
|
||||
|
||||
def build_placement_lookup_from_token_metadata(df_token_meta) -> dict[tuple[str, int], dict]:
|
||||
"""Build the placement lookup expected by tokenization helpers.
|
||||
|
||||
The training-time tokenization code canonicalizes holds using a lookup keyed
|
||||
by ``(board_key, placement_id)``. At inference/demo time, we usually have
|
||||
``token_metadata.csv`` rather than the raw database, so this reconstructs
|
||||
the necessary coordinate lookup from token metadata.
|
||||
"""
|
||||
hold_meta = df_token_meta[df_token_meta["kind"] == "hold"].dropna(subset=["placement_id"]).copy()
|
||||
lookup: dict[tuple[str, int], dict] = {}
|
||||
|
||||
for _, row in hold_meta.drop_duplicates(["board_key", "placement_id"]).iterrows():
|
||||
key = (str(row["board_key"]), int(row["placement_id"]))
|
||||
lookup[key] = {
|
||||
"board_key": str(row["board_key"]),
|
||||
"board_token_prefix": str(row["board_token_prefix"]),
|
||||
"placement_id": int(row["placement_id"]),
|
||||
"x": float(row["x"]),
|
||||
"y": float(row["y"]),
|
||||
"x_norm": float(row.get("x_norm", 0.0)),
|
||||
"y_norm": float(row.get("y_norm", 0.0)),
|
||||
}
|
||||
|
||||
return lookup
|
||||
|
||||
|
||||
def frames_to_grade_model_tokens(
|
||||
frames: str,
|
||||
angle: int,
|
||||
board_config: BoardConfig,
|
||||
df_token_meta,
|
||||
) -> list[str]:
|
||||
"""Convert a user-provided frames string into grade-predictor tokens.
|
||||
|
||||
Output format matches training for the grade predictor:
|
||||
|
||||
``<CLS> <BOARD_...> <ANGLE_...> <BOARDPREFIX_p..._role> ... <EOS>``
|
||||
|
||||
The route is canonicalized using the same role/y/x ordering used during
|
||||
tokenization. No grade token is included.
|
||||
"""
|
||||
placement_lookup = build_placement_lookup_from_token_metadata(df_token_meta)
|
||||
holds = parse_frames(frames)
|
||||
holds = canonicalize_holds(holds, board_config, placement_lookup)
|
||||
|
||||
tokens = [
|
||||
"<CLS>",
|
||||
board_token(board_config),
|
||||
angle_token(angle),
|
||||
]
|
||||
tokens.extend(
|
||||
hold_token(placement_id, role_id, board_config)
|
||||
for placement_id, role_id in holds
|
||||
)
|
||||
tokens.append("<EOS>")
|
||||
return tokens
|
||||
|
||||
|
||||
def predict_frames_grade(
|
||||
grade_predictor: LoadedGradePredictor,
|
||||
frames: str,
|
||||
angle: int,
|
||||
board_config: BoardConfig,
|
||||
df_token_meta,
|
||||
) -> dict[str, object]:
|
||||
"""Predict grade from board, angle, and a BoardLib frames string."""
|
||||
tokens = frames_to_grade_model_tokens(
|
||||
frames=frames,
|
||||
angle=angle,
|
||||
board_config=board_config,
|
||||
df_token_meta=df_token_meta,
|
||||
)
|
||||
|
||||
# predict_route_grade accepts either <BOS>-style generated tokens or
|
||||
# already-prepared <CLS>-style model tokens. It will leave the leading
|
||||
# <CLS> intact through the fallback branch.
|
||||
pred = predict_route_grade(grade_predictor, tokens)
|
||||
|
||||
return {
|
||||
**pred,
|
||||
"tokens": tokens,
|
||||
"sequence": " ".join(tokens),
|
||||
"board_key": board_config.board_key,
|
||||
"board_display_name": board_config.display_name,
|
||||
"requested_angle": int(angle),
|
||||
"frames": frames,
|
||||
}
|
||||
|
||||
@@ -41,7 +41,11 @@ class JointRouteTransformerRegressor(nn.Module):
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer,
|
||||
num_layers=num_layers,
|
||||
enable_nested_tensor=False,
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(d_model, d_model),
|
||||
@@ -96,7 +100,11 @@ class JointRouteGPT(nn.Module):
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.blocks = nn.TransformerEncoder(layer, num_layers=n_layer)
|
||||
self.blocks = nn.TransformerEncoder(
|
||||
layer,
|
||||
num_layers=n_layer,
|
||||
enable_nested_tensor=False,
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
self.lm_head.weight = self.token_emb.weight
|
||||
|
||||
@@ -75,3 +75,56 @@ def safe_train_test_split(
|
||||
random_state=random_state,
|
||||
stratify=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def assign_group_splits(
|
||||
df: pd.DataFrame,
|
||||
group_cols: list[str],
|
||||
test_size: float,
|
||||
val_size_within_temp: float,
|
||||
random_state: int,
|
||||
stratify_col: str | None = None,
|
||||
) -> pd.Series:
|
||||
"""Assign train/val/test splits at group level.
|
||||
|
||||
This prevents multiple rows for the same logical climb, for example the
|
||||
same UUID at several angles, from being distributed across different
|
||||
splits. The returned Series is indexed like ``df`` and contains
|
||||
``train``, ``val``, or ``test``.
|
||||
"""
|
||||
group_df = df[group_cols + ([stratify_col] if stratify_col else [])].copy()
|
||||
group_df["__row_index"] = range(len(group_df))
|
||||
group_df = group_df.drop_duplicates(group_cols).reset_index(drop=True)
|
||||
|
||||
train_groups, temp_groups = safe_train_test_split(
|
||||
group_df,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify_col=stratify_col,
|
||||
)
|
||||
val_groups, test_groups = safe_train_test_split(
|
||||
temp_groups,
|
||||
test_size=val_size_within_temp,
|
||||
random_state=random_state,
|
||||
stratify_col=stratify_col,
|
||||
)
|
||||
|
||||
def key_frame(frame: pd.DataFrame) -> set[tuple]:
|
||||
return set(map(tuple, frame[group_cols].astype(str).values.tolist()))
|
||||
|
||||
train_keys = key_frame(train_groups)
|
||||
val_keys = key_frame(val_groups)
|
||||
test_keys = key_frame(test_groups)
|
||||
|
||||
def split_for_row(row) -> str:
|
||||
key = tuple(str(row[col]) for col in group_cols)
|
||||
if key in train_keys:
|
||||
return "train"
|
||||
if key in val_keys:
|
||||
return "val"
|
||||
if key in test_keys:
|
||||
return "test"
|
||||
raise KeyError(f"Could not assign split for group key {key}")
|
||||
|
||||
return df.apply(split_for_row, axis=1)
|
||||
|
||||
353
src/climbingboardgpt/visualization.py
Normal file
353
src/climbingboardgpt/visualization.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Visualization utilities for generated ClimbingBoardGPT routes.
|
||||
|
||||
The route-overlay functions here deliberately mimic the old TB2/Kilter
|
||||
notebook convention: draw the board composite image with the product-size
|
||||
coordinate extent, then scatter route holds in board coordinates.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
HOLD_TOKEN_PATTERN = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
|
||||
# These are the same coordinate windows used in the earlier visualization
|
||||
# notebooks. They come from the product size geometry rather than from the
|
||||
# min/max of the actual holds. The hold coordinates are inset by about 4in,
|
||||
# so using hold min/max directly shifts/stretches the background image.
|
||||
BOARD_CANVAS = {
|
||||
"tb2": {
|
||||
"extent": [-68, 68, 0, 144],
|
||||
"figsize": (16, 14),
|
||||
"image_aspect": "auto",
|
||||
},
|
||||
"kilter": {
|
||||
"extent": [-24, 168, 0, 156],
|
||||
"figsize": (17, 12),
|
||||
"image_aspect": "equal",
|
||||
},
|
||||
}
|
||||
|
||||
ROLE_COLORS = {
|
||||
"start": "#2ecc71",
|
||||
"middle": "#3498db",
|
||||
"finish": "#e74c3c",
|
||||
"foot": "#f1c40f",
|
||||
"unknown": "#9ca3af",
|
||||
}
|
||||
|
||||
ROLE_MARKERS = {
|
||||
"start": "o",
|
||||
"middle": "o",
|
||||
"finish": "*",
|
||||
"foot": "s",
|
||||
"unknown": "o",
|
||||
}
|
||||
|
||||
ROLE_SIZES = {
|
||||
"start": 150,
|
||||
"middle": 150,
|
||||
"finish": 230,
|
||||
"foot": 95,
|
||||
"unknown": 150,
|
||||
}
|
||||
|
||||
|
||||
def parse_tokens(value) -> list[str]:
|
||||
"""Parse a generated token list from a list, repr string, or sequence string."""
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
if not isinstance(value, str):
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed = ast.literal_eval(value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(v) for v in parsed]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value.split()
|
||||
|
||||
|
||||
def tokens_to_route_records(tokens: Iterable[str]) -> pd.DataFrame:
|
||||
"""Extract generated hold records from model tokens."""
|
||||
rows = []
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_PATTERN.match(str(token))
|
||||
if match is None:
|
||||
continue
|
||||
rows.append(
|
||||
{
|
||||
"token": token,
|
||||
"board_token_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def load_token_metadata(tokenized_dir: str | Path) -> pd.DataFrame:
|
||||
"""Load token metadata produced by ``scripts/01_tokenize_routes.py``."""
|
||||
tokenized_dir = Path(tokenized_dir)
|
||||
path = tokenized_dir / "token_metadata.csv"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Could not find {path}. Run scripts/01_tokenize_routes.py first."
|
||||
)
|
||||
return pd.read_csv(path)
|
||||
|
||||
|
||||
def board_canvas_settings(board_key: str, df_token_meta: pd.DataFrame | None = None) -> dict[str, object]:
|
||||
"""Return board canvas settings.
|
||||
|
||||
Known boards use hand-calibrated extents from the old notebooks. Unknown
|
||||
boards fall back to coordinate bounds from ``token_metadata.csv``.
|
||||
"""
|
||||
board_key = str(board_key)
|
||||
if board_key in BOARD_CANVAS:
|
||||
return dict(BOARD_CANVAS[board_key])
|
||||
|
||||
if df_token_meta is None:
|
||||
raise ValueError(f"No board canvas settings for board_key={board_key!r}.")
|
||||
|
||||
holds = _board_holds(df_token_meta, board_key)
|
||||
x_min, x_max = float(holds["x"].min()), float(holds["x"].max())
|
||||
y_min, y_max = float(holds["y"].min()), float(holds["y"].max())
|
||||
x_pad = max((x_max - x_min) * 0.06, 1.0)
|
||||
y_pad = max((y_max - y_min) * 0.06, 1.0)
|
||||
return {
|
||||
"extent": [x_min - x_pad, x_max + x_pad, y_min - y_pad, y_max + y_pad],
|
||||
"figsize": (8, 10),
|
||||
"image_aspect": "auto",
|
||||
}
|
||||
|
||||
|
||||
def _board_holds(df_token_meta: pd.DataFrame, board_key: str) -> pd.DataFrame:
|
||||
holds = df_token_meta[
|
||||
(df_token_meta["kind"] == "hold")
|
||||
& (df_token_meta["board_key"].astype(str) == str(board_key))
|
||||
].copy()
|
||||
|
||||
if holds.empty:
|
||||
raise ValueError(
|
||||
f"No hold metadata found for board_key={board_key!r}. "
|
||||
"Check token_metadata.csv and board config."
|
||||
)
|
||||
|
||||
holds = holds.drop_duplicates(["board_key", "placement_id"]).copy()
|
||||
return holds
|
||||
|
||||
|
||||
def _route_with_coords(
|
||||
route_records: pd.DataFrame,
|
||||
df_token_meta: pd.DataFrame,
|
||||
board_key: str,
|
||||
) -> pd.DataFrame:
|
||||
holds = _board_holds(df_token_meta, board_key)
|
||||
coords = holds[["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
|
||||
["board_key", "placement_id"]
|
||||
)
|
||||
|
||||
merged = route_records.merge(
|
||||
coords,
|
||||
on=["board_token_prefix", "placement_id"],
|
||||
how="left",
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def visualize_route_tokens(
|
||||
tokens: Iterable[str],
|
||||
df_token_meta: pd.DataFrame,
|
||||
board_key: str,
|
||||
title: str | None = None,
|
||||
subtitle: str | None = None,
|
||||
output_path: str | Path | None = None,
|
||||
annotate: bool = False,
|
||||
show_all_holds: bool | None = None,
|
||||
background_image: str | Path | None = None,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
dpi: int = 160,
|
||||
):
|
||||
"""Visualize a generated route as a board overlay plot.
|
||||
|
||||
If a background image is supplied, the plot uses the calibrated canvas
|
||||
extent from the old project notebooks. If no image is supplied, it falls
|
||||
back to a clean coordinate-board style and shows available holds.
|
||||
"""
|
||||
route_records = tokens_to_route_records(tokens)
|
||||
if route_records.empty:
|
||||
raise ValueError("No hold tokens found in generated sequence.")
|
||||
|
||||
board_holds = _board_holds(df_token_meta, board_key)
|
||||
route_df = _route_with_coords(route_records, df_token_meta, board_key)
|
||||
route_df = route_df.dropna(subset=["x", "y"]).copy()
|
||||
|
||||
if route_df.empty:
|
||||
raise ValueError(
|
||||
"Generated route contained hold tokens, but none matched the board metadata."
|
||||
)
|
||||
|
||||
canvas = board_canvas_settings(board_key, df_token_meta)
|
||||
extent = [float(v) for v in canvas["extent"]]
|
||||
x_min, x_max, y_min, y_max = extent
|
||||
image_aspect = str(canvas.get("image_aspect", "auto"))
|
||||
figsize = figsize or canvas.get("figsize", (8, 10))
|
||||
|
||||
background_exists = background_image is not None and Path(background_image).exists()
|
||||
if show_all_holds is None:
|
||||
show_all_holds = not background_exists
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
if background_exists:
|
||||
img = plt.imread(Path(background_image))
|
||||
ax.imshow(
|
||||
img,
|
||||
extent=extent,
|
||||
aspect=image_aspect,
|
||||
alpha=1.0,
|
||||
zorder=0,
|
||||
)
|
||||
|
||||
if show_all_holds:
|
||||
ax.scatter(
|
||||
board_holds["x"],
|
||||
board_holds["y"],
|
||||
s=22,
|
||||
c="#d1d5db",
|
||||
alpha=0.45,
|
||||
linewidths=0,
|
||||
label="available holds",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
# Draw route holds role-by-role so the legend is meaningful.
|
||||
for role, frame in route_df.groupby("role", sort=False):
|
||||
ax.scatter(
|
||||
frame["x"],
|
||||
frame["y"],
|
||||
s=ROLE_SIZES.get(role, 150),
|
||||
c=ROLE_COLORS.get(role, ROLE_COLORS["unknown"]),
|
||||
marker=ROLE_MARKERS.get(role, "o"),
|
||||
edgecolors="#111827",
|
||||
linewidths=1.0,
|
||||
alpha=0.96,
|
||||
label=role,
|
||||
zorder=3,
|
||||
)
|
||||
|
||||
if annotate:
|
||||
for _, row in route_df.iterrows():
|
||||
ax.text(
|
||||
row["x"],
|
||||
row["y"],
|
||||
str(int(row["placement_id"])),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
color="white",
|
||||
bbox=dict(
|
||||
boxstyle="circle,pad=0.12",
|
||||
alpha=0.45,
|
||||
facecolor="#111827",
|
||||
edgecolor="white",
|
||||
linewidth=0.8,
|
||||
),
|
||||
zorder=4,
|
||||
)
|
||||
|
||||
ax.set_xlim(x_min, x_max)
|
||||
ax.set_ylim(y_min, y_max)
|
||||
if image_aspect == "equal":
|
||||
ax.set_aspect("equal", adjustable="box")
|
||||
ax.set_xlabel("X Position")
|
||||
ax.set_ylabel("Y Position")
|
||||
|
||||
# Put the title and subtitle at the figure level, not the axes level.
|
||||
# This avoids the old overlap where ax.set_title(...) and ax.text(y=1.01)
|
||||
# competed for the same narrow top margin.
|
||||
has_header = bool(title or subtitle)
|
||||
if title:
|
||||
fig.suptitle(
|
||||
title,
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
y=0.985,
|
||||
)
|
||||
if subtitle:
|
||||
fig.text(
|
||||
0.5,
|
||||
0.958,
|
||||
subtitle,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=9,
|
||||
color="#4b5563",
|
||||
)
|
||||
|
||||
if background_exists:
|
||||
ax.grid(False)
|
||||
else:
|
||||
ax.grid(True, alpha=0.18)
|
||||
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False)
|
||||
|
||||
# Reserve top space for the figure-level title/subtitle.
|
||||
if has_header:
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.925])
|
||||
else:
|
||||
fig.tight_layout()
|
||||
|
||||
if output_path is not None:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(output_path, dpi=dpi, bbox_inches="tight")
|
||||
|
||||
return fig, ax, route_df
|
||||
|
||||
|
||||
def visualize_route_result(
|
||||
result: dict[str, object],
|
||||
df_token_meta: pd.DataFrame,
|
||||
output_path: str | Path | None = None,
|
||||
annotate: bool = False,
|
||||
background_image: str | Path | None = None,
|
||||
):
|
||||
"""Visualize a result dictionary returned by ``generate_route``."""
|
||||
board_key = str(result["board_key"])
|
||||
tokens = parse_tokens(result["tokens"])
|
||||
title = (
|
||||
f"{str(result.get('board_display_name', board_key))} "
|
||||
f"generated V{int(result['requested_grouped_v'])} @ {int(result['requested_angle'])}°"
|
||||
)
|
||||
subtitle_parts = [
|
||||
f"valid={result.get('basic_valid')}",
|
||||
f"holds={result.get('n_hold_tokens')}",
|
||||
]
|
||||
if "predicted_grouped_v" in result:
|
||||
subtitle_parts.append(
|
||||
f"predicted V{int(result['predicted_grouped_v'])}"
|
||||
f" ({float(result['predicted_display_difficulty']):.2f})"
|
||||
)
|
||||
if "critic_v_error" in result:
|
||||
subtitle_parts.append(f"error {int(result['critic_v_error']):+d}V")
|
||||
subtitle_parts.append(f"temperature={result.get('temperature')}")
|
||||
subtitle = " | ".join(subtitle_parts)
|
||||
return visualize_route_tokens(
|
||||
tokens=tokens,
|
||||
df_token_meta=df_token_meta,
|
||||
board_key=board_key,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
output_path=output_path,
|
||||
annotate=annotate,
|
||||
background_image=background_image,
|
||||
)
|
||||
Reference in New Issue
Block a user