Add web demo polish and smoke-test pipeline
This commit is contained in:
33
src/climbingboardgpt/checkpoints.py
Normal file
33
src/climbingboardgpt/checkpoints.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint_path: str | Path,
|
||||
map_location: str | torch.device,
|
||||
*,
|
||||
trusted: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Load a PyTorch checkpoint, preferring safer weights-only loading.
|
||||
|
||||
Set ``trusted=True`` only for checkpoints produced by this project or an
|
||||
otherwise trusted source. Older PyTorch versions do not support
|
||||
``weights_only``; those fall back to the legacy loader for compatibility.
|
||||
"""
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
|
||||
try:
|
||||
return torch.load(checkpoint_path, map_location=map_location, weights_only=True)
|
||||
except TypeError:
|
||||
return torch.load(checkpoint_path, map_location=map_location)
|
||||
except Exception as exc:
|
||||
if not trusted:
|
||||
raise RuntimeError(
|
||||
"Could not load checkpoint with weights_only=True. "
|
||||
"Only retry with trusted=True for checkpoints from a trusted source."
|
||||
) from exc
|
||||
return torch.load(checkpoint_path, map_location=map_location, weights_only=False)
|
||||
@@ -135,12 +135,14 @@ def build_placements_query(config: BoardConfig) -> tuple[str, list]:
|
||||
def load_board_data(
|
||||
config: BoardConfig,
|
||||
project_root: str | Path | None = None,
|
||||
max_climbs: int | None = None,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Load climbs and placements data for a single board.
|
||||
|
||||
Args:
|
||||
config: Board configuration
|
||||
project_root: Path to project root (for resolving db_path)
|
||||
max_climbs: Optional row limit for fast smoke-test loads.
|
||||
|
||||
Returns:
|
||||
Tuple of (climbs DataFrame, placements DataFrame)
|
||||
@@ -154,6 +156,11 @@ def load_board_data(
|
||||
|
||||
climbs_query, climbs_params = build_climbs_query(config)
|
||||
placements_query, placements_params = build_placements_query(config)
|
||||
if max_climbs is not None:
|
||||
if max_climbs < 1:
|
||||
raise ValueError("max_climbs must be at least 1.")
|
||||
climbs_query = f"{climbs_query}\nORDER BY c.uuid, cs.angle\nLIMIT ?"
|
||||
climbs_params = [*climbs_params, int(max_climbs)]
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
df_climbs = pd.read_sql_query(climbs_query, conn, params=climbs_params)
|
||||
@@ -174,6 +181,7 @@ def load_board_data(
|
||||
def load_multi_board_data(
|
||||
configs: list[BoardConfig],
|
||||
project_root: str | Path | None = None,
|
||||
max_climbs_per_board: int | None = None,
|
||||
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""Load and concatenate data from multiple boards.
|
||||
|
||||
@@ -184,6 +192,7 @@ def load_multi_board_data(
|
||||
Args:
|
||||
configs: List of board configurations
|
||||
project_root: Path to project root
|
||||
max_climbs_per_board: Optional row limit per board for smoke tests.
|
||||
|
||||
Returns:
|
||||
Tuple of (combined climbs DataFrame, combined placements DataFrame)
|
||||
@@ -192,11 +201,15 @@ def load_multi_board_data(
|
||||
placement_frames = []
|
||||
|
||||
for config in configs:
|
||||
climbs, placements = load_board_data(config, project_root=project_root)
|
||||
climbs, placements = load_board_data(
|
||||
config,
|
||||
project_root=project_root,
|
||||
max_climbs=max_climbs_per_board,
|
||||
)
|
||||
climb_frames.append(climbs)
|
||||
placement_frames.append(placements)
|
||||
|
||||
return (
|
||||
pd.concat(climb_frames, ignore_index=True),
|
||||
pd.concat(placement_frames, ignore_index=True),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from torch.utils.data import Dataset
|
||||
|
||||
class RouteGradeDataset(Dataset):
|
||||
def __init__(self, df, max_len: int, pad_id: int):
|
||||
self.row_ids = df["row_id"].tolist() if "row_id" in df.columns else df.index.tolist()
|
||||
self.ids = df["model_ids"].tolist()
|
||||
self.targets = df["display_difficulty"].astype(float).values
|
||||
self.uuids = df["uuid"].tolist()
|
||||
@@ -28,6 +29,7 @@ class RouteGradeDataset(Dataset):
|
||||
"input_ids": torch.tensor(ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(mask, dtype=torch.bool),
|
||||
"target": torch.tensor(self.targets[idx], dtype=torch.float32),
|
||||
"row_id": int(self.row_ids[idx]),
|
||||
"uuid": self.uuids[idx],
|
||||
"board_key": self.boards[idx],
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from typing import Iterable
|
||||
|
||||
@@ -8,38 +7,11 @@ import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.spatial.distance import pdist
|
||||
|
||||
HOLD_TOKEN_PATTERN = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
from .tokenization import parse_tokens, tokens_to_hold_records
|
||||
|
||||
|
||||
def parse_token_list(value) -> list[str]:
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
return []
|
||||
try:
|
||||
parsed = ast.literal_eval(value)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
except Exception:
|
||||
pass
|
||||
return value.split()
|
||||
|
||||
|
||||
def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
rows = []
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_PATTERN.match(token)
|
||||
if match is None:
|
||||
continue
|
||||
rows.append(
|
||||
{
|
||||
"token": token,
|
||||
"board_token_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
return parse_tokens(value)
|
||||
|
||||
|
||||
def validity_from_records(records: list[dict[str, object]], requested_board_prefix: str | None = None) -> dict[str, object]:
|
||||
@@ -102,30 +74,30 @@ def nearest_real_route_same_board(
|
||||
real_df: pd.DataFrame,
|
||||
) -> dict[str, object]:
|
||||
board_frame = real_df[real_df["board_key"] == generated_board_key]
|
||||
best = {
|
||||
"nearest_real_jaccard": -1.0,
|
||||
"nearest_real_uuid": None,
|
||||
"nearest_real_name": None,
|
||||
"nearest_real_grouped_v": None,
|
||||
"nearest_real_angle": None,
|
||||
if board_frame.empty:
|
||||
return {
|
||||
"nearest_real_jaccard": np.nan,
|
||||
"nearest_real_uuid": None,
|
||||
"nearest_real_name": None,
|
||||
"nearest_real_grouped_v": None,
|
||||
"nearest_real_angle": None,
|
||||
"novelty_distance": np.nan,
|
||||
}
|
||||
|
||||
similarities = board_frame["hold_set"].map(lambda hold_set: jaccard(generated_set, hold_set))
|
||||
best_idx = similarities.idxmax()
|
||||
row = board_frame.loc[best_idx]
|
||||
|
||||
nearest_real_jaccard = float(similarities.loc[best_idx])
|
||||
return {
|
||||
"nearest_real_jaccard": nearest_real_jaccard,
|
||||
"nearest_real_uuid": row["uuid"],
|
||||
"nearest_real_name": row["climb_name"],
|
||||
"nearest_real_grouped_v": row["grouped_v"],
|
||||
"nearest_real_angle": row["angle"],
|
||||
"novelty_distance": 1.0 - nearest_real_jaccard,
|
||||
}
|
||||
|
||||
for _, row in board_frame.iterrows():
|
||||
similarity = jaccard(generated_set, row["hold_set"])
|
||||
if similarity > best["nearest_real_jaccard"]:
|
||||
best.update(
|
||||
{
|
||||
"nearest_real_jaccard": similarity,
|
||||
"nearest_real_uuid": row["uuid"],
|
||||
"nearest_real_name": row["climb_name"],
|
||||
"nearest_real_grouped_v": row["grouped_v"],
|
||||
"nearest_real_angle": row["angle"],
|
||||
}
|
||||
)
|
||||
|
||||
best["novelty_distance"] = 1.0 - float(best["nearest_real_jaccard"])
|
||||
return best
|
||||
|
||||
|
||||
def build_placement_coords(df_token_meta: pd.DataFrame) -> dict[tuple[str, int], dict[str, float]]:
|
||||
hold_meta = df_token_meta[df_token_meta["kind"] == "hold"].dropna(subset=["placement_id"]).copy()
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
HOLD_TOKEN_PATTERN = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
from .tokenization import tokens_to_hold_records
|
||||
|
||||
|
||||
def top_k_filter(logits: torch.Tensor, k: int | None) -> torch.Tensor:
|
||||
@@ -61,20 +60,7 @@ def prompt_tokens(board_prefix: str, angle: int, grouped_v: int) -> list[str]:
|
||||
|
||||
|
||||
def hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
rows = []
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_PATTERN.match(token)
|
||||
if match is None:
|
||||
continue
|
||||
rows.append(
|
||||
{
|
||||
"board_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
"token": token,
|
||||
}
|
||||
)
|
||||
return rows
|
||||
return tokens_to_hold_records(tokens)
|
||||
|
||||
|
||||
def validity_summary(tokens: Iterable[str], requested_board_prefix: str | None = None) -> dict[str, object]:
|
||||
|
||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from .checkpoints import load_checkpoint
|
||||
from .config import BoardConfig, load_board_config
|
||||
from .generation import generate_one
|
||||
from .grades import to_grouped_v
|
||||
@@ -75,10 +76,7 @@ def load_grade_predictor(
|
||||
|
||||
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
|
||||
checkpoint = load_checkpoint(checkpoint_path, map_location=resolved_device, trusted=True)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
@@ -176,10 +174,7 @@ def load_route_generator(
|
||||
|
||||
resolved_device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=resolved_device)
|
||||
checkpoint = load_checkpoint(checkpoint_path, map_location=resolved_device, trusted=True)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
@@ -332,4 +327,3 @@ def predict_frames_grade(
|
||||
"requested_angle": int(angle),
|
||||
"frames": frames,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import ast
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
@@ -39,6 +40,43 @@ def parse_frames(frames_str: str | None) -> list[tuple[int, int]]:
|
||||
return [(int(placement_id), int(role_id)) for placement_id, role_id in matches]
|
||||
|
||||
|
||||
def parse_tokens(value) -> list[str]:
|
||||
"""Parse tokens from a list, repr-style list string, or whitespace sequence."""
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
if not isinstance(value, str):
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed = ast.literal_eval(value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(v) for v in parsed]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value.split()
|
||||
|
||||
|
||||
def tokens_to_hold_records(tokens: Iterable[str]) -> list[dict[str, object]]:
|
||||
"""Extract hold records from model tokens using the shared hold-token grammar."""
|
||||
rows: list[dict[str, object]] = []
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_PATTERN.match(str(token))
|
||||
if match is None:
|
||||
continue
|
||||
board_prefix = match.group(1)
|
||||
rows.append(
|
||||
{
|
||||
"token": str(token),
|
||||
"board_token_prefix": board_prefix,
|
||||
"board_prefix": board_prefix,
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def make_placement_lookup(df_placements: pd.DataFrame) -> dict[tuple[str, int], dict]:
|
||||
rows = {}
|
||||
for _, row in df_placements.iterrows():
|
||||
|
||||
@@ -94,7 +94,6 @@ def assign_group_splits(
|
||||
``train``, ``val``, or ``test``.
|
||||
"""
|
||||
group_df = df[group_cols + ([stratify_col] if stratify_col else [])].copy()
|
||||
group_df["__row_index"] = range(len(group_df))
|
||||
group_df = group_df.drop_duplicates(group_cols).reset_index(drop=True)
|
||||
|
||||
train_groups, temp_groups = safe_train_test_split(
|
||||
|
||||
@@ -6,15 +6,13 @@ coordinate extent, then scatter route holds in board coordinates.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
HOLD_TOKEN_PATTERN = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
from .tokenization import parse_tokens, tokens_to_hold_records
|
||||
|
||||
# These are the same coordinate windows used in the earlier visualization
|
||||
# notebooks. They come from the product size geometry rather than from the
|
||||
@@ -58,39 +56,9 @@ ROLE_SIZES = {
|
||||
}
|
||||
|
||||
|
||||
def parse_tokens(value) -> list[str]:
|
||||
"""Parse a generated token list from a list, repr string, or sequence string."""
|
||||
if isinstance(value, list):
|
||||
return [str(v) for v in value]
|
||||
if not isinstance(value, str):
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed = ast.literal_eval(value)
|
||||
if isinstance(parsed, list):
|
||||
return [str(v) for v in parsed]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return value.split()
|
||||
|
||||
|
||||
def tokens_to_route_records(tokens: Iterable[str]) -> pd.DataFrame:
|
||||
"""Extract generated hold records from model tokens."""
|
||||
rows = []
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_PATTERN.match(str(token))
|
||||
if match is None:
|
||||
continue
|
||||
rows.append(
|
||||
{
|
||||
"token": token,
|
||||
"board_token_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
}
|
||||
)
|
||||
return pd.DataFrame(rows)
|
||||
return pd.DataFrame(tokens_to_hold_records(tokens))
|
||||
|
||||
|
||||
def load_token_metadata(tokenized_dir: str | Path) -> pd.DataFrame:
|
||||
|
||||
Reference in New Issue
Block a user