Files
ClimbingBoardGPT/webapp/app.py
2026-05-24 20:00:40 -04:00

715 lines
24 KiB
Python

"""FastAPI web demo for ClimbingBoardGPT.
Inference-only design:
- model checkpoints and token metadata are loaded once at startup;
- board background images are served as static files;
- each request returns route hold coordinates/roles as JSON;
- the browser draws the overlay as SVG on top of the already-loaded image.
Local run:
uvicorn webapp.app:app --host 127.0.0.1 --port 8055
"""
from __future__ import annotations
import hashlib
import os
import sys
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
import pandas as pd
import torch
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT / "src"))
from climbingboardgpt.config import load_board_config
from climbingboardgpt.generation import validity_summary
from climbingboardgpt.inference import (
generate_route,
load_grade_predictor,
load_route_generator,
predict_frames_grade,
predict_route_grade,
)
from climbingboardgpt.tokenization import tokens_to_hold_records
from climbingboardgpt.utils import json_safe
from climbingboardgpt.visualization import BOARD_CANVAS, load_token_metadata, tokens_to_route_records
DEVICE = os.getenv("CBGPT_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu")
TORCH_THREADS = os.getenv("CBGPT_TORCH_THREADS")
TORCH_THREADS_INT = int(TORCH_THREADS) if TORCH_THREADS else None
MODEL_DIR = Path(os.getenv("CBGPT_MODEL_DIR", REPO_ROOT / "models"))
TOKENIZED_DIR = Path(
os.getenv("CBGPT_TOKENIZED_DIR", REPO_ROOT / "data" / "processed" / "tokenized")
)
KNOWN_ROUTES_PATH = Path(
os.getenv("CBGPT_KNOWN_ROUTES_PATH", TOKENIZED_DIR / "route_sequences.csv")
)
GENERATOR_PATH = MODEL_DIR / "joint_route_gpt_generator.pth"
GRADE_MODEL_PATH = MODEL_DIR / "joint_transformer_grade_predictor.pth"
BOARD_IMAGE_PATHS = {
"tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png",
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_composite.png",
}
class GenerateRequest(BaseModel):
board: str = Field(..., pattern="^(tb2|kilter)$")
angle: int = Field(40, ge=0, le=80)
grade: int = Field(6, ge=0, le=16)
temperature: float = Field(0.9, ge=0.1, le=2.0)
top_k: int = Field(50, ge=1, le=500)
max_new_tokens: int = Field(40, ge=4, le=80)
valid_only: bool = Field(True, description="Retry generation until a basic-valid climb is sampled.")
max_attempts: int = Field(8, ge=1, le=25, description="Maximum attempts when valid_only is true.")
class PredictRequest(BaseModel):
board: str = Field(..., pattern="^(tb2|kilter)$")
angle: int = Field(..., ge=0, le=80)
frames: str = Field(..., min_length=1, max_length=500)
def _board_config(board: str):
try:
return app.state.board_configs[board]
except KeyError as exc:
raise HTTPException(status_code=400, detail=f"Unknown board: {board}") from exc
def _require_generator():
if app.state.generator is None:
raise HTTPException(
status_code=503,
detail=f"Route generator checkpoint not loaded. Expected {GENERATOR_PATH}.",
)
return app.state.generator
def _require_grade_predictor():
if app.state.grade_predictor is None:
raise HTTPException(
status_code=503,
detail=f"Grade predictor checkpoint not loaded. Expected {GRADE_MODEL_PATH}.",
)
return app.state.grade_predictor
def _tokens_to_holds(board: str, tokens: list[str]) -> list[dict[str, Any]]:
route_records = tokens_to_route_records(tokens)
if route_records.empty:
return []
token_meta = app.state.token_meta
coords = token_meta[
(token_meta["kind"] == "hold")
& (token_meta["board_key"].astype(str) == str(board))
][["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
["board_key", "placement_id"]
)
merged = route_records.merge(
coords,
on=["board_token_prefix", "placement_id"],
how="left",
).dropna(subset=["x", "y"])
holds = []
for _, row in merged.iterrows():
holds.append(
{
"token": str(row["token"]),
"placement_id": int(row["placement_id"]),
"role": str(row["role"]),
"x": float(row["x"]),
"y": float(row["y"]),
}
)
return holds
def _file_version(path: Path) -> str:
"""Small cache-busting version string based on mtime and file size."""
try:
stat = path.stat()
return f"{int(stat.st_mtime)}-{stat.st_size}"
except FileNotFoundError:
return "missing"
def _static_image_url(board: str) -> str:
image_path = BOARD_IMAGE_PATHS[board]
return f"/board-images/{image_path.name}?board={board}&v={_file_version(image_path)}"
def _file_info(path: Path) -> dict[str, Any]:
if not path.exists():
return {
"path": str(path),
"exists": False,
"size_bytes": None,
"mtime": None,
"sha256_16": None,
}
data = path.read_bytes()
stat = path.stat()
return {
"path": str(path),
"exists": True,
"size_bytes": stat.st_size,
"mtime": stat.st_mtime,
"sha256_16": hashlib.sha256(data).hexdigest()[:16],
}
def _board_available_holds(board: str) -> list[dict[str, Any]]:
"""Return clickable hold coordinates for a board.
Some token-metadata rows can contain missing coordinates for hold-role tokens
that exist in the vocabulary but cannot be plotted directly. Those rows must
be removed before returning JSON, because FastAPI's JSONResponse rejects NaN.
"""
token_meta = app.state.token_meta
holds = token_meta[
(token_meta["kind"] == "hold")
& (token_meta["board_key"].astype(str) == str(board))
][["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
["board_key", "placement_id"]
)
holds = holds.copy()
holds["x"] = pd.to_numeric(holds["x"], errors="coerce")
holds["y"] = pd.to_numeric(holds["y"], errors="coerce")
holds = holds.dropna(subset=["x", "y"])
return [
{
"placement_id": int(row["placement_id"]),
"x": float(row["x"]),
"y": float(row["y"]),
}
for _, row in holds.sort_values(["y", "x", "placement_id"]).iterrows()
]
def _role_limit_validity(tokens: list[str], requested_board_prefix: str) -> dict[str, Any]:
"""Extra webapp validity checks for start/finish counts.
The lower-level validity check requires at least one start and at least one
finish, but BoardLib-style climbs should not have arbitrarily many starts or
finishes. For this demo we enforce at most two starts and at most two finishes.
"""
counts = {
"start": 0,
"middle": 0,
"finish": 0,
"foot": 0,
"unknown": 0,
}
for record in tokens_to_hold_records(tokens):
if str(record["board_token_prefix"]) != str(requested_board_prefix):
continue
role = str(record["role"])
counts[role] = counts.get(role, 0) + 1
n_start = int(counts.get("start", 0))
n_finish = int(counts.get("finish", 0))
return {
"role_counts": counts,
"n_start_holds": n_start,
"n_finish_holds": n_finish,
"has_at_most_two_starts": n_start <= 2,
"has_at_most_two_finishes": n_finish <= 2,
"role_limits_valid": n_start <= 2 and n_finish <= 2,
}
def _combined_validity(tokens: list[str], requested_board_prefix: str) -> dict[str, Any]:
"""Combined structural validity used by webapp generation and prediction."""
base = validity_summary(tokens, requested_board_prefix=requested_board_prefix)
role_limits = _role_limit_validity(tokens, requested_board_prefix=requested_board_prefix)
base_basic = bool(base.get("basic_valid", False))
combined_basic = bool(base_basic and role_limits["role_limits_valid"])
return {
**base,
**role_limits,
"base_basic_valid": base_basic,
"basic_valid": combined_basic,
"webapp_basic_valid": combined_basic,
}
def _invalid_prediction_reasons(validity: dict[str, Any]) -> list[str]:
"""Human-readable reasons a route should not be grade-predicted."""
reasons: list[str] = []
if int(validity.get("n_hold_tokens", 0)) == 0:
reasons.append("no hold tokens were found in the frames string")
if not bool(validity.get("one_board_only", True)):
reasons.append("the route contains holds from more than one board")
if not bool(validity.get("matches_requested_board", True)):
reasons.append("the route contains holds from the wrong board")
if bool(validity.get("has_duplicate_placements", False)):
reasons.append("the route contains duplicate placements")
if not bool(validity.get("has_start", False)):
reasons.append("the route has no start hold")
if not bool(validity.get("has_finish", False)):
reasons.append("the route has no finish hold")
if int(validity.get("n_start_holds", 0)) > 2:
reasons.append("the route has more than two start holds")
if int(validity.get("n_finish_holds", 0)) > 2:
reasons.append("the route has more than two finish holds")
if int(validity.get("n_hold_tokens", 0)) < 3:
reasons.append("the route has fewer than 3 holds")
return reasons
def _angle_key(angle: Any) -> int:
try:
return int(round(float(angle)))
except Exception:
return 0
def _route_signature_from_holds(board: str, angle: Any, holds: list[dict[str, Any]]) -> str | None:
"""Canonical exact-match signature: board + angle + hold-role multiset."""
parts = []
for hold in holds:
try:
placement_id = int(hold["placement_id"])
role = str(hold["role"])
except Exception:
continue
parts.append(f"{role}:{placement_id}")
if not parts:
return None
return f"{board}|{_angle_key(angle)}|" + "|".join(sorted(parts))
def _holds_from_sequence(sequence: str) -> list[dict[str, Any]]:
holds: list[dict[str, Any]] = []
for record in tokens_to_hold_records(str(sequence).split()):
holds.append(
{
"board_prefix": str(record["board_token_prefix"]),
"placement_id": int(record["placement_id"]),
"role": str(record["role"]),
}
)
return holds
def _load_available_angles(path: Path) -> dict[str, list[int]]:
"""Load available wall angles by board from route_sequences.csv.
Falls back to a conservative common angle set if the processed route table
is unavailable. This keeps the webapp usable for demos even if only model
checkpoints/token metadata are present.
"""
fallback = {
"tb2": [20, 25, 30, 35, 40, 45, 50],
"kilter": [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70],
}
if not path.exists():
return fallback
try:
df = pd.read_csv(path, usecols=["board_key", "angle"])
except Exception:
return fallback
out: dict[str, list[int]] = {}
for board, frame in df.dropna(subset=["board_key", "angle"]).groupby("board_key"):
values = sorted({_angle_key(angle) for angle in frame["angle"].tolist()})
if values:
out[str(board)] = values
for board, values in fallback.items():
out.setdefault(board, values)
return out
def _load_available_grades(path: Path) -> dict[str, list[int]]:
"""Load available grouped V-grades by board from route_sequences.csv."""
fallback = {
"tb2": list(range(0, 16)),
"kilter": list(range(0, 16)),
}
if not path.exists():
return fallback
try:
df = pd.read_csv(path, usecols=["board_key", "grouped_v"])
except Exception:
return fallback
out: dict[str, list[int]] = {}
for board, frame in df.dropna(subset=["board_key", "grouped_v"]).groupby("board_key"):
values = sorted({int(v) for v in frame["grouped_v"].tolist()})
if values:
out[str(board)] = values
for board, values in fallback.items():
out.setdefault(board, values)
return out
def _load_known_route_lookup(path: Path) -> dict[str, dict[str, Any]]:
"""Load exact known-route signatures once at startup.
This is intentionally an exact O(1) lookup, not a nearest-neighbor search.
The key is board + angle + exact hold-role set. Multiple real climbs can
share the same signature, so each signature stores a count and a few examples.
"""
lookup: dict[str, dict[str, Any]] = {}
if not path.exists():
return lookup
usecols = [
"uuid",
"board_key",
"climb_name",
"setter_username",
"angle",
"grouped_v",
"boulder_grade",
"quality_average",
"ascensionist_count",
"frames",
"sequence_no_grade",
]
try:
df = pd.read_csv(path, usecols=lambda col: col in set(usecols))
except Exception:
return lookup
for _, row in df.iterrows():
board = str(row.get("board_key", ""))
angle = row.get("angle", 0)
holds = _holds_from_sequence(str(row.get("sequence_no_grade", "")))
signature = _route_signature_from_holds(board, angle, holds)
if signature is None:
continue
entry = lookup.setdefault(signature, {"count": 0, "examples": []})
entry["count"] += 1
if len(entry["examples"]) < 5:
example = {
"uuid": str(row.get("uuid", "")),
"climb_name": None if pd.isna(row.get("climb_name", None)) else str(row.get("climb_name", "")),
"setter_username": None if pd.isna(row.get("setter_username", None)) else str(row.get("setter_username", "")),
"angle": _angle_key(angle),
"grouped_v": None if pd.isna(row.get("grouped_v", None)) else int(row.get("grouped_v")),
"boulder_grade": None if pd.isna(row.get("boulder_grade", None)) else str(row.get("boulder_grade", "")),
"quality_average": None if pd.isna(row.get("quality_average", None)) else float(row.get("quality_average")),
"ascensionist_count": None if pd.isna(row.get("ascensionist_count", None)) else int(row.get("ascensionist_count")),
"frames": None if pd.isna(row.get("frames", None)) else str(row.get("frames", "")),
}
entry["examples"].append(example)
return lookup
def _known_route_status(board: str, angle: Any, holds: list[dict[str, Any]]) -> dict[str, Any]:
signature = _route_signature_from_holds(board, angle, holds)
if signature is None:
return {
"checked": bool(getattr(app.state, "known_route_lookup", None)),
"is_known": False,
"match_count": 0,
"examples": [],
"signature": None,
}
lookup = getattr(app.state, "known_route_lookup", {})
match = lookup.get(signature)
return {
"checked": bool(lookup),
"is_known": match is not None,
"match_count": int(match["count"]) if match else 0,
"examples": match["examples"] if match else [],
"signature": signature,
}
def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[str, Any]:
board = str(result["board_key"])
tokens = list(tokens if tokens is not None else result.get("tokens", []))
extent = [float(v) for v in BOARD_CANVAS[board]["extent"]]
image_path = BOARD_IMAGE_PATHS[board]
holds = _tokens_to_holds(board, tokens)
angle = result.get("requested_angle", result.get("angle", 0))
return json_safe(
{
**result,
"tokens": tokens,
"holds": holds,
"known_climb": _known_route_status(board, angle, holds),
"canvas": {
"extent": extent,
"x_min": extent[0],
"x_max": extent[1],
"y_min": extent[2],
"y_max": extent[3],
"width": extent[1] - extent[0],
"height": extent[3] - extent[2],
},
"background_url": _static_image_url(board),
}
)
@asynccontextmanager
async def lifespan(app: FastAPI):
if TORCH_THREADS_INT is not None:
torch.set_num_threads(TORCH_THREADS_INT)
app.state.board_configs = {
"tb2": load_board_config("tb2", config_dir=REPO_ROOT / "configs"),
"kilter": load_board_config("kilter", config_dir=REPO_ROOT / "configs"),
}
app.state.token_meta = load_token_metadata(TOKENIZED_DIR)
app.state.available_angles = _load_available_angles(KNOWN_ROUTES_PATH)
app.state.available_grades = _load_available_grades(KNOWN_ROUTES_PATH)
started = time.time()
app.state.known_route_lookup = _load_known_route_lookup(KNOWN_ROUTES_PATH)
app.state.known_route_lookup_seconds = round(time.time() - started, 3)
app.state.generator = None
if GENERATOR_PATH.exists():
app.state.generator = load_route_generator(
GENERATOR_PATH,
device=DEVICE,
torch_threads=TORCH_THREADS_INT,
)
app.state.grade_predictor = None
if GRADE_MODEL_PATH.exists():
app.state.grade_predictor = load_grade_predictor(
GRADE_MODEL_PATH,
device=DEVICE,
torch_threads=TORCH_THREADS_INT,
)
yield
app = FastAPI(title="ClimbingBoardGPT", version="0.1.0", lifespan=lifespan)
app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static")
app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images")
@app.get("/")
def index():
return FileResponse(REPO_ROOT / "webapp" / "static" / "index.html")
@app.get("/api/health")
def health():
return {
"ok": True,
"device": DEVICE,
"generator_loaded": app.state.generator is not None,
"grade_predictor_loaded": app.state.grade_predictor is not None,
"known_route_signatures": len(getattr(app.state, "known_route_lookup", {})),
"known_route_lookup_seconds": getattr(app.state, "known_route_lookup_seconds", None),
"available_angles": getattr(app.state, "available_angles", {}),
"available_grades": getattr(app.state, "available_grades", {}),
"torch_threads": torch.get_num_threads(),
}
@app.get("/api/boards")
def boards():
payload = {}
for board, config in app.state.board_configs.items():
extent = [float(v) for v in BOARD_CANVAS[board]["extent"]]
payload[board] = {
"board_key": board,
"display_name": config.display_name,
"token_prefix": config.token_prefix,
"role_definitions": config.role_definitions,
"available_angles": getattr(app.state, "available_angles", {}).get(board, []),
"available_grades": getattr(app.state, "available_grades", {}).get(board, []),
"background_url": _static_image_url(board),
"canvas": {
"extent": extent,
"x_min": extent[0],
"x_max": extent[1],
"y_min": extent[2],
"y_max": extent[3],
"width": extent[1] - extent[0],
"height": extent[3] - extent[2],
},
}
return payload
@app.get("/api/board-holds/{board}")
def board_holds(board: str):
config = _board_config(board)
return json_safe({
"board_key": board,
"display_name": config.display_name,
"token_prefix": config.token_prefix,
"role_definitions": config.role_definitions,
"holds": _board_available_holds(board),
})
@app.get("/api/debug/images")
def debug_images():
payload = {}
for board, image_path in BOARD_IMAGE_PATHS.items():
payload[board] = {
"background_url": _static_image_url(board),
"file": _file_info(image_path),
}
return payload
@app.post("/api/generate")
def generate(req: GenerateRequest):
generator = _require_generator()
config = _board_config(req.board)
attempts = req.max_attempts if req.valid_only else 1
result = None
sampled_results = []
for attempt in range(1, attempts + 1):
candidate = generate_route(
generator=generator,
board_config=config,
angle=req.angle,
grade=req.grade,
temperature=req.temperature,
top_k=req.top_k,
max_new_tokens=req.max_new_tokens,
)
candidate["generation_attempt"] = attempt
combined_validity = _combined_validity(
list(candidate.get("tokens", [])),
requested_board_prefix=config.token_prefix,
)
candidate.update(combined_validity)
candidate["webapp_validity"] = combined_validity
sampled_results.append(
{
"attempt": attempt,
"basic_valid": bool(candidate.get("basic_valid")),
"base_basic_valid": bool(candidate.get("base_basic_valid")),
"role_limits_valid": bool(candidate.get("role_limits_valid")),
"n_hold_tokens": int(candidate.get("n_hold_tokens", 0)),
"n_start_holds": int(candidate.get("n_start_holds", 0)),
"n_finish_holds": int(candidate.get("n_finish_holds", 0)),
"frames": candidate.get("frames", ""),
}
)
result = candidate
if not req.valid_only or bool(candidate.get("basic_valid")):
break
if result is None:
raise HTTPException(status_code=500, detail="Generation failed unexpectedly.")
result["requested_valid_only"] = bool(req.valid_only)
result["max_attempts"] = int(attempts)
result["attempts_used"] = int(result.get("generation_attempt", 1))
result["sampled_attempts"] = sampled_results
warnings = []
if req.valid_only and not bool(result.get("basic_valid")):
warnings.append(
f"Could not sample a valid climb after {attempts} attempts. Showing the last sample."
)
if not bool(result.get("role_limits_valid", True)):
warnings.append(
"Last sample violates start/finish role limits: at most two starts and at most two finishes."
)
elif req.valid_only and result["attempts_used"] > 1:
warnings.append(
f"Sampled {result['attempts_used']} climbs before finding a valid one."
)
elif not bool(result.get("basic_valid")):
warnings.append("Generated climb is structurally invalid.")
if app.state.grade_predictor is not None:
grade_result = predict_route_grade(app.state.grade_predictor, result["tokens"])
result.update(grade_result)
result["critic_v_error"] = (
int(result["predicted_grouped_v"]) - int(result["requested_grouped_v"])
)
if "warnings" in result and isinstance(result["warnings"], list):
result["warnings"].extend(warnings)
else:
result["warnings"] = warnings
return _payload(result)
@app.post("/api/predict")
def predict(req: PredictRequest):
predictor = _require_grade_predictor()
config = _board_config(req.board)
result = predict_frames_grade(
grade_predictor=predictor,
frames=req.frames,
angle=req.angle,
board_config=config,
df_token_meta=app.state.token_meta,
)
tokens = list(result["tokens"])
validity = _combined_validity(tokens, requested_board_prefix=config.token_prefix)
result.update(validity)
result["webapp_validity"] = validity
if not bool(validity.get("basic_valid", False)):
reasons = _invalid_prediction_reasons(validity)
raise HTTPException(
status_code=400,
detail=json_safe(
{
"message": "Cannot predict grade for an invalid climb.",
"reasons": reasons,
"validity": validity,
}
),
)
return _payload(result, tokens=tokens)