Next version. Models + scripts updated. 2

This commit is contained in:
Pawel
2026-05-21 22:21:26 -04:00
parent 0002ef1545
commit 86d582a572
23 changed files with 1768 additions and 293 deletions

View File

@@ -42,17 +42,19 @@ def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
return rows
def validity_from_records(records: list[dict[str, object]]) -> dict[str, object]:
def validity_from_records(records: list[dict[str, object]], requested_board_prefix: str | None = None) -> dict[str, object]:
placements = [int(record["placement_id"]) for record in records]
roles = [str(record["role"]) for record in records]
prefixes = [str(record["board_token_prefix"]) for record in records]
one_board_only = len(set(prefixes)) <= 1
matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)
out = {
"n_holds_eval": len(records),
"n_unique_placements_eval": len(set(placements)),
"has_duplicate_placements_eval": len(records) != len(set(placements)),
"one_board_only_eval": one_board_only,
"matches_requested_board_eval": matches_requested_board,
"n_start_eval": roles.count("start"),
"n_middle_eval": roles.count("middle"),
"n_foot_eval": roles.count("foot"),

View File

@@ -77,13 +77,14 @@ def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
return rows
def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:
records = hold_records(tokens)
placements = [record["placement_id"] for record in records]
roles = [record["role"] for record in records]
prefixes = [record["board_prefix"] for record in records]
one_board_only = len(set(prefixes)) <= 1
matches_requested_board = requested_board_prefix is None or all(prefix == requested_board_prefix for prefix in prefixes)
no_duplicates = len(placements) == len(set(placements))
has_start = "start" in roles
has_finish = "finish" in roles
@@ -94,6 +95,7 @@ def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
"n_unique_placements": len(set(placements)),
"has_duplicate_placements": not no_duplicates,
"one_board_only": one_board_only,
"matches_requested_board": matches_requested_board,
"has_start": has_start,
"has_middle": "middle" in roles,
"has_finish": has_finish,
@@ -101,14 +103,16 @@ def validity_summary(tokens: Iterable[str]) -> dict[str, object]:
"n_middle": roles.count("middle"),
"n_foot": roles.count("foot"),
"n_finish": roles.count("finish"),
"basic_valid": bool(one_board_only and no_duplicates and has_start and has_finish and enough_holds),
"basic_valid": bool(one_board_only and matches_requested_board and no_duplicates and has_start and has_finish and enough_holds),
}
def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int]) -> str:
def generated_tokens_to_frames(tokens: Iterable[str], role_name_to_id: dict[str, int], board_prefix: str | None = None) -> str:
pieces = []
seen = set()
for record in hold_records(tokens):
if board_prefix is not None and str(record["board_prefix"]) != board_prefix:
continue
placement_id = int(record["placement_id"])
role = str(record["role"])
if placement_id in seen or role not in role_name_to_id:
@@ -154,7 +158,7 @@ def generate_one(
forbidden_ids=forbidden_ids,
)
tokens = [itos.get(int(idx), "<UNK>") for idx in token_ids]
validity = validity_summary(tokens)
validity = validity_summary(tokens, requested_board_prefix=board_prefix)
return {
"requested_board_prefix": board_prefix,
@@ -164,6 +168,6 @@ def generate_one(
"top_k": None if top_k is None else int(top_k),
"tokens": tokens,
"sequence": " ".join(tokens),
"frames": generated_tokens_to_frames(tokens, role_name_to_id),
"frames": generated_tokens_to_frames(tokens, role_name_to_id, board_prefix=board_prefix),
**validity,
}

View File

@@ -0,0 +1,335 @@
"""Inference helpers for ClimbingBoardGPT demos.
This module is intentionally small and webapp-friendly: it loads trained
checkpoints once, keeps them in memory, and exposes route generation helpers.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import torch
from .config import BoardConfig, load_board_config
from .generation import generate_one
from .grades import to_grouped_v
from .models import JointRouteGPT, JointRouteTransformerRegressor
from .tokenization import (
angle_token,
board_token,
canonicalize_holds,
hold_token,
parse_frames,
)
@dataclass
class LoadedGenerator:
"""A loaded GPT-style route generator plus vocabulary metadata."""
model: JointRouteGPT
stoi: dict[str, int]
itos: dict[int, str]
device: torch.device
checkpoint_path: Path
@dataclass
class LoadedGradePredictor:
"""A loaded transformer grade predictor plus vocabulary metadata."""
model: JointRouteTransformerRegressor
stoi: dict[str, int]
itos: dict[int, str]
device: torch.device
checkpoint_path: Path
max_len: int
pad_id: int
unk_id: int
def load_grade_predictor(
checkpoint_path: str | Path,
device: str | torch.device | None = None,
torch_threads: int | None = None,
) -> LoadedGradePredictor:
"""Load a trained joint grade-prediction checkpoint.
Args:
checkpoint_path:
Path to ``models/joint_transformer_grade_predictor.pth``.
device:
``"cpu"``, ``"cuda"``, or None for auto-detection.
torch_threads:
Optional CPU thread cap for small VPS demos.
Returns:
LoadedGradePredictor containing the PyTorch model and tokenizer maps.
"""
if torch_threads is not None:
torch.set_num_threads(int(torch_threads))
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Could not find grade predictor checkpoint: {checkpoint_path}")
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
try:
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
cfg = checkpoint["config"]
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
itos = {int(k): str(v) for k, v in checkpoint["itos"].items()}
coord_features = checkpoint["coord_features"]
if not isinstance(coord_features, torch.Tensor):
coord_features = torch.tensor(coord_features, dtype=torch.float32)
model = JointRouteTransformerRegressor(
vocab_size=cfg["vocab_size"],
max_len=cfg["max_len"],
coord_features=coord_features,
d_model=cfg.get("d_model", 128),
nhead=cfg.get("nhead", 4),
num_layers=cfg.get("num_layers", 4),
dim_feedforward=cfg.get("dim_feedforward", 256),
dropout=cfg.get("dropout", 0.10),
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
).to(resolved_device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return LoadedGradePredictor(
model=model,
stoi=stoi,
itos=itos,
device=resolved_device,
checkpoint_path=checkpoint_path,
max_len=int(cfg["max_len"]),
pad_id=int(cfg.get("pad_id", stoi["<PAD>"])),
unk_id=int(stoi["<UNK>"]),
)
def predict_route_grade(
grade_predictor: LoadedGradePredictor,
tokens: list[str],
) -> dict[str, object]:
"""Predict the grade of a route-token sequence.
The grade token is removed before scoring, because the predictor should
infer the grade from the board, angle, and hold-role tokens rather than
reading the requested grade.
"""
model_tokens = [token for token in tokens if not str(token).startswith("<GRADE_")]
if model_tokens and model_tokens[0] == "<BOS>":
model_tokens = ["<CLS>"] + model_tokens[1:]
else:
model_tokens = ["<CLS>"] + model_tokens
ids = [grade_predictor.stoi.get(token, grade_predictor.unk_id) for token in model_tokens]
ids = ids[: grade_predictor.max_len]
mask = [1] * len(ids)
if len(ids) < grade_predictor.max_len:
pad_n = grade_predictor.max_len - len(ids)
ids += [grade_predictor.pad_id] * pad_n
mask += [0] * pad_n
with torch.no_grad():
input_ids = torch.tensor([ids], dtype=torch.long, device=grade_predictor.device)
attention_mask = torch.tensor([mask], dtype=torch.bool, device=grade_predictor.device)
pred_display_difficulty = float(grade_predictor.model(input_ids, attention_mask).cpu().item())
return {
"predicted_display_difficulty": pred_display_difficulty,
"predicted_grouped_v": int(to_grouped_v(pred_display_difficulty)),
}
def load_route_generator(
checkpoint_path: str | Path,
device: str | torch.device | None = None,
torch_threads: int | None = None,
) -> LoadedGenerator:
"""Load a trained joint route generator checkpoint.
Args:
checkpoint_path:
Path to ``models/joint_route_gpt_generator.pth``.
device:
``"cpu"``, ``"cuda"``, or None for auto-detection.
torch_threads:
Optional CPU thread cap for small VPS demos.
Returns:
LoadedGenerator containing the PyTorch model and tokenizer maps.
"""
if torch_threads is not None:
torch.set_num_threads(int(torch_threads))
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Could not find generator checkpoint: {checkpoint_path}")
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
try:
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
cfg = checkpoint["config"]
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
itos = {int(k): str(v) for k, v in checkpoint["itos"].items()}
model = JointRouteGPT(
vocab_size=cfg["vocab_size"],
block_size=cfg["block_size"],
n_embd=cfg.get("n_embd", 128),
n_head=cfg.get("n_head", 4),
n_layer=cfg.get("n_layer", 4),
dropout=cfg.get("dropout", 0.10),
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
).to(resolved_device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return LoadedGenerator(
model=model,
stoi=stoi,
itos=itos,
device=resolved_device,
checkpoint_path=checkpoint_path,
)
def generate_route(
generator: LoadedGenerator,
board_config: BoardConfig,
angle: int,
grade: int,
temperature: float = 0.9,
top_k: int | None = 50,
max_new_tokens: int = 40,
) -> dict[str, object]:
"""Generate a single route for a board/angle/grade condition."""
return {
"board_key": board_config.board_key,
"board_display_name": board_config.display_name,
**generate_one(
model=generator.model,
stoi=generator.stoi,
itos=generator.itos,
device=generator.device,
board_prefix=board_config.token_prefix,
angle=int(angle),
grouped_v=int(grade),
role_name_to_id=board_config.role_definitions,
temperature=float(temperature),
top_k=top_k,
max_new_tokens=int(max_new_tokens),
),
}
def load_board_for_demo(board: str, config_dir: str | Path | None = None) -> BoardConfig:
"""Load a board config by key, with a clearer demo error message."""
try:
return load_board_config(board, config_dir=config_dir)
except FileNotFoundError as exc:
raise FileNotFoundError(
f"Unknown board '{board}'. Expected one of the JSON configs in configs/."
) from exc
def build_placement_lookup_from_token_metadata(df_token_meta) -> dict[tuple[str, int], dict]:
"""Build the placement lookup expected by tokenization helpers.
The training-time tokenization code canonicalizes holds using a lookup keyed
by ``(board_key, placement_id)``. At inference/demo time, we usually have
``token_metadata.csv`` rather than the raw database, so this reconstructs
the necessary coordinate lookup from token metadata.
"""
hold_meta = df_token_meta[df_token_meta["kind"] == "hold"].dropna(subset=["placement_id"]).copy()
lookup: dict[tuple[str, int], dict] = {}
for _, row in hold_meta.drop_duplicates(["board_key", "placement_id"]).iterrows():
key = (str(row["board_key"]), int(row["placement_id"]))
lookup[key] = {
"board_key": str(row["board_key"]),
"board_token_prefix": str(row["board_token_prefix"]),
"placement_id": int(row["placement_id"]),
"x": float(row["x"]),
"y": float(row["y"]),
"x_norm": float(row.get("x_norm", 0.0)),
"y_norm": float(row.get("y_norm", 0.0)),
}
return lookup
def frames_to_grade_model_tokens(
frames: str,
angle: int,
board_config: BoardConfig,
df_token_meta,
) -> list[str]:
"""Convert a user-provided frames string into grade-predictor tokens.
Output format matches training for the grade predictor:
``<CLS> <BOARD_...> <ANGLE_...> <BOARDPREFIX_p..._role> ... <EOS>``
The route is canonicalized using the same role/y/x ordering used during
tokenization. No grade token is included.
"""
placement_lookup = build_placement_lookup_from_token_metadata(df_token_meta)
holds = parse_frames(frames)
holds = canonicalize_holds(holds, board_config, placement_lookup)
tokens = [
"<CLS>",
board_token(board_config),
angle_token(angle),
]
tokens.extend(
hold_token(placement_id, role_id, board_config)
for placement_id, role_id in holds
)
tokens.append("<EOS>")
return tokens
def predict_frames_grade(
grade_predictor: LoadedGradePredictor,
frames: str,
angle: int,
board_config: BoardConfig,
df_token_meta,
) -> dict[str, object]:
"""Predict grade from board, angle, and a BoardLib frames string."""
tokens = frames_to_grade_model_tokens(
frames=frames,
angle=angle,
board_config=board_config,
df_token_meta=df_token_meta,
)
# predict_route_grade accepts either <BOS>-style generated tokens or
# already-prepared <CLS>-style model tokens. It will leave the leading
# <CLS> intact through the fallback branch.
pred = predict_route_grade(grade_predictor, tokens)
return {
**pred,
"tokens": tokens,
"sequence": " ".join(tokens),
"board_key": board_config.board_key,
"board_display_name": board_config.display_name,
"requested_angle": int(angle),
"frames": frames,
}

View File

@@ -41,7 +41,11 @@ class JointRouteTransformerRegressor(nn.Module):
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers,
enable_nested_tensor=False,
)
self.norm = nn.LayerNorm(d_model)
self.head = nn.Sequential(
nn.Linear(d_model, d_model),
@@ -96,7 +100,11 @@ class JointRouteGPT(nn.Module):
batch_first=True,
norm_first=True,
)
self.blocks = nn.TransformerEncoder(layer, num_layers=n_layer)
self.blocks = nn.TransformerEncoder(
layer,
num_layers=n_layer,
enable_nested_tensor=False,
)
self.ln_f = nn.LayerNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight

View File

@@ -75,3 +75,56 @@ def safe_train_test_split(
random_state=random_state,
stratify=None,
)
def assign_group_splits(
df: pd.DataFrame,
group_cols: list[str],
test_size: float,
val_size_within_temp: float,
random_state: int,
stratify_col: str | None = None,
) -> pd.Series:
"""Assign train/val/test splits at group level.
This prevents multiple rows for the same logical climb, for example the
same UUID at several angles, from being distributed across different
splits. The returned Series is indexed like ``df`` and contains
``train``, ``val``, or ``test``.
"""
group_df = df[group_cols + ([stratify_col] if stratify_col else [])].copy()
group_df["__row_index"] = range(len(group_df))
group_df = group_df.drop_duplicates(group_cols).reset_index(drop=True)
train_groups, temp_groups = safe_train_test_split(
group_df,
test_size=test_size,
random_state=random_state,
stratify_col=stratify_col,
)
val_groups, test_groups = safe_train_test_split(
temp_groups,
test_size=val_size_within_temp,
random_state=random_state,
stratify_col=stratify_col,
)
def key_frame(frame: pd.DataFrame) -> set[tuple]:
return set(map(tuple, frame[group_cols].astype(str).values.tolist()))
train_keys = key_frame(train_groups)
val_keys = key_frame(val_groups)
test_keys = key_frame(test_groups)
def split_for_row(row) -> str:
key = tuple(str(row[col]) for col in group_cols)
if key in train_keys:
return "train"
if key in val_keys:
return "val"
if key in test_keys:
return "test"
raise KeyError(f"Could not assign split for group key {key}")
return df.apply(split_for_row, axis=1)

View File

@@ -0,0 +1,353 @@
"""Visualization utilities for generated ClimbingBoardGPT routes.
The route-overlay functions here deliberately mimic the old TB2/Kilter
notebook convention: draw the board composite image with the product-size
coordinate extent, then scatter route holds in board coordinates.
"""
from __future__ import annotations
import ast
import re
from pathlib import Path
from typing import Iterable
import matplotlib.pyplot as plt
import pandas as pd
HOLD_TOKEN_PATTERN = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
# These are the same coordinate windows used in the earlier visualization
# notebooks. They come from the product size geometry rather than from the
# min/max of the actual holds. The hold coordinates are inset by about 4in,
# so using hold min/max directly shifts/stretches the background image.
BOARD_CANVAS = {
"tb2": {
"extent": [-68, 68, 0, 144],
"figsize": (16, 14),
"image_aspect": "auto",
},
"kilter": {
"extent": [-24, 168, 0, 156],
"figsize": (17, 12),
"image_aspect": "equal",
},
}
ROLE_COLORS = {
"start": "#2ecc71",
"middle": "#3498db",
"finish": "#e74c3c",
"foot": "#f1c40f",
"unknown": "#9ca3af",
}
ROLE_MARKERS = {
"start": "o",
"middle": "o",
"finish": "*",
"foot": "s",
"unknown": "o",
}
ROLE_SIZES = {
"start": 150,
"middle": 150,
"finish": 230,
"foot": 95,
"unknown": 150,
}
def parse_tokens(value) -> list[str]:
"""Parse a generated token list from a list, repr string, or sequence string."""
if isinstance(value, list):
return [str(v) for v in value]
if not isinstance(value, str):
return []
try:
parsed = ast.literal_eval(value)
if isinstance(parsed, list):
return [str(v) for v in parsed]
except Exception:
pass
return value.split()
def tokens_to_route_records(tokens: Iterable[str]) -> pd.DataFrame:
"""Extract generated hold records from model tokens."""
rows = []
for token in tokens:
match = HOLD_TOKEN_PATTERN.match(str(token))
if match is None:
continue
rows.append(
{
"token": token,
"board_token_prefix": match.group(1),
"placement_id": int(match.group(2)),
"role": match.group(3),
}
)
return pd.DataFrame(rows)
def load_token_metadata(tokenized_dir: str | Path) -> pd.DataFrame:
"""Load token metadata produced by ``scripts/01_tokenize_routes.py``."""
tokenized_dir = Path(tokenized_dir)
path = tokenized_dir / "token_metadata.csv"
if not path.exists():
raise FileNotFoundError(
f"Could not find {path}. Run scripts/01_tokenize_routes.py first."
)
return pd.read_csv(path)
def board_canvas_settings(board_key: str, df_token_meta: pd.DataFrame | None = None) -> dict[str, object]:
"""Return board canvas settings.
Known boards use hand-calibrated extents from the old notebooks. Unknown
boards fall back to coordinate bounds from ``token_metadata.csv``.
"""
board_key = str(board_key)
if board_key in BOARD_CANVAS:
return dict(BOARD_CANVAS[board_key])
if df_token_meta is None:
raise ValueError(f"No board canvas settings for board_key={board_key!r}.")
holds = _board_holds(df_token_meta, board_key)
x_min, x_max = float(holds["x"].min()), float(holds["x"].max())
y_min, y_max = float(holds["y"].min()), float(holds["y"].max())
x_pad = max((x_max - x_min) * 0.06, 1.0)
y_pad = max((y_max - y_min) * 0.06, 1.0)
return {
"extent": [x_min - x_pad, x_max + x_pad, y_min - y_pad, y_max + y_pad],
"figsize": (8, 10),
"image_aspect": "auto",
}
def _board_holds(df_token_meta: pd.DataFrame, board_key: str) -> pd.DataFrame:
holds = df_token_meta[
(df_token_meta["kind"] == "hold")
& (df_token_meta["board_key"].astype(str) == str(board_key))
].copy()
if holds.empty:
raise ValueError(
f"No hold metadata found for board_key={board_key!r}. "
"Check token_metadata.csv and board config."
)
holds = holds.drop_duplicates(["board_key", "placement_id"]).copy()
return holds
def _route_with_coords(
route_records: pd.DataFrame,
df_token_meta: pd.DataFrame,
board_key: str,
) -> pd.DataFrame:
holds = _board_holds(df_token_meta, board_key)
coords = holds[["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
["board_key", "placement_id"]
)
merged = route_records.merge(
coords,
on=["board_token_prefix", "placement_id"],
how="left",
)
return merged
def visualize_route_tokens(
tokens: Iterable[str],
df_token_meta: pd.DataFrame,
board_key: str,
title: str | None = None,
subtitle: str | None = None,
output_path: str | Path | None = None,
annotate: bool = False,
show_all_holds: bool | None = None,
background_image: str | Path | None = None,
figsize: tuple[float, float] | None = None,
dpi: int = 160,
):
"""Visualize a generated route as a board overlay plot.
If a background image is supplied, the plot uses the calibrated canvas
extent from the old project notebooks. If no image is supplied, it falls
back to a clean coordinate-board style and shows available holds.
"""
route_records = tokens_to_route_records(tokens)
if route_records.empty:
raise ValueError("No hold tokens found in generated sequence.")
board_holds = _board_holds(df_token_meta, board_key)
route_df = _route_with_coords(route_records, df_token_meta, board_key)
route_df = route_df.dropna(subset=["x", "y"]).copy()
if route_df.empty:
raise ValueError(
"Generated route contained hold tokens, but none matched the board metadata."
)
canvas = board_canvas_settings(board_key, df_token_meta)
extent = [float(v) for v in canvas["extent"]]
x_min, x_max, y_min, y_max = extent
image_aspect = str(canvas.get("image_aspect", "auto"))
figsize = figsize or canvas.get("figsize", (8, 10))
background_exists = background_image is not None and Path(background_image).exists()
if show_all_holds is None:
show_all_holds = not background_exists
fig, ax = plt.subplots(figsize=figsize)
if background_exists:
img = plt.imread(Path(background_image))
ax.imshow(
img,
extent=extent,
aspect=image_aspect,
alpha=1.0,
zorder=0,
)
if show_all_holds:
ax.scatter(
board_holds["x"],
board_holds["y"],
s=22,
c="#d1d5db",
alpha=0.45,
linewidths=0,
label="available holds",
zorder=1,
)
# Draw route holds role-by-role so the legend is meaningful.
for role, frame in route_df.groupby("role", sort=False):
ax.scatter(
frame["x"],
frame["y"],
s=ROLE_SIZES.get(role, 150),
c=ROLE_COLORS.get(role, ROLE_COLORS["unknown"]),
marker=ROLE_MARKERS.get(role, "o"),
edgecolors="#111827",
linewidths=1.0,
alpha=0.96,
label=role,
zorder=3,
)
if annotate:
for _, row in route_df.iterrows():
ax.text(
row["x"],
row["y"],
str(int(row["placement_id"])),
ha="center",
va="center",
fontsize=7,
fontweight="bold",
color="white",
bbox=dict(
boxstyle="circle,pad=0.12",
alpha=0.45,
facecolor="#111827",
edgecolor="white",
linewidth=0.8,
),
zorder=4,
)
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
if image_aspect == "equal":
ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("X Position")
ax.set_ylabel("Y Position")
# Put the title and subtitle at the figure level, not the axes level.
# This avoids the old overlap where ax.set_title(...) and ax.text(y=1.01)
# competed for the same narrow top margin.
has_header = bool(title or subtitle)
if title:
fig.suptitle(
title,
fontsize=14,
fontweight="bold",
y=0.985,
)
if subtitle:
fig.text(
0.5,
0.958,
subtitle,
ha="center",
va="top",
fontsize=9,
color="#4b5563",
)
if background_exists:
ax.grid(False)
else:
ax.grid(True, alpha=0.18)
ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0), frameon=False)
# Reserve top space for the figure-level title/subtitle.
if has_header:
fig.tight_layout(rect=[0, 0, 1, 0.925])
else:
fig.tight_layout()
if output_path is not None:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(output_path, dpi=dpi, bbox_inches="tight")
return fig, ax, route_df
def visualize_route_result(
result: dict[str, object],
df_token_meta: pd.DataFrame,
output_path: str | Path | None = None,
annotate: bool = False,
background_image: str | Path | None = None,
):
"""Visualize a result dictionary returned by ``generate_route``."""
board_key = str(result["board_key"])
tokens = parse_tokens(result["tokens"])
title = (
f"{str(result.get('board_display_name', board_key))} "
f"generated V{int(result['requested_grouped_v'])} @ {int(result['requested_angle'])}°"
)
subtitle_parts = [
f"valid={result.get('basic_valid')}",
f"holds={result.get('n_hold_tokens')}",
]
if "predicted_grouped_v" in result:
subtitle_parts.append(
f"predicted V{int(result['predicted_grouped_v'])}"
f" ({float(result['predicted_display_difficulty']):.2f})"
)
if "critic_v_error" in result:
subtitle_parts.append(f"error {int(result['critic_v_error']):+d}V")
subtitle_parts.append(f"temperature={result.get('temperature')}")
subtitle = " | ".join(subtitle_parts)
return visualize_route_tokens(
tokens=tokens,
df_token_meta=df_token_meta,
board_key=board_key,
title=title,
subtitle=subtitle,
output_path=output_path,
annotate=annotate,
background_image=background_image,
)