"""FastAPI web demo for ClimbingBoardGPT. Inference-only design: - model checkpoints and token metadata are loaded once at startup; - board background images are served as static files; - each request returns route hold coordinates/roles as JSON; - the browser draws the overlay as SVG on top of the already-loaded image. Local run: uvicorn webapp.app:app --host 127.0.0.1 --port 8055 """ from __future__ import annotations import hashlib import os import sys import time from contextlib import asynccontextmanager from pathlib import Path from typing import Any import pandas as pd import torch from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field REPO_ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(REPO_ROOT / "src")) from climbingboardgpt.config import load_board_config from climbingboardgpt.generation import validity_summary from climbingboardgpt.inference import ( generate_route, load_grade_predictor, load_route_generator, predict_frames_grade, predict_route_grade, ) from climbingboardgpt.tokenization import tokens_to_hold_records from climbingboardgpt.utils import json_safe from climbingboardgpt.visualization import BOARD_CANVAS, load_token_metadata, tokens_to_route_records DEVICE = os.getenv("CBGPT_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu") TORCH_THREADS = os.getenv("CBGPT_TORCH_THREADS") TORCH_THREADS_INT = int(TORCH_THREADS) if TORCH_THREADS else None MODEL_DIR = Path(os.getenv("CBGPT_MODEL_DIR", REPO_ROOT / "models")) TOKENIZED_DIR = Path( os.getenv("CBGPT_TOKENIZED_DIR", REPO_ROOT / "data" / "processed" / "tokenized") ) KNOWN_ROUTES_PATH = Path( os.getenv("CBGPT_KNOWN_ROUTES_PATH", TOKENIZED_DIR / "route_sequences.csv") ) GENERATOR_PATH = MODEL_DIR / "joint_route_gpt_generator.pth" GRADE_MODEL_PATH = MODEL_DIR / "joint_transformer_grade_predictor.pth" BOARD_IMAGE_PATHS = { "tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png", "kilter": REPO_ROOT / "images" / "kilter-original-16x12_composite.png", } class GenerateRequest(BaseModel): board: str = Field(..., pattern="^(tb2|kilter)$") angle: int = Field(40, ge=0, le=80) grade: int = Field(6, ge=0, le=16) temperature: float = Field(0.9, ge=0.1, le=2.0) top_k: int = Field(50, ge=1, le=500) max_new_tokens: int = Field(40, ge=4, le=80) valid_only: bool = Field(True, description="Retry generation until a basic-valid climb is sampled.") max_attempts: int = Field(8, ge=1, le=25, description="Maximum attempts when valid_only is true.") class PredictRequest(BaseModel): board: str = Field(..., pattern="^(tb2|kilter)$") angle: int = Field(..., ge=0, le=80) frames: str = Field(..., min_length=1, max_length=500) def _board_config(board: str): try: return app.state.board_configs[board] except KeyError as exc: raise HTTPException(status_code=400, detail=f"Unknown board: {board}") from exc def _require_generator(): if app.state.generator is None: raise HTTPException( status_code=503, detail=f"Route generator checkpoint not loaded. Expected {GENERATOR_PATH}.", ) return app.state.generator def _require_grade_predictor(): if app.state.grade_predictor is None: raise HTTPException( status_code=503, detail=f"Grade predictor checkpoint not loaded. Expected {GRADE_MODEL_PATH}.", ) return app.state.grade_predictor def _tokens_to_holds(board: str, tokens: list[str]) -> list[dict[str, Any]]: route_records = tokens_to_route_records(tokens) if route_records.empty: return [] token_meta = app.state.token_meta coords = token_meta[ (token_meta["kind"] == "hold") & (token_meta["board_key"].astype(str) == str(board)) ][["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates( ["board_key", "placement_id"] ) merged = route_records.merge( coords, on=["board_token_prefix", "placement_id"], how="left", ).dropna(subset=["x", "y"]) holds = [] for _, row in merged.iterrows(): holds.append( { "token": str(row["token"]), "placement_id": int(row["placement_id"]), "role": str(row["role"]), "x": float(row["x"]), "y": float(row["y"]), } ) return holds def _file_version(path: Path) -> str: """Small cache-busting version string based on mtime and file size.""" try: stat = path.stat() return f"{int(stat.st_mtime)}-{stat.st_size}" except FileNotFoundError: return "missing" def _static_image_url(board: str) -> str: image_path = BOARD_IMAGE_PATHS[board] return f"/board-images/{image_path.name}?board={board}&v={_file_version(image_path)}" def _file_info(path: Path) -> dict[str, Any]: if not path.exists(): return { "path": str(path), "exists": False, "size_bytes": None, "mtime": None, "sha256_16": None, } data = path.read_bytes() stat = path.stat() return { "path": str(path), "exists": True, "size_bytes": stat.st_size, "mtime": stat.st_mtime, "sha256_16": hashlib.sha256(data).hexdigest()[:16], } def _board_available_holds(board: str) -> list[dict[str, Any]]: """Return clickable hold coordinates for a board. Some token-metadata rows can contain missing coordinates for hold-role tokens that exist in the vocabulary but cannot be plotted directly. Those rows must be removed before returning JSON, because FastAPI's JSONResponse rejects NaN. """ token_meta = app.state.token_meta holds = token_meta[ (token_meta["kind"] == "hold") & (token_meta["board_key"].astype(str) == str(board)) ][["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates( ["board_key", "placement_id"] ) holds = holds.copy() holds["x"] = pd.to_numeric(holds["x"], errors="coerce") holds["y"] = pd.to_numeric(holds["y"], errors="coerce") holds = holds.dropna(subset=["x", "y"]) return [ { "placement_id": int(row["placement_id"]), "x": float(row["x"]), "y": float(row["y"]), } for _, row in holds.sort_values(["y", "x", "placement_id"]).iterrows() ] def _role_limit_validity(tokens: list[str], requested_board_prefix: str) -> dict[str, Any]: """Extra webapp validity checks for start/finish counts. The lower-level validity check requires at least one start and at least one finish, but BoardLib-style climbs should not have arbitrarily many starts or finishes. For this demo we enforce at most two starts and at most two finishes. """ counts = { "start": 0, "middle": 0, "finish": 0, "foot": 0, "unknown": 0, } for record in tokens_to_hold_records(tokens): if str(record["board_token_prefix"]) != str(requested_board_prefix): continue role = str(record["role"]) counts[role] = counts.get(role, 0) + 1 n_start = int(counts.get("start", 0)) n_finish = int(counts.get("finish", 0)) return { "role_counts": counts, "n_start_holds": n_start, "n_finish_holds": n_finish, "has_at_most_two_starts": n_start <= 2, "has_at_most_two_finishes": n_finish <= 2, "role_limits_valid": n_start <= 2 and n_finish <= 2, } def _combined_validity(tokens: list[str], requested_board_prefix: str) -> dict[str, Any]: """Combined structural validity used by webapp generation and prediction.""" base = validity_summary(tokens, requested_board_prefix=requested_board_prefix) role_limits = _role_limit_validity(tokens, requested_board_prefix=requested_board_prefix) base_basic = bool(base.get("basic_valid", False)) combined_basic = bool(base_basic and role_limits["role_limits_valid"]) return { **base, **role_limits, "base_basic_valid": base_basic, "basic_valid": combined_basic, "webapp_basic_valid": combined_basic, } def _invalid_prediction_reasons(validity: dict[str, Any]) -> list[str]: """Human-readable reasons a route should not be grade-predicted.""" reasons: list[str] = [] if int(validity.get("n_hold_tokens", 0)) == 0: reasons.append("no hold tokens were found in the frames string") if not bool(validity.get("one_board_only", True)): reasons.append("the route contains holds from more than one board") if not bool(validity.get("matches_requested_board", True)): reasons.append("the route contains holds from the wrong board") if bool(validity.get("has_duplicate_placements", False)): reasons.append("the route contains duplicate placements") if not bool(validity.get("has_start", False)): reasons.append("the route has no start hold") if not bool(validity.get("has_finish", False)): reasons.append("the route has no finish hold") if int(validity.get("n_start_holds", 0)) > 2: reasons.append("the route has more than two start holds") if int(validity.get("n_finish_holds", 0)) > 2: reasons.append("the route has more than two finish holds") if int(validity.get("n_hold_tokens", 0)) < 3: reasons.append("the route has fewer than 3 holds") return reasons def _angle_key(angle: Any) -> int: try: return int(round(float(angle))) except Exception: return 0 def _route_signature_from_holds(board: str, angle: Any, holds: list[dict[str, Any]]) -> str | None: """Canonical exact-match signature: board + angle + hold-role multiset.""" parts = [] for hold in holds: try: placement_id = int(hold["placement_id"]) role = str(hold["role"]) except Exception: continue parts.append(f"{role}:{placement_id}") if not parts: return None return f"{board}|{_angle_key(angle)}|" + "|".join(sorted(parts)) def _holds_from_sequence(sequence: str) -> list[dict[str, Any]]: holds: list[dict[str, Any]] = [] for record in tokens_to_hold_records(str(sequence).split()): holds.append( { "board_prefix": str(record["board_token_prefix"]), "placement_id": int(record["placement_id"]), "role": str(record["role"]), } ) return holds def _load_available_angles(path: Path) -> dict[str, list[int]]: """Load available wall angles by board from route_sequences.csv. Falls back to a conservative common angle set if the processed route table is unavailable. This keeps the webapp usable for demos even if only model checkpoints/token metadata are present. """ fallback = { "tb2": [20, 25, 30, 35, 40, 45, 50], "kilter": [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], } if not path.exists(): return fallback try: df = pd.read_csv(path, usecols=["board_key", "angle"]) except Exception: return fallback out: dict[str, list[int]] = {} for board, frame in df.dropna(subset=["board_key", "angle"]).groupby("board_key"): values = sorted({_angle_key(angle) for angle in frame["angle"].tolist()}) if values: out[str(board)] = values for board, values in fallback.items(): out.setdefault(board, values) return out def _load_available_grades(path: Path) -> dict[str, list[int]]: """Load available grouped V-grades by board from route_sequences.csv.""" fallback = { "tb2": list(range(0, 16)), "kilter": list(range(0, 16)), } if not path.exists(): return fallback try: df = pd.read_csv(path, usecols=["board_key", "grouped_v"]) except Exception: return fallback out: dict[str, list[int]] = {} for board, frame in df.dropna(subset=["board_key", "grouped_v"]).groupby("board_key"): values = sorted({int(v) for v in frame["grouped_v"].tolist()}) if values: out[str(board)] = values for board, values in fallback.items(): out.setdefault(board, values) return out def _load_known_route_lookup(path: Path) -> dict[str, dict[str, Any]]: """Load exact known-route signatures once at startup. This is intentionally an exact O(1) lookup, not a nearest-neighbor search. The key is board + angle + exact hold-role set. Multiple real climbs can share the same signature, so each signature stores a count and a few examples. """ lookup: dict[str, dict[str, Any]] = {} if not path.exists(): return lookup usecols = [ "uuid", "board_key", "climb_name", "setter_username", "angle", "grouped_v", "boulder_grade", "quality_average", "ascensionist_count", "frames", "sequence_no_grade", ] try: df = pd.read_csv(path, usecols=lambda col: col in set(usecols)) except Exception: return lookup for _, row in df.iterrows(): board = str(row.get("board_key", "")) angle = row.get("angle", 0) holds = _holds_from_sequence(str(row.get("sequence_no_grade", ""))) signature = _route_signature_from_holds(board, angle, holds) if signature is None: continue entry = lookup.setdefault(signature, {"count": 0, "examples": []}) entry["count"] += 1 if len(entry["examples"]) < 5: example = { "uuid": str(row.get("uuid", "")), "climb_name": None if pd.isna(row.get("climb_name", None)) else str(row.get("climb_name", "")), "setter_username": None if pd.isna(row.get("setter_username", None)) else str(row.get("setter_username", "")), "angle": _angle_key(angle), "grouped_v": None if pd.isna(row.get("grouped_v", None)) else int(row.get("grouped_v")), "boulder_grade": None if pd.isna(row.get("boulder_grade", None)) else str(row.get("boulder_grade", "")), "quality_average": None if pd.isna(row.get("quality_average", None)) else float(row.get("quality_average")), "ascensionist_count": None if pd.isna(row.get("ascensionist_count", None)) else int(row.get("ascensionist_count")), "frames": None if pd.isna(row.get("frames", None)) else str(row.get("frames", "")), } entry["examples"].append(example) return lookup def _known_route_status(board: str, angle: Any, holds: list[dict[str, Any]]) -> dict[str, Any]: signature = _route_signature_from_holds(board, angle, holds) if signature is None: return { "checked": bool(getattr(app.state, "known_route_lookup", None)), "is_known": False, "match_count": 0, "examples": [], "signature": None, } lookup = getattr(app.state, "known_route_lookup", {}) match = lookup.get(signature) return { "checked": bool(lookup), "is_known": match is not None, "match_count": int(match["count"]) if match else 0, "examples": match["examples"] if match else [], "signature": signature, } def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[str, Any]: board = str(result["board_key"]) tokens = list(tokens if tokens is not None else result.get("tokens", [])) extent = [float(v) for v in BOARD_CANVAS[board]["extent"]] image_path = BOARD_IMAGE_PATHS[board] holds = _tokens_to_holds(board, tokens) angle = result.get("requested_angle", result.get("angle", 0)) return json_safe( { **result, "tokens": tokens, "holds": holds, "known_climb": _known_route_status(board, angle, holds), "canvas": { "extent": extent, "x_min": extent[0], "x_max": extent[1], "y_min": extent[2], "y_max": extent[3], "width": extent[1] - extent[0], "height": extent[3] - extent[2], }, "background_url": _static_image_url(board), } ) @asynccontextmanager async def lifespan(app: FastAPI): if TORCH_THREADS_INT is not None: torch.set_num_threads(TORCH_THREADS_INT) app.state.board_configs = { "tb2": load_board_config("tb2", config_dir=REPO_ROOT / "configs"), "kilter": load_board_config("kilter", config_dir=REPO_ROOT / "configs"), } app.state.token_meta = load_token_metadata(TOKENIZED_DIR) app.state.available_angles = _load_available_angles(KNOWN_ROUTES_PATH) app.state.available_grades = _load_available_grades(KNOWN_ROUTES_PATH) started = time.time() app.state.known_route_lookup = _load_known_route_lookup(KNOWN_ROUTES_PATH) app.state.known_route_lookup_seconds = round(time.time() - started, 3) app.state.generator = None if GENERATOR_PATH.exists(): app.state.generator = load_route_generator( GENERATOR_PATH, device=DEVICE, torch_threads=TORCH_THREADS_INT, ) app.state.grade_predictor = None if GRADE_MODEL_PATH.exists(): app.state.grade_predictor = load_grade_predictor( GRADE_MODEL_PATH, device=DEVICE, torch_threads=TORCH_THREADS_INT, ) yield app = FastAPI(title="ClimbingBoardGPT", version="0.1.0", lifespan=lifespan) app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static") app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images") @app.get("/") def index(): return FileResponse(REPO_ROOT / "webapp" / "static" / "index.html") @app.get("/api/health") def health(): return { "ok": True, "device": DEVICE, "generator_loaded": app.state.generator is not None, "grade_predictor_loaded": app.state.grade_predictor is not None, "known_route_signatures": len(getattr(app.state, "known_route_lookup", {})), "known_route_lookup_seconds": getattr(app.state, "known_route_lookup_seconds", None), "available_angles": getattr(app.state, "available_angles", {}), "available_grades": getattr(app.state, "available_grades", {}), "torch_threads": torch.get_num_threads(), } @app.get("/api/boards") def boards(): payload = {} for board, config in app.state.board_configs.items(): extent = [float(v) for v in BOARD_CANVAS[board]["extent"]] payload[board] = { "board_key": board, "display_name": config.display_name, "token_prefix": config.token_prefix, "role_definitions": config.role_definitions, "available_angles": getattr(app.state, "available_angles", {}).get(board, []), "available_grades": getattr(app.state, "available_grades", {}).get(board, []), "background_url": _static_image_url(board), "canvas": { "extent": extent, "x_min": extent[0], "x_max": extent[1], "y_min": extent[2], "y_max": extent[3], "width": extent[1] - extent[0], "height": extent[3] - extent[2], }, } return payload @app.get("/api/board-holds/{board}") def board_holds(board: str): config = _board_config(board) return json_safe({ "board_key": board, "display_name": config.display_name, "token_prefix": config.token_prefix, "role_definitions": config.role_definitions, "holds": _board_available_holds(board), }) @app.get("/api/debug/images") def debug_images(): payload = {} for board, image_path in BOARD_IMAGE_PATHS.items(): payload[board] = { "background_url": _static_image_url(board), "file": _file_info(image_path), } return payload @app.post("/api/generate") def generate(req: GenerateRequest): generator = _require_generator() config = _board_config(req.board) attempts = req.max_attempts if req.valid_only else 1 result = None sampled_results = [] for attempt in range(1, attempts + 1): candidate = generate_route( generator=generator, board_config=config, angle=req.angle, grade=req.grade, temperature=req.temperature, top_k=req.top_k, max_new_tokens=req.max_new_tokens, ) candidate["generation_attempt"] = attempt combined_validity = _combined_validity( list(candidate.get("tokens", [])), requested_board_prefix=config.token_prefix, ) candidate.update(combined_validity) candidate["webapp_validity"] = combined_validity sampled_results.append( { "attempt": attempt, "basic_valid": bool(candidate.get("basic_valid")), "base_basic_valid": bool(candidate.get("base_basic_valid")), "role_limits_valid": bool(candidate.get("role_limits_valid")), "n_hold_tokens": int(candidate.get("n_hold_tokens", 0)), "n_start_holds": int(candidate.get("n_start_holds", 0)), "n_finish_holds": int(candidate.get("n_finish_holds", 0)), "frames": candidate.get("frames", ""), } ) result = candidate if not req.valid_only or bool(candidate.get("basic_valid")): break if result is None: raise HTTPException(status_code=500, detail="Generation failed unexpectedly.") result["requested_valid_only"] = bool(req.valid_only) result["max_attempts"] = int(attempts) result["attempts_used"] = int(result.get("generation_attempt", 1)) result["sampled_attempts"] = sampled_results warnings = [] if req.valid_only and not bool(result.get("basic_valid")): warnings.append( f"Could not sample a valid climb after {attempts} attempts. Showing the last sample." ) if not bool(result.get("role_limits_valid", True)): warnings.append( "Last sample violates start/finish role limits: at most two starts and at most two finishes." ) elif req.valid_only and result["attempts_used"] > 1: warnings.append( f"Sampled {result['attempts_used']} climbs before finding a valid one." ) elif not bool(result.get("basic_valid")): warnings.append("Generated climb is structurally invalid.") if app.state.grade_predictor is not None: grade_result = predict_route_grade(app.state.grade_predictor, result["tokens"]) result.update(grade_result) result["critic_v_error"] = ( int(result["predicted_grouped_v"]) - int(result["requested_grouped_v"]) ) if "warnings" in result and isinstance(result["warnings"], list): result["warnings"].extend(warnings) else: result["warnings"] = warnings return _payload(result) @app.post("/api/predict") def predict(req: PredictRequest): predictor = _require_grade_predictor() config = _board_config(req.board) result = predict_frames_grade( grade_predictor=predictor, frames=req.frames, angle=req.angle, board_config=config, df_token_meta=app.state.token_meta, ) tokens = list(result["tokens"]) validity = _combined_validity(tokens, requested_board_prefix=config.token_prefix) result.update(validity) result["webapp_validity"] = validity if not bool(validity.get("basic_valid", False)): reasons = _invalid_prediction_reasons(validity) raise HTTPException( status_code=400, detail=json_safe( { "message": "Cannot predict grade for an invalid climb.", "reasons": reasons, "validity": validity, } ), ) return _payload(result, tokens=tokens)