Add web demo polish and smoke-test pipeline
This commit is contained in:
@@ -14,9 +14,9 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -39,6 +39,8 @@ from climbingboardgpt.inference import (
|
||||
predict_frames_grade,
|
||||
predict_route_grade,
|
||||
)
|
||||
from climbingboardgpt.tokenization import tokens_to_hold_records
|
||||
from climbingboardgpt.utils import json_safe
|
||||
from climbingboardgpt.visualization import BOARD_CANVAS, load_token_metadata, tokens_to_route_records
|
||||
|
||||
|
||||
@@ -59,7 +61,7 @@ GRADE_MODEL_PATH = MODEL_DIR / "joint_transformer_grade_predictor.pth"
|
||||
|
||||
BOARD_IMAGE_PATHS = {
|
||||
"tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png",
|
||||
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_compose.png",
|
||||
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_composite.png",
|
||||
}
|
||||
|
||||
|
||||
@@ -80,26 +82,6 @@ class PredictRequest(BaseModel):
|
||||
frames: str = Field(..., min_length=1, max_length=500)
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _json_safe(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_json_safe(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return [_json_safe(v) for v in value]
|
||||
if hasattr(value, "item"):
|
||||
try:
|
||||
return value.item()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if pd.isna(value):
|
||||
return None
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
def _board_config(board: str):
|
||||
try:
|
||||
return app.state.board_configs[board]
|
||||
@@ -240,13 +222,10 @@ def _role_limit_validity(tokens: list[str], requested_board_prefix: str) -> dict
|
||||
"unknown": 0,
|
||||
}
|
||||
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_RE.match(str(token))
|
||||
if match is None:
|
||||
continue
|
||||
board_prefix, _placement_id, role = match.groups()
|
||||
if str(board_prefix) != str(requested_board_prefix):
|
||||
for record in tokens_to_hold_records(tokens):
|
||||
if str(record["board_token_prefix"]) != str(requested_board_prefix):
|
||||
continue
|
||||
role = str(record["role"])
|
||||
counts[role] = counts.get(role, 0) + 1
|
||||
|
||||
n_start = int(counts.get("start", 0))
|
||||
@@ -305,10 +284,6 @@ def _invalid_prediction_reasons(validity: dict[str, Any]) -> list[str]:
|
||||
return reasons
|
||||
|
||||
|
||||
|
||||
HOLD_TOKEN_RE = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
|
||||
|
||||
def _angle_key(angle: Any) -> int:
|
||||
try:
|
||||
return int(round(float(angle)))
|
||||
@@ -335,15 +310,12 @@ def _route_signature_from_holds(board: str, angle: Any, holds: list[dict[str, An
|
||||
|
||||
def _holds_from_sequence(sequence: str) -> list[dict[str, Any]]:
|
||||
holds: list[dict[str, Any]] = []
|
||||
for token in str(sequence).split():
|
||||
match = HOLD_TOKEN_RE.match(token)
|
||||
if match is None:
|
||||
continue
|
||||
for record in tokens_to_hold_records(str(sequence).split()):
|
||||
holds.append(
|
||||
{
|
||||
"board_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
"board_prefix": str(record["board_token_prefix"]),
|
||||
"placement_id": int(record["placement_id"]),
|
||||
"role": str(record["role"]),
|
||||
}
|
||||
)
|
||||
return holds
|
||||
@@ -497,7 +469,7 @@ def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[st
|
||||
holds = _tokens_to_holds(board, tokens)
|
||||
angle = result.get("requested_angle", result.get("angle", 0))
|
||||
|
||||
return _json_safe(
|
||||
return json_safe(
|
||||
{
|
||||
**result,
|
||||
"tokens": tokens,
|
||||
@@ -517,13 +489,8 @@ def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[st
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="ClimbingBoardGPT", version="0.1.0")
|
||||
app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static")
|
||||
app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
if TORCH_THREADS_INT is not None:
|
||||
torch.set_num_threads(TORCH_THREADS_INT)
|
||||
|
||||
@@ -553,6 +520,12 @@ def startup() -> None:
|
||||
device=DEVICE,
|
||||
torch_threads=TORCH_THREADS_INT,
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="ClimbingBoardGPT", version="0.1.0", lifespan=lifespan)
|
||||
app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static")
|
||||
app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@@ -604,7 +577,7 @@ def boards():
|
||||
@app.get("/api/board-holds/{board}")
|
||||
def board_holds(board: str):
|
||||
config = _board_config(board)
|
||||
return _json_safe({
|
||||
return json_safe({
|
||||
"board_key": board,
|
||||
"display_name": config.display_name,
|
||||
"token_prefix": config.token_prefix,
|
||||
@@ -729,7 +702,7 @@ def predict(req: PredictRequest):
|
||||
reasons = _invalid_prediction_reasons(validity)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=_json_safe(
|
||||
detail=json_safe(
|
||||
{
|
||||
"message": "Cannot predict grade for an invalid climb.",
|
||||
"reasons": reasons,
|
||||
|
||||
Reference in New Issue
Block a user