Next version. Models + scripts updated. 2
This commit is contained in:
@@ -42,17 +42,19 @@ def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
return rows
|
||||
|
||||
|
||||
def validity_from_records(records: list[dict[str, object]]) -> dict[str, object]:
|
||||
def validity_from_records(records: list[dict[str, object]], requested_board_prefix: str | None = None) -> dict[str, object]:
|
||||
placements = [int(record["placement_id"]) for record in records]
|
||||
roles = [str(record["role"]) for record in records]
|
||||
prefixes = [str(record["board_token_prefix"]) for record in records]
|
||||
one_board_only = len(set(prefixes)) <= 1
|
||||
matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)
|
||||
|
||||
out = {
|
||||
"n_holds_eval": len(records),
|
||||
"n_unique_placements_eval": len(set(placements)),
|
||||
"has_duplicate_placements_eval": len(records) != len(set(placements)),
|
||||
"one_board_only_eval": one_board_only,
|
||||
"matches_requested_board_eval": matches_requested_board,
|
||||
"n_start_eval": roles.count("start"),
|
||||
"n_middle_eval": roles.count("middle"),
|
||||
"n_foot_eval": roles.count("foot"),
|
||||
|
||||
@@ -77,13 +77,14 @@ def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
return rows
|
||||
|
||||
|
||||
def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
|
||||
def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:
|
||||
records = hold_records(tokens)
|
||||
placements = [record["placement_id"] for record in records]
|
||||
roles = [record["role"] for record in records]
|
||||
prefixes = [record["board_prefix"] for record in records]
|
||||
|
||||
one_board_only = len(set(prefixes)) <= 1
|
||||
matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)
|
||||
no_duplicates = len(placements) == len(set(placements))
|
||||
has_start = "start" in roles
|
||||
has_finish = "finish" in roles
|
||||
@@ -94,6 +95,7 @@ def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
|
||||
"n_unique_placements": len(set(placements)),
|
||||
"has_duplicate_placements": not no_duplicates,
|
||||
"one_board_only": one_board_only,
|
||||
"matches_requested_board": matches_requested_board,
|
||||
"has_start": has_start,
|
||||
"has_middle": "middle" in roles,
|
||||
"has_finish": has_finish,
|
||||
@@ -101,14 +103,16 @@ def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
|
||||
"n_middle": roles.count("middle"),
|
||||
"n_foot": roles.count("foot"),
|
||||
"n_finish": roles.count("finish"),
|
||||
"basic_valid": bool(one_board_only and no_duplicates and has_start and has_finish and enough_holds),
|
||||
"basic_valid": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),
|
||||
}
|
||||
|
||||
|
||||
def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int]) -> str:
|
||||
def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:
|
||||
pieces = []
|
||||
seen = set()
|
||||
for record in hold_records(tokens):
|
||||
if board_prefix is not None and str(record["board_prefix"]) != board_prefix:
|
||||
continue
|
||||
placement_id = int(record["placement_id"])
|
||||
role = str(record["role"])
|
||||
if placement_id in seen or role not in role_name_to_id:
|
||||
@@ -154,7 +158,7 @@ def generate_one(
|
||||
forbidden_ids=forbidden_ids,
|
||||
)
|
||||
tokens = [itos.get(int(idx), "<UNK>") for idx in token_ids]
|
||||
validity = validity_summary(tokens)
|
||||
validity = validity_summary(tokens, requested_board_prefix=board_prefix)
|
||||
|
||||
return {
|
||||
"requested_board_prefix": board_prefix,
|
||||
@@ -164,6 +168,6 @@ def generate_one(
|
||||
"top_k": None if top_k is None else int(top_k),
|
||||
"tokens": tokens,
|
||||
"sequence": " ".join(tokens),
|
||||
"frames": generated_tokens_to_frames(tokens, role_name_to_id),
|
||||
"frames": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),
|
||||
**validity,
|
||||
}
|
||||
|
||||
335
src/climbingboardgpt/inference.py
Normal file
335
src/climbingboardgpt/inference.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Inference helpers for ClimbingBoardGPT demos.
|
||||
|
||||
This module is intentionally small and webapp-friendly: it loads trained
|
||||
checkpoints once, keeps them in memory, and exposes route generation helpers.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from .config import BoardConfig, load_board_config
|
||||
from .generation import generate_one
|
||||
from .grades import to_grouped_v
|
||||
from .models import JointRouteGPT, JointRouteTransformerRegressor
|
||||
from .tokenization import (
|
||||
angle_token,
|
||||
board_token,
|
||||
canonicalize_holds,
|
||||
hold_token,
|
||||
parse_frames,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedGenerator:
|
||||
"""A loaded GPT-style route generator plus vocabulary metadata."""
|
||||
|
||||
model: JointRouteGPT
|
||||
stoi: dict[str, int]
|
||||
itos: dict[int, str]
|
||||
device: torch.device
|
||||
checkpoint_path: Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedGradePredictor:
|
||||
"""A loaded transformer grade predictor plus vocabulary metadata."""
|
||||
|
||||
model: JointRouteTransformerRegressor
|
||||
stoi: dict[str, int]
|
||||
itos: dict[int, str]
|
||||
device: torch.device
|
||||
checkpoint_path: Path
|
||||
max_len: int
|
||||
pad_id: int
|
||||
unk_id: int
|
||||
|
||||
|
||||
def load_grade_predictor(
|
||||
checkpoint_path: str | Path,
|
||||
device: str | torch.device | None = None,
|
||||
torch_threads: int | None = None,
|
||||
) -> LoadedGradePredictor:
|
||||
"""Load a trained joint grade-prediction checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path:
|
||||
Path to ``models/joint_transformer_grade_predictor.pth``.
|
||||
device:
|
||||
``"cpu"``, ``"cuda"``, or None for auto-detection.
|
||||
torch_threads:
|
||||
Optional CPU thread cap for small VPS demos.
|
||||
|
||||
Returns:
|
||||
LoadedGradePredictor containing the PyTorch model and tokenizer maps.
|
||||
"""
|
||||
if torch_threads is not None:
|
||||
torch.set_num_threads(int(torch_threads))
|
||||
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Could not find grade predictor checkpoint: {checkpoint_path}")
|
||||
|
||||
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
itos = {int(k): str(v) for k, v in checkpoint["itos"].items()}
|
||||
coord_features = checkpoint["coord_features"]
|
||||
if not isinstance(coord_features, torch.Tensor):
|
||||
coord_features = torch.tensor(coord_features, dtype=torch.float32)
|
||||
|
||||
model = JointRouteTransformerRegressor(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_len=cfg["max_len"],
|
||||
coord_features=coord_features,
|
||||
d_model=cfg.get("d_model", 128),
|
||||
nhead=cfg.get("nhead", 4),
|
||||
num_layers=cfg.get("num_layers", 4),
|
||||
dim_feedforward=cfg.get("dim_feedforward", 256),
|
||||
dropout=cfg.get("dropout", 0.10),
|
||||
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
|
||||
).to(resolved_device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
return LoadedGradePredictor(
|
||||
model=model,
|
||||
stoi=stoi,
|
||||
itos=itos,
|
||||
device=resolved_device,
|
||||
checkpoint_path=checkpoint_path,
|
||||
max_len=int(cfg["max_len"]),
|
||||
pad_id=int(cfg.get("pad_id", stoi["<PAD>"])),
|
||||
unk_id=int(stoi["<UNK>"]),
|
||||
)
|
||||
|
||||
|
||||
def predict_route_grade(
|
||||
grade_predictor: LoadedGradePredictor,
|
||||
tokens: list[str],
|
||||
) -> dict[str, object]:
|
||||
"""Predict the grade of a route-token sequence.
|
||||
|
||||
The grade token is removed before scoring, because the predictor should
|
||||
infer the grade from the board, angle, and hold-role tokens rather than
|
||||
reading the requested grade.
|
||||
"""
|
||||
model_tokens = [token for token in tokens if not str(token).startswith("<GRADE_")]
|
||||
if model_tokens and model_tokens[0] == "<BOS>":
|
||||
model_tokens = ["<CLS>"] + model_tokens[1:]
|
||||
else:
|
||||
model_tokens = ["<CLS>"] + model_tokens
|
||||
|
||||
ids = [grade_predictor.stoi.get(token, grade_predictor.unk_id) for token in model_tokens]
|
||||
ids = ids[: grade_predictor.max_len]
|
||||
mask = [1] * len(ids)
|
||||
|
||||
if len(ids) < grade_predictor.max_len:
|
||||
pad_n = grade_predictor.max_len - len(ids)
|
||||
ids += [grade_predictor.pad_id] * pad_n
|
||||
mask += [0] * pad_n
|
||||
|
||||
with torch.no_grad():
|
||||
input_ids = torch.tensor([ids], dtype=torch.long, device=grade_predictor.device)
|
||||
attention_mask = torch.tensor([mask], dtype=torch.bool, device=grade_predictor.device)
|
||||
pred_display_difficulty = float(grade_predictor.model(input_ids, attention_mask).cpu().item())
|
||||
|
||||
return {
|
||||
"predicted_display_difficulty": pred_display_difficulty,
|
||||
"predicted_grouped_v": int(to_grouped_v(pred_display_difficulty)),
|
||||
}
|
||||
|
||||
|
||||
def load_route_generator(
|
||||
checkpoint_path: str | Path,
|
||||
device: str | torch.device | None = None,
|
||||
torch_threads: int | None = None,
|
||||
) -> LoadedGenerator:
|
||||
"""Load a trained joint route generator checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_path:
|
||||
Path to ``models/joint_route_gpt_generator.pth``.
|
||||
device:
|
||||
``"cpu"``, ``"cuda"``, or None for auto-detection.
|
||||
torch_threads:
|
||||
Optional CPU thread cap for small VPS demos.
|
||||
|
||||
Returns:
|
||||
LoadedGenerator containing the PyTorch model and tokenizer maps.
|
||||
"""
|
||||
if torch_threads is not None:
|
||||
torch.set_num_threads(int(torch_threads))
|
||||
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Could not find generator checkpoint: {checkpoint_path}")
|
||||
|
||||
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
itos = {int(k): str(v) for k, v in checkpoint["itos"].items()}
|
||||
|
||||
model = JointRouteGPT(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
block_size=cfg["block_size"],
|
||||
n_embd=cfg.get("n_embd", 128),
|
||||
n_head=cfg.get("n_head", 4),
|
||||
n_layer=cfg.get("n_layer", 4),
|
||||
dropout=cfg.get("dropout", 0.10),
|
||||
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
|
||||
).to(resolved_device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
return LoadedGenerator(
|
||||
model=model,
|
||||
stoi=stoi,
|
||||
itos=itos,
|
||||
device=resolved_device,
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
def generate_route(
|
||||
generator: LoadedGenerator,
|
||||
board_config: BoardConfig,
|
||||
angle: int,
|
||||
grade: int,
|
||||
temperature: float = 0.9,
|
||||
top_k: int | None = 50,
|
||||
max_new_tokens: int = 40,
|
||||
) -> dict[str, object]:
|
||||
"""Generate a single route for a board/angle/grade condition."""
|
||||
return {
|
||||
"board_key": board_config.board_key,
|
||||
"board_display_name": board_config.display_name,
|
||||
**generate_one(
|
||||
model=generator.model,
|
||||
stoi=generator.stoi,
|
||||
itos=generator.itos,
|
||||
device=generator.device,
|
||||
board_prefix=board_config.token_prefix,
|
||||
angle=int(angle),
|
||||
grouped_v=int(grade),
|
||||
role_name_to_id=board_config.role_definitions,
|
||||
temperature=float(temperature),
|
||||
top_k=top_k,
|
||||
max_new_tokens=int(max_new_tokens),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def load_board_for_demo(board: str, config_dir: str | Path | None = None) -> BoardConfig:
|
||||
"""Load a board config by key, with a clearer demo error message."""
|
||||
try:
|
||||
return load_board_config(board, config_dir=config_dir)
|
||||
except FileNotFoundError as exc:
|
||||
raise FileNotFoundError(
|
||||
f"Unknown board '{board}'. Expected one of the JSON configs in configs/."
|
||||
) from exc
|
||||
|
||||
|
||||
def build_placement_lookup_from_token_metadata(df_token_meta) -> dict[tuple[str, int], dict]:
|
||||
"""Build the placement lookup expected by tokenization helpers.
|
||||
|
||||
The training-time tokenization code canonicalizes holds using a lookup keyed
|
||||
by ``(board_key, placement_id)``. At inference/demo time, we usually have
|
||||
``token_metadata.csv`` rather than the raw database, so this reconstructs
|
||||
the necessary coordinate lookup from token metadata.
|
||||
"""
|
||||
hold_meta = df_token_meta[df_token_meta["kind"] == "hold"].dropna(subset=["placement_id"]).copy()
|
||||
lookup: dict[tuple[str, int], dict] = {}
|
||||
|
||||
for _, row in hold_meta.drop_duplicates(["board_key", "placement_id"]).iterrows():
|
||||
key = (str(row["board_key"]), int(row["placement_id"]))
|
||||
lookup[key] = {
|
||||
"board_key": str(row["board_key"]),
|
||||
"board_token_prefix": str(row["board_token_prefix"]),
|
||||
"placement_id": int(row["placement_id"]),
|
||||
"x": float(row["x"]),
|
||||
"y": float(row["y"]),
|
||||
"x_norm": float(row.get("x_norm", 0.0)),
|
||||
"y_norm": float(row.get("y_norm", 0.0)),
|
||||
}
|
||||
|
||||
return lookup
|
||||
|
||||
|
||||
def frames_to_grade_model_tokens(
|
||||
frames: str,
|
||||
angle: int,
|
||||
board_config: BoardConfig,
|
||||
df_token_meta,
|
||||
) -> list[str]:
|
||||
"""Convert a user-provided frames string into grade-predictor tokens.
|
||||
|
||||
Output format matches training for the grade predictor:
|
||||
|
||||
``<CLS> <BOARD_...> <ANGLE_...> <BOARDPREFIX_p..._role> ... <EOS>``
|
||||
|
||||
The route is canonicalized using the same role/y/x ordering used during
|
||||
tokenization. No grade token is included.
|
||||
"""
|
||||
placement_lookup = build_placement_lookup_from_token_metadata(df_token_meta)
|
||||
holds = parse_frames(frames)
|
||||
holds = canonicalize_holds(holds, board_config, placement_lookup)
|
||||
|
||||
tokens = [
|
||||
"<CLS>",
|
||||
board_token(board_config),
|
||||
angle_token(angle),
|
||||
]
|
||||
tokens.extend(
|
||||
hold_token(placement_id, role_id, board_config)
|
||||
for placement_id, role_id in holds
|
||||
)
|
||||
tokens.append("<EOS>")
|
||||
return tokens
|
||||
|
||||
|
||||
def predict_frames_grade(
|
||||
grade_predictor: LoadedGradePredictor,
|
||||
frames: str,
|
||||
angle: int,
|
||||
board_config: BoardConfig,
|
||||
df_token_meta,
|
||||
) -> dict[str, object]:
|
||||
"""Predict grade from board, angle, and a BoardLib frames string."""
|
||||
tokens = frames_to_grade_model_tokens(
|
||||
frames=frames,
|
||||
angle=angle,
|
||||
board_config=board_config,
|
||||
df_token_meta=df_token_meta,
|
||||
)
|
||||
|
||||
# predict_route_grade accepts either <BOS>-style generated tokens or
|
||||
# already-prepared <CLS>-style model tokens. It will leave the leading
|
||||
# <CLS> intact through the fallback branch.
|
||||
pred = predict_route_grade(grade_predictor, tokens)
|
||||
|
||||
return {
|
||||
**pred,
|
||||
"tokens": tokens,
|
||||
"sequence": " ".join(tokens),
|
||||
"board_key": board_config.board_key,
|
||||
"board_display_name": board_config.display_name,
|
||||
"requested_angle": int(angle),
|
||||
"frames": frames,
|
||||
}
|
||||
|
||||
@@ -41,7 +41,11 @@ class JointRouteTransformerRegressor(nn.Module):
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer,
|
||||
num_layers=num_layers,
|
||||
enable_nested_tensor=False,
|
||||
)
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(d_model, d_model),
|
||||
@@ -96,7 +100,11 @@ class JointRouteGPT(nn.Module):
|
||||
batch_first=True,
|
||||
norm_first=True,
|
||||
)
|
||||
self.blocks = nn.TransformerEncoder(layer, num_layers=n_layer)
|
||||
self.blocks = nn.TransformerEncoder(
|
||||
layer,
|
||||
num_layers=n_layer,
|
||||
enable_nested_tensor=False,
|
||||
)
|
||||
self.ln_f = nn.LayerNorm(n_embd)
|
||||
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
|
||||
self.lm_head.weight = self.token_emb.weight
|
||||
|
||||
@@ -75,3 +75,56 @@ def safe_train_test_split(
|
||||
random_state=random_state,
|
||||
stratify=None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
def assign_group_splits(
|
||||
df: pd.DataFrame,
|
||||
group_cols: list[str],
|
||||
test_size: float,
|
||||
val_size_within_temp: float,
|
||||
random_state: int,
|
||||
stratify_col: str | None = None,
|
||||
) -> pd.Series:
|
||||
"""Assign train/val/test splits at group level.
|
||||
|
||||
This prevents multiple rows for the same logical climb, for example the
|
||||
same UUID at several angles, from being distributed across different
|
||||
splits. The returned Series is indexed like ``df`` and contains
|
||||
``train``, ``val``, or ``test``.
|
||||
"""
|
||||
group_df = df[group_cols + ([stratify_col] if stratify_col else [])].copy()
|
||||
group_df["__row_index"] = range(len(group_df))
|
||||
group_df = group_df.drop_duplicates(group_cols).reset_index(drop=True)
|
||||
|
||||
train_groups, temp_groups = safe_train_test_split(
|
||||
group_df,
|
||||
test_size=test_size,
|
||||
random_state=random_state,
|
||||
stratify_col=stratify_col,
|
||||
)
|
||||
val_groups, test_groups = safe_train_test_split(
|
||||
temp_groups,
|
||||
test_size=val_size_within_temp,
|
||||
random_state=random_state,
|
||||
stratify_col=stratify_col,
|
||||
)
|
||||
|
||||
def key_frame(frame: pd.DataFrame) -> set[tuple]:
|
||||
return set(map(tuple, frame[group_cols].astype(str).values.tolist()))
|
||||
|
||||
train_keys = key_frame(train_groups)
|
||||
val_keys = key_frame(val_groups)
|
||||
test_keys = key_frame(test_groups)
|
||||
|
||||
def split_for_row(row) -> str:
|
||||
key = tuple(str(row[col]) for col in group_cols)
|
||||
if key in train_keys:
|
||||
return "train"
|
||||
if key in val_keys:
|
||||
return "val"
|
||||
if key in test_keys:
|
||||
return "test"
|
||||
raise KeyError(f"Could not assign split for group key {key}")
|
||||
|
||||
return df.apply(split_for_row, axis=1)
|
||||
|
||||
353
src/climbingboardgpt/visualization.py
Normal file
353
src/climbingboardgpt/visualization.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""Visualization utilities for generated ClimbingBoardGPT routes.
|
||||
|
||||
The route-overlay functions here deliberately mimic the old TB2/Kilter
|
||||
notebook convention: draw the board composite image with the product-size
|
||||
coordinate extent, then scatter route holds in board coordinates.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
HOLD_TOKEN_PATTERN = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
|
||||
# These are the same coordinate windows used in the earlier visualization
|
||||
# notebooks. They come from the product size geometry rather than from the
|
||||
# min/max of the actual holds. The hold coordinates are inset by about 4in,
|
||||
# so using hold min/max directly shifts/stretches the background image.
|
||||
BOARD_CANVAS = {
|
||||
"tb2": {
|
||||
"extent": [-68, 68, 0, 144],
|
||||
"figsize": (16, 14),
|
||||
"image_aspect": "auto",
|
||||
},
|
||||
"kilter": {
|
||||
"extent": [-24, 168, 0, 156],
|
||||
"figsize": (17, 12),
|
||||
"image_aspect": "equal",
|
||||
},
|
||||
}
|
||||
|
||||
ROLE_COLORS = {
|
||||
"start": "#2ecc71",
|
||||
"middle": "#3498db",
|
||||
"finish": "#e74c3c",
|
||||
"foot": "#f1c40f",
|
||||
"unknown": "#9ca3af",
|
||||
}
|
||||
|
||||
ROLE_MARKERS = {
|
||||
"start": "o",
|
||||
"middle": "o",
|
||||
"finish": "*",
|
||||
"foot": "s",
|
||||
"unknown": "o",
|
||||
}
|
||||
|
||||
ROLE_SIZES = {
|
||||
"start": 150,
|
||||
"middle": 150,
|
||||
"finish": 230,
|
||||
"foot": 95,
|
||||
"unknown": 150,
|
||||
}
|
||||
|
||||
|
||||
def parse_tokens(value) -> list[str]:
|
||||
"""Parse a generated token list from a list, repr string, or sequence string."""
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
if not isinstance(value, str):
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed = ast.literal_eval(value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(v) for v in parsed]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value.split()
|
||||
|
||||
|
||||
def tokens_to_route_records(tokens: Iterable[str]) -> pd.DataFrame:
|
||||
"""Extract generated hold records from model tokens."""
|
||||
rows = []
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_PATTERN.match(str(token))
|
||||
if match is None:
|
||||
continue
|
||||
rows.append(
|
||||
{
|
||||
"token": token,
|
||||
"board_token_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def load_token_metadata(tokenized_dir: str | Path) -> pd.DataFrame:
|
||||
"""Load token metadata produced by ``scripts/01_tokenize_routes.py``."""
|
||||
tokenized_dir = Path(tokenized_dir)
|
||||
path = tokenized_dir / "token_metadata.csv"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Could not find {path}. Run scripts/01_tokenize_routes.py first."
|
||||
)
|
||||
return pd.read_csv(path)
|
||||
|
||||
|
||||
def board_canvas_settings(board_key: str, df_token_meta: pd.DataFrame | None = None) -> dict[str, object]:
|
||||
"""Return board canvas settings.
|
||||
|
||||
Known boards use hand-calibrated extents from the old notebooks. Unknown
|
||||
boards fall back to coordinate bounds from ``token_metadata.csv``.
|
||||
"""
|
||||
board_key = str(board_key)
|
||||
if board_key in BOARD_CANVAS:
|
||||
return dict(BOARD_CANVAS[board_key])
|
||||
|
||||
if df_token_meta is None:
|
||||
raise ValueError(f"No board canvas settings for board_key={board_key!r}.")
|
||||
|
||||
holds = _board_holds(df_token_meta, board_key)
|
||||
x_min, x_max = float(holds["x"].min()), float(holds["x"].max())
|
||||
y_min, y_max = float(holds["y"].min()), float(holds["y"].max())
|
||||
x_pad = max((x_max - x_min) * 0.06, 1.0)
|
||||
y_pad = max((y_max - y_min) * 0.06, 1.0)
|
||||
return {
|
||||
"extent": [x_min - x_pad, x_max + x_pad, y_min - y_pad, y_max + y_pad],
|
||||
"figsize": (8, 10),
|
||||
"image_aspect": "auto",
|
||||
}
|
||||
|
||||
|
||||
def _board_holds(df_token_meta: pd.DataFrame, board_key: str) -> pd.DataFrame:
|
||||
holds = df_token_meta[
|
||||
(df_token_meta["kind"] == "hold")
|
||||
& (df_token_meta["board_key"].astype(str) == str(board_key))
|
||||
].copy()
|
||||
|
||||
if holds.empty:
|
||||
raise ValueError(
|
||||
f"No hold metadata found for board_key={board_key!r}. "
|
||||
"Check token_metadata.csv and board config."
|
||||
)
|
||||
|
||||
holds = holds.drop_duplicates(["board_key", "placement_id"]).copy()
|
||||
return holds
|
||||
|
||||
|
||||
def _route_with_coords(
|
||||
route_records: pd.DataFrame,
|
||||
df_token_meta: pd.DataFrame,
|
||||
board_key: str,
|
||||
) -> pd.DataFrame:
|
||||
holds = _board_holds(df_token_meta, board_key)
|
||||
coords = holds[["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
|
||||
["board_key", "placement_id"]
|
||||
)
|
||||
|
||||
merged = route_records.merge(
|
||||
coords,
|
||||
on=["board_token_prefix", "placement_id"],
|
||||
how="left",
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
def visualize_route_tokens(
|
||||
tokens: Iterable[str],
|
||||
df_token_meta: pd.DataFrame,
|
||||
board_key: str,
|
||||
title: str | None = None,
|
||||
subtitle: str | None = None,
|
||||
output_path: str | Path | None = None,
|
||||
annotate: bool = False,
|
||||
show_all_holds: bool | None = None,
|
||||
background_image: str | Path | None = None,
|
||||
figsize: tuple[float, float] | None = None,
|
||||
dpi: int = 160,
|
||||
):
|
||||
"""Visualize a generated route as a board overlay plot.
|
||||
|
||||
If a background image is supplied, the plot uses the calibrated canvas
|
||||
extent from the old project notebooks. If no image is supplied, it falls
|
||||
back to a clean coordinate-board style and shows available holds.
|
||||
"""
|
||||
route_records = tokens_to_route_records(tokens)
|
||||
if route_records.empty:
|
||||
raise ValueError("No hold tokens found in generated sequence.")
|
||||
|
||||
board_holds = _board_holds(df_token_meta, board_key)
|
||||
route_df = _route_with_coords(route_records, df_token_meta, board_key)
|
||||
route_df = route_df.dropna(subset=["x", "y"]).copy()
|
||||
|
||||
if route_df.empty:
|
||||
raise ValueError(
|
||||
"Generated route contained hold tokens, but none matched the board metadata."
|
||||
)
|
||||
|
||||
canvas = board_canvas_settings(board_key, df_token_meta)
|
||||
extent = [float(v) for v in canvas["extent"]]
|
||||
x_min, x_max, y_min, y_max = extent
|
||||
image_aspect = str(canvas.get("image_aspect", "auto"))
|
||||
figsize = figsize or canvas.get("figsize", (8, 10))
|
||||
|
||||
background_exists = background_image is not None and Path(background_image).exists()
|
||||
if show_all_holds is None:
|
||||
show_all_holds = not background_exists
|
||||
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
if background_exists:
|
||||
img = plt.imread(Path(background_image))
|
||||
ax.imshow(
|
||||
img,
|
||||
extent=extent,
|
||||
aspect=image_aspect,
|
||||
alpha=1.0,
|
||||
zorder=0,
|
||||
)
|
||||
|
||||
if show_all_holds:
|
||||
ax.scatter(
|
||||
board_holds["x"],
|
||||
board_holds["y"],
|
||||
s=22,
|
||||
c="#d1d5db",
|
||||
alpha=0.45,
|
||||
linewidths=0,
|
||||
label="available holds",
|
||||
zorder=1,
|
||||
)
|
||||
|
||||
# Draw route holds role-by-role so the legend is meaningful.
|
||||
for role, frame in route_df.groupby("role", sort=False):
|
||||
ax.scatter(
|
||||
frame["x"],
|
||||
frame["y"],
|
||||
s=ROLE_SIZES.get(role, 150),
|
||||
c=ROLE_COLORS.get(role, ROLE_COLORS["unknown"]),
|
||||
marker=ROLE_MARKERS.get(role, "o"),
|
||||
edgecolors="#111827",
|
||||
linewidths=1.0,
|
||||
alpha=0.96,
|
||||
label=role,
|
||||
zorder=3,
|
||||
)
|
||||
|
||||
if annotate:
|
||||
for _, row in route_df.iterrows():
|
||||
ax.text(
|
||||
row["x"],
|
||||
row["y"],
|
||||
str(int(row["placement_id"])),
|
||||
ha="center",
|
||||
va="center",
|
||||
fontsize=7,
|
||||
fontweight="bold",
|
||||
color="white",
|
||||
bbox=dict(
|
||||
boxstyle="circle,pad=0.12",
|
||||
alpha=0.45,
|
||||
facecolor="#111827",
|
||||
edgecolor="white",
|
||||
linewidth=0.8,
|
||||
),
|
||||
zorder=4,
|
||||
)
|
||||
|
||||
ax.set_xlim(x_min, x_max)
|
||||
ax.set_ylim(y_min, y_max)
|
||||
if image_aspect == "equal":
|
||||
ax.set_aspect("equal", adjustable="box")
|
||||
ax.set_xlabel("X Position")
|
||||
ax.set_ylabel("Y Position")
|
||||
|
||||
# Put the title and subtitle at the figure level, not the axes level.
|
||||
# This avoids the old overlap where ax.set_title(...) and ax.text(y=1.01)
|
||||
# competed for the same narrow top margin.
|
||||
has_header = bool(title or subtitle)
|
||||
if title:
|
||||
fig.suptitle(
|
||||
title,
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
y=0.985,
|
||||
)
|
||||
if subtitle:
|
||||
fig.text(
|
||||
0.5,
|
||||
0.958,
|
||||
subtitle,
|
||||
ha="center",
|
||||
va="top",
|
||||
fontsize=9,
|
||||
color="#4b5563",
|
||||
)
|
||||
|
||||
if background_exists:
|
||||
ax.grid(False)
|
||||
else:
|
||||
ax.grid(True, alpha=0.18)
|
||||
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False)
|
||||
|
||||
# Reserve top space for the figure-level title/subtitle.
|
||||
if has_header:
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.925])
|
||||
else:
|
||||
fig.tight_layout()
|
||||
|
||||
if output_path is not None:
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fig.savefig(output_path, dpi=dpi, bbox_inches="tight")
|
||||
|
||||
return fig, ax, route_df
|
||||
|
||||
|
||||
def visualize_route_result(
|
||||
result: dict[str, object],
|
||||
df_token_meta: pd.DataFrame,
|
||||
output_path: str | Path | None = None,
|
||||
annotate: bool = False,
|
||||
background_image: str | Path | None = None,
|
||||
):
|
||||
"""Visualize a result dictionary returned by ``generate_route``."""
|
||||
board_key = str(result["board_key"])
|
||||
tokens = parse_tokens(result["tokens"])
|
||||
title = (
|
||||
f"{str(result.get('board_display_name', board_key))} "
|
||||
f"generated V{int(result['requested_grouped_v'])} @ {int(result['requested_angle'])}°"
|
||||
)
|
||||
subtitle_parts = [
|
||||
f"valid={result.get('basic_valid')}",
|
||||
f"holds={result.get('n_hold_tokens')}",
|
||||
]
|
||||
if "predicted_grouped_v" in result:
|
||||
subtitle_parts.append(
|
||||
f"predicted V{int(result['predicted_grouped_v'])}"
|
||||
f" ({float(result['predicted_display_difficulty']):.2f})"
|
||||
)
|
||||
if "critic_v_error" in result:
|
||||
subtitle_parts.append(f"error {int(result['critic_v_error']):+d}V")
|
||||
subtitle_parts.append(f"temperature={result.get('temperature')}")
|
||||
subtitle = " | ".join(subtitle_parts)
|
||||
return visualize_route_tokens(
|
||||
tokens=tokens,
|
||||
df_token_meta=df_token_meta,
|
||||
board_key=board_key,
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
output_path=output_path,
|
||||
annotate=annotate,
|
||||
background_image=background_image,
|
||||
)
|
||||
Reference in New Issue
Block a user