Add web demo polish and smoke-test pipeline

This commit is contained in:
Pawel
2026-05-24 20:00:40 -04:00
parent 2391c80003
commit bbf276d642
22 changed files with 614 additions and 306 deletions

View File

@@ -14,9 +14,9 @@ from __future__ import annotations
import hashlib
import os
import re
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
@@ -39,6 +39,8 @@ from climbingboardgpt.inference import (
predict_frames_grade,
predict_route_grade,
)
from climbingboardgpt.tokenization import tokens_to_hold_records
from climbingboardgpt.utils import json_safe
from climbingboardgpt.visualization import BOARD_CANVAS, load_token_metadata, tokens_to_route_records
@@ -59,7 +61,7 @@ GRADE_MODEL_PATH = MODEL_DIR / "joint_transformer_grade_predictor.pth"
BOARD_IMAGE_PATHS = {
"tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png",
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_compose.png",
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_composite.png",
}
@@ -80,26 +82,6 @@ class PredictRequest(BaseModel):
frames: str = Field(..., min_length=1, max_length=500)
def _json_safe(value: Any) -> Any:
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
if isinstance(value, list):
return [_json_safe(v) for v in value]
if isinstance(value, tuple):
return [_json_safe(v) for v in value]
if hasattr(value, "item"):
try:
return value.item()
except Exception:
pass
try:
if pd.isna(value):
return None
except Exception:
pass
return value
def _board_config(board: str):
try:
return app.state.board_configs[board]
@@ -240,13 +222,10 @@ def _role_limit_validity(tokens: list[str], requested_board_prefix: str) -> dict
"unknown": 0,
}
for token in tokens:
match = HOLD_TOKEN_RE.match(str(token))
if match is None:
continue
board_prefix, _placement_id, role = match.groups()
if str(board_prefix) != str(requested_board_prefix):
for record in tokens_to_hold_records(tokens):
if str(record["board_token_prefix"]) != str(requested_board_prefix):
continue
role = str(record["role"])
counts[role] = counts.get(role, 0) + 1
n_start = int(counts.get("start", 0))
@@ -305,10 +284,6 @@ def _invalid_prediction_reasons(validity: dict[str, Any]) -> list[str]:
return reasons
HOLD_TOKEN_RE = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
def _angle_key(angle: Any) -> int:
try:
return int(round(float(angle)))
@@ -335,15 +310,12 @@ def _route_signature_from_holds(board: str, angle: Any, holds: list[dict[str, An
def _holds_from_sequence(sequence: str) -> list[dict[str, Any]]:
holds: list[dict[str, Any]] = []
for token in str(sequence).split():
match = HOLD_TOKEN_RE.match(token)
if match is None:
continue
for record in tokens_to_hold_records(str(sequence).split()):
holds.append(
{
"board_prefix": match.group(1),
"placement_id": int(match.group(2)),
"role": match.group(3),
"board_prefix": str(record["board_token_prefix"]),
"placement_id": int(record["placement_id"]),
"role": str(record["role"]),
}
)
return holds
@@ -497,7 +469,7 @@ def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[st
holds = _tokens_to_holds(board, tokens)
angle = result.get("requested_angle", result.get("angle", 0))
return _json_safe(
return json_safe(
{
**result,
"tokens": tokens,
@@ -517,13 +489,8 @@ def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[st
)
app = FastAPI(title="ClimbingBoardGPT", version="0.1.0")
app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static")
app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images")
@app.on_event("startup")
def startup() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI):
if TORCH_THREADS_INT is not None:
torch.set_num_threads(TORCH_THREADS_INT)
@@ -553,6 +520,12 @@ def startup() -> None:
device=DEVICE,
torch_threads=TORCH_THREADS_INT,
)
yield
app = FastAPI(title="ClimbingBoardGPT", version="0.1.0", lifespan=lifespan)
app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static")
app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images")
@app.get("/")
@@ -604,7 +577,7 @@ def boards():
@app.get("/api/board-holds/{board}")
def board_holds(board: str):
config = _board_config(board)
return _json_safe({
return json_safe({
"board_key": board,
"display_name": config.display_name,
"token_prefix": config.token_prefix,
@@ -729,7 +702,7 @@ def predict(req: PredictRequest):
reasons = _invalid_prediction_reasons(validity)
raise HTTPException(
status_code=400,
detail=_json_safe(
detail=json_safe(
{
"message": "Cannot predict grade for an invalid climb.",
"reasons": reasons,