webapp
This commit is contained in:
741
webapp/app.py
Normal file
741
webapp/app.py
Normal file
@@ -0,0 +1,741 @@
|
||||
"""FastAPI web demo for ClimbingBoardGPT.
|
||||
|
||||
Inference-only design:
|
||||
- model checkpoints and token metadata are loaded once at startup;
|
||||
- board background images are served as static files;
|
||||
- each request returns route hold coordinates/roles as JSON;
|
||||
- the browser draws the overlay as SVG on top of the already-loaded image.
|
||||
|
||||
Local run:
|
||||
|
||||
uvicorn webapp.app:app --host 127.0.0.1 --port 8055
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(REPO_ROOT / "src"))
|
||||
|
||||
from climbingboardgpt.config import load_board_config
|
||||
from climbingboardgpt.generation import validity_summary
|
||||
from climbingboardgpt.inference import (
|
||||
generate_route,
|
||||
load_grade_predictor,
|
||||
load_route_generator,
|
||||
predict_frames_grade,
|
||||
predict_route_grade,
|
||||
)
|
||||
from climbingboardgpt.visualization import BOARD_CANVAS, load_token_metadata, tokens_to_route_records
|
||||
|
||||
|
||||
DEVICE = os.getenv("CBGPT_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
TORCH_THREADS = os.getenv("CBGPT_TORCH_THREADS")
|
||||
TORCH_THREADS_INT = int(TORCH_THREADS) if TORCH_THREADS else None
|
||||
|
||||
MODEL_DIR = Path(os.getenv("CBGPT_MODEL_DIR", REPO_ROOT / "models"))
|
||||
TOKENIZED_DIR = Path(
|
||||
os.getenv("CBGPT_TOKENIZED_DIR", REPO_ROOT / "data" / "processed" / "tokenized")
|
||||
)
|
||||
KNOWN_ROUTES_PATH = Path(
|
||||
os.getenv("CBGPT_KNOWN_ROUTES_PATH", TOKENIZED_DIR / "route_sequences.csv")
|
||||
)
|
||||
|
||||
GENERATOR_PATH = MODEL_DIR / "joint_route_gpt_generator.pth"
|
||||
GRADE_MODEL_PATH = MODEL_DIR / "joint_transformer_grade_predictor.pth"
|
||||
|
||||
BOARD_IMAGE_PATHS = {
|
||||
"tb2": REPO_ROOT / "images" / "tb2_board_12x12_composite.png",
|
||||
"kilter": REPO_ROOT / "images" / "kilter-original-16x12_compose.png",
|
||||
}
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
board: str = Field(..., pattern="^(tb2|kilter)$")
|
||||
angle: int = Field(40, ge=0, le=80)
|
||||
grade: int = Field(6, ge=0, le=16)
|
||||
temperature: float = Field(0.9, ge=0.1, le=2.0)
|
||||
top_k: int = Field(50, ge=1, le=500)
|
||||
max_new_tokens: int = Field(40, ge=4, le=80)
|
||||
valid_only: bool = Field(True, description="Retry generation until a basic-valid climb is sampled.")
|
||||
max_attempts: int = Field(8, ge=1, le=25, description="Maximum attempts when valid_only is true.")
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
board: str = Field(..., pattern="^(tb2|kilter)$")
|
||||
angle: int = Field(..., ge=0, le=80)
|
||||
frames: str = Field(..., min_length=1, max_length=500)
|
||||
|
||||
|
||||
def _json_safe(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _json_safe(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_json_safe(v) for v in value]
|
||||
if isinstance(value, tuple):
|
||||
return [_json_safe(v) for v in value]
|
||||
if hasattr(value, "item"):
|
||||
try:
|
||||
return value.item()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if pd.isna(value):
|
||||
return None
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
def _board_config(board: str):
|
||||
try:
|
||||
return app.state.board_configs[board]
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown board: {board}") from exc
|
||||
|
||||
|
||||
def _require_generator():
|
||||
if app.state.generator is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Route generator checkpoint not loaded. Expected {GENERATOR_PATH}.",
|
||||
)
|
||||
return app.state.generator
|
||||
|
||||
|
||||
def _require_grade_predictor():
|
||||
if app.state.grade_predictor is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Grade predictor checkpoint not loaded. Expected {GRADE_MODEL_PATH}.",
|
||||
)
|
||||
return app.state.grade_predictor
|
||||
|
||||
|
||||
def _tokens_to_holds(board: str, tokens: list[str]) -> list[dict[str, Any]]:
|
||||
route_records = tokens_to_route_records(tokens)
|
||||
if route_records.empty:
|
||||
return []
|
||||
|
||||
token_meta = app.state.token_meta
|
||||
coords = token_meta[
|
||||
(token_meta["kind"] == "hold")
|
||||
& (token_meta["board_key"].astype(str) == str(board))
|
||||
][["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
|
||||
["board_key", "placement_id"]
|
||||
)
|
||||
|
||||
merged = route_records.merge(
|
||||
coords,
|
||||
on=["board_token_prefix", "placement_id"],
|
||||
how="left",
|
||||
).dropna(subset=["x", "y"])
|
||||
|
||||
holds = []
|
||||
for _, row in merged.iterrows():
|
||||
holds.append(
|
||||
{
|
||||
"token": str(row["token"]),
|
||||
"placement_id": int(row["placement_id"]),
|
||||
"role": str(row["role"]),
|
||||
"x": float(row["x"]),
|
||||
"y": float(row["y"]),
|
||||
}
|
||||
)
|
||||
return holds
|
||||
|
||||
|
||||
|
||||
def _file_version(path: Path) -> str:
|
||||
"""Small cache-busting version string based on mtime and file size."""
|
||||
try:
|
||||
stat = path.stat()
|
||||
return f"{int(stat.st_mtime)}-{stat.st_size}"
|
||||
except FileNotFoundError:
|
||||
return "missing"
|
||||
|
||||
|
||||
def _static_image_url(board: str) -> str:
|
||||
image_path = BOARD_IMAGE_PATHS[board]
|
||||
return f"/board-images/{image_path.name}?board={board}&v={_file_version(image_path)}"
|
||||
|
||||
|
||||
def _file_info(path: Path) -> dict[str, Any]:
|
||||
if not path.exists():
|
||||
return {
|
||||
"path": str(path),
|
||||
"exists": False,
|
||||
"size_bytes": None,
|
||||
"mtime": None,
|
||||
"sha256_16": None,
|
||||
}
|
||||
data = path.read_bytes()
|
||||
stat = path.stat()
|
||||
return {
|
||||
"path": str(path),
|
||||
"exists": True,
|
||||
"size_bytes": stat.st_size,
|
||||
"mtime": stat.st_mtime,
|
||||
"sha256_16": hashlib.sha256(data).hexdigest()[:16],
|
||||
}
|
||||
|
||||
|
||||
def _board_available_holds(board: str) -> list[dict[str, Any]]:
|
||||
"""Return clickable hold coordinates for a board.
|
||||
|
||||
Some token-metadata rows can contain missing coordinates for hold-role tokens
|
||||
that exist in the vocabulary but cannot be plotted directly. Those rows must
|
||||
be removed before returning JSON, because FastAPI's JSONResponse rejects NaN.
|
||||
"""
|
||||
token_meta = app.state.token_meta
|
||||
holds = token_meta[
|
||||
(token_meta["kind"] == "hold")
|
||||
& (token_meta["board_key"].astype(str) == str(board))
|
||||
][["board_key", "board_token_prefix", "placement_id", "x", "y"]].drop_duplicates(
|
||||
["board_key", "placement_id"]
|
||||
)
|
||||
|
||||
holds = holds.copy()
|
||||
holds["x"] = pd.to_numeric(holds["x"], errors="coerce")
|
||||
holds["y"] = pd.to_numeric(holds["y"], errors="coerce")
|
||||
holds = holds.dropna(subset=["x", "y"])
|
||||
|
||||
return [
|
||||
{
|
||||
"placement_id": int(row["placement_id"]),
|
||||
"x": float(row["x"]),
|
||||
"y": float(row["y"]),
|
||||
}
|
||||
for _, row in holds.sort_values(["y", "x", "placement_id"]).iterrows()
|
||||
]
|
||||
|
||||
|
||||
|
||||
|
||||
def _role_limit_validity(tokens: list[str], requested_board_prefix: str) -> dict[str, Any]:
|
||||
"""Extra webapp validity checks for start/finish counts.
|
||||
|
||||
The lower-level validity check requires at least one start and at least one
|
||||
finish, but BoardLib-style climbs should not have arbitrarily many starts or
|
||||
finishes. For this demo we enforce at most two starts and at most two finishes.
|
||||
"""
|
||||
counts = {
|
||||
"start": 0,
|
||||
"middle": 0,
|
||||
"finish": 0,
|
||||
"foot": 0,
|
||||
"unknown": 0,
|
||||
}
|
||||
|
||||
for token in tokens:
|
||||
match = HOLD_TOKEN_RE.match(str(token))
|
||||
if match is None:
|
||||
continue
|
||||
board_prefix, _placement_id, role = match.groups()
|
||||
if str(board_prefix) != str(requested_board_prefix):
|
||||
continue
|
||||
counts[role] = counts.get(role, 0) + 1
|
||||
|
||||
n_start = int(counts.get("start", 0))
|
||||
n_finish = int(counts.get("finish", 0))
|
||||
|
||||
return {
|
||||
"role_counts": counts,
|
||||
"n_start_holds": n_start,
|
||||
"n_finish_holds": n_finish,
|
||||
"has_at_most_two_starts": n_start <= 2,
|
||||
"has_at_most_two_finishes": n_finish <= 2,
|
||||
"role_limits_valid": n_start <= 2 and n_finish <= 2,
|
||||
}
|
||||
|
||||
|
||||
def _combined_validity(tokens: list[str], requested_board_prefix: str) -> dict[str, Any]:
|
||||
"""Combined structural validity used by webapp generation and prediction."""
|
||||
base = validity_summary(tokens, requested_board_prefix=requested_board_prefix)
|
||||
role_limits = _role_limit_validity(tokens, requested_board_prefix=requested_board_prefix)
|
||||
|
||||
base_basic = bool(base.get("basic_valid", False))
|
||||
combined_basic = bool(base_basic and role_limits["role_limits_valid"])
|
||||
|
||||
return {
|
||||
**base,
|
||||
**role_limits,
|
||||
"base_basic_valid": base_basic,
|
||||
"basic_valid": combined_basic,
|
||||
"webapp_basic_valid": combined_basic,
|
||||
}
|
||||
|
||||
|
||||
def _invalid_prediction_reasons(validity: dict[str, Any]) -> list[str]:
|
||||
"""Human-readable reasons a route should not be grade-predicted."""
|
||||
reasons: list[str] = []
|
||||
|
||||
if int(validity.get("n_hold_tokens", 0)) == 0:
|
||||
reasons.append("no hold tokens were found in the frames string")
|
||||
if not bool(validity.get("one_board_only", True)):
|
||||
reasons.append("the route contains holds from more than one board")
|
||||
if not bool(validity.get("matches_requested_board", True)):
|
||||
reasons.append("the route contains holds from the wrong board")
|
||||
if bool(validity.get("has_duplicate_placements", False)):
|
||||
reasons.append("the route contains duplicate placements")
|
||||
if not bool(validity.get("has_start", False)):
|
||||
reasons.append("the route has no start hold")
|
||||
if not bool(validity.get("has_finish", False)):
|
||||
reasons.append("the route has no finish hold")
|
||||
if int(validity.get("n_start_holds", 0)) > 2:
|
||||
reasons.append("the route has more than two start holds")
|
||||
if int(validity.get("n_finish_holds", 0)) > 2:
|
||||
reasons.append("the route has more than two finish holds")
|
||||
if int(validity.get("n_hold_tokens", 0)) < 3:
|
||||
reasons.append("the route has fewer than 3 holds")
|
||||
|
||||
return reasons
|
||||
|
||||
|
||||
|
||||
HOLD_TOKEN_RE = re.compile(r"^<([A-Z0-9_]+)_p(\d+)_(start|middle|finish|foot|unknown)>$")
|
||||
|
||||
|
||||
def _angle_key(angle: Any) -> int:
|
||||
try:
|
||||
return int(round(float(angle)))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _route_signature_from_holds(board: str, angle: Any, holds: list[dict[str, Any]]) -> str | None:
|
||||
"""Canonical exact-match signature: board + angle + hold-role multiset."""
|
||||
parts = []
|
||||
for hold in holds:
|
||||
try:
|
||||
placement_id = int(hold["placement_id"])
|
||||
role = str(hold["role"])
|
||||
except Exception:
|
||||
continue
|
||||
parts.append(f"{role}:{placement_id}")
|
||||
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
return f"{board}|{_angle_key(angle)}|" + "|".join(sorted(parts))
|
||||
|
||||
|
||||
def _holds_from_sequence(sequence: str) -> list[dict[str, Any]]:
|
||||
holds: list[dict[str, Any]] = []
|
||||
for token in str(sequence).split():
|
||||
match = HOLD_TOKEN_RE.match(token)
|
||||
if match is None:
|
||||
continue
|
||||
holds.append(
|
||||
{
|
||||
"board_prefix": match.group(1),
|
||||
"placement_id": int(match.group(2)),
|
||||
"role": match.group(3),
|
||||
}
|
||||
)
|
||||
return holds
|
||||
|
||||
|
||||
|
||||
def _load_available_angles(path: Path) -> dict[str, list[int]]:
|
||||
"""Load available wall angles by board from route_sequences.csv.
|
||||
|
||||
Falls back to a conservative common angle set if the processed route table
|
||||
is unavailable. This keeps the webapp usable for demos even if only model
|
||||
checkpoints/token metadata are present.
|
||||
"""
|
||||
fallback = {
|
||||
"tb2": [20, 25, 30, 35, 40, 45, 50],
|
||||
"kilter": [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70],
|
||||
}
|
||||
if not path.exists():
|
||||
return fallback
|
||||
|
||||
try:
|
||||
df = pd.read_csv(path, usecols=["board_key", "angle"])
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
out: dict[str, list[int]] = {}
|
||||
for board, frame in df.dropna(subset=["board_key", "angle"]).groupby("board_key"):
|
||||
values = sorted({_angle_key(angle) for angle in frame["angle"].tolist()})
|
||||
if values:
|
||||
out[str(board)] = values
|
||||
|
||||
for board, values in fallback.items():
|
||||
out.setdefault(board, values)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def _load_available_grades(path: Path) -> dict[str, list[int]]:
|
||||
"""Load available grouped V-grades by board from route_sequences.csv."""
|
||||
fallback = {
|
||||
"tb2": list(range(0, 16)),
|
||||
"kilter": list(range(0, 16)),
|
||||
}
|
||||
if not path.exists():
|
||||
return fallback
|
||||
|
||||
try:
|
||||
df = pd.read_csv(path, usecols=["board_key", "grouped_v"])
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
out: dict[str, list[int]] = {}
|
||||
for board, frame in df.dropna(subset=["board_key", "grouped_v"]).groupby("board_key"):
|
||||
values = sorted({int(v) for v in frame["grouped_v"].tolist()})
|
||||
if values:
|
||||
out[str(board)] = values
|
||||
|
||||
for board, values in fallback.items():
|
||||
out.setdefault(board, values)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _load_known_route_lookup(path: Path) -> dict[str, dict[str, Any]]:
|
||||
"""Load exact known-route signatures once at startup.
|
||||
|
||||
This is intentionally an exact O(1) lookup, not a nearest-neighbor search.
|
||||
The key is board + angle + exact hold-role set. Multiple real climbs can
|
||||
share the same signature, so each signature stores a count and a few examples.
|
||||
"""
|
||||
lookup: dict[str, dict[str, Any]] = {}
|
||||
if not path.exists():
|
||||
return lookup
|
||||
|
||||
usecols = [
|
||||
"uuid",
|
||||
"board_key",
|
||||
"climb_name",
|
||||
"setter_username",
|
||||
"angle",
|
||||
"grouped_v",
|
||||
"boulder_grade",
|
||||
"quality_average",
|
||||
"ascensionist_count",
|
||||
"frames",
|
||||
"sequence_no_grade",
|
||||
]
|
||||
|
||||
try:
|
||||
df = pd.read_csv(path, usecols=lambda col: col in set(usecols))
|
||||
except Exception:
|
||||
return lookup
|
||||
|
||||
for _, row in df.iterrows():
|
||||
board = str(row.get("board_key", ""))
|
||||
angle = row.get("angle", 0)
|
||||
holds = _holds_from_sequence(str(row.get("sequence_no_grade", "")))
|
||||
signature = _route_signature_from_holds(board, angle, holds)
|
||||
if signature is None:
|
||||
continue
|
||||
|
||||
entry = lookup.setdefault(signature, {"count": 0, "examples": []})
|
||||
entry["count"] += 1
|
||||
|
||||
if len(entry["examples"]) < 5:
|
||||
example = {
|
||||
"uuid": str(row.get("uuid", "")),
|
||||
"climb_name": None if pd.isna(row.get("climb_name", None)) else str(row.get("climb_name", "")),
|
||||
"setter_username": None if pd.isna(row.get("setter_username", None)) else str(row.get("setter_username", "")),
|
||||
"angle": _angle_key(angle),
|
||||
"grouped_v": None if pd.isna(row.get("grouped_v", None)) else int(row.get("grouped_v")),
|
||||
"boulder_grade": None if pd.isna(row.get("boulder_grade", None)) else str(row.get("boulder_grade", "")),
|
||||
"quality_average": None if pd.isna(row.get("quality_average", None)) else float(row.get("quality_average")),
|
||||
"ascensionist_count": None if pd.isna(row.get("ascensionist_count", None)) else int(row.get("ascensionist_count")),
|
||||
"frames": None if pd.isna(row.get("frames", None)) else str(row.get("frames", "")),
|
||||
}
|
||||
entry["examples"].append(example)
|
||||
|
||||
return lookup
|
||||
|
||||
|
||||
def _known_route_status(board: str, angle: Any, holds: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
signature = _route_signature_from_holds(board, angle, holds)
|
||||
if signature is None:
|
||||
return {
|
||||
"checked": bool(getattr(app.state, "known_route_lookup", None)),
|
||||
"is_known": False,
|
||||
"match_count": 0,
|
||||
"examples": [],
|
||||
"signature": None,
|
||||
}
|
||||
|
||||
lookup = getattr(app.state, "known_route_lookup", {})
|
||||
match = lookup.get(signature)
|
||||
|
||||
return {
|
||||
"checked": bool(lookup),
|
||||
"is_known": match is not None,
|
||||
"match_count": int(match["count"]) if match else 0,
|
||||
"examples": match["examples"] if match else [],
|
||||
"signature": signature,
|
||||
}
|
||||
|
||||
|
||||
def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[str, Any]:
|
||||
board = str(result["board_key"])
|
||||
tokens = list(tokens if tokens is not None else result.get("tokens", []))
|
||||
extent = [float(v) for v in BOARD_CANVAS[board]["extent"]]
|
||||
image_path = BOARD_IMAGE_PATHS[board]
|
||||
holds = _tokens_to_holds(board, tokens)
|
||||
angle = result.get("requested_angle", result.get("angle", 0))
|
||||
|
||||
return _json_safe(
|
||||
{
|
||||
**result,
|
||||
"tokens": tokens,
|
||||
"holds": holds,
|
||||
"known_climb": _known_route_status(board, angle, holds),
|
||||
"canvas": {
|
||||
"extent": extent,
|
||||
"x_min": extent[0],
|
||||
"x_max": extent[1],
|
||||
"y_min": extent[2],
|
||||
"y_max": extent[3],
|
||||
"width": extent[1] - extent[0],
|
||||
"height": extent[3] - extent[2],
|
||||
},
|
||||
"background_url": _static_image_url(board),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="ClimbingBoardGPT", version="0.1.0")
|
||||
app.mount("/static", StaticFiles(directory=REPO_ROOT / "webapp" / "static"), name="static")
|
||||
app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="board-images")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup() -> None:
|
||||
if TORCH_THREADS_INT is not None:
|
||||
torch.set_num_threads(TORCH_THREADS_INT)
|
||||
|
||||
app.state.board_configs = {
|
||||
"tb2": load_board_config("tb2", config_dir=REPO_ROOT / "configs"),
|
||||
"kilter": load_board_config("kilter", config_dir=REPO_ROOT / "configs"),
|
||||
}
|
||||
app.state.token_meta = load_token_metadata(TOKENIZED_DIR)
|
||||
app.state.available_angles = _load_available_angles(KNOWN_ROUTES_PATH)
|
||||
app.state.available_grades = _load_available_grades(KNOWN_ROUTES_PATH)
|
||||
started = time.time()
|
||||
app.state.known_route_lookup = _load_known_route_lookup(KNOWN_ROUTES_PATH)
|
||||
app.state.known_route_lookup_seconds = round(time.time() - started, 3)
|
||||
|
||||
app.state.generator = None
|
||||
if GENERATOR_PATH.exists():
|
||||
app.state.generator = load_route_generator(
|
||||
GENERATOR_PATH,
|
||||
device=DEVICE,
|
||||
torch_threads=TORCH_THREADS_INT,
|
||||
)
|
||||
|
||||
app.state.grade_predictor = None
|
||||
if GRADE_MODEL_PATH.exists():
|
||||
app.state.grade_predictor = load_grade_predictor(
|
||||
GRADE_MODEL_PATH,
|
||||
device=DEVICE,
|
||||
torch_threads=TORCH_THREADS_INT,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def index():
|
||||
return FileResponse(REPO_ROOT / "webapp" / "static" / "index.html")
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
return {
|
||||
"ok": True,
|
||||
"device": DEVICE,
|
||||
"generator_loaded": app.state.generator is not None,
|
||||
"grade_predictor_loaded": app.state.grade_predictor is not None,
|
||||
"known_route_signatures": len(getattr(app.state, "known_route_lookup", {})),
|
||||
"known_route_lookup_seconds": getattr(app.state, "known_route_lookup_seconds", None),
|
||||
"available_angles": getattr(app.state, "available_angles", {}),
|
||||
"available_grades": getattr(app.state, "available_grades", {}),
|
||||
"torch_threads": torch.get_num_threads(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/boards")
|
||||
def boards():
|
||||
payload = {}
|
||||
for board, config in app.state.board_configs.items():
|
||||
extent = [float(v) for v in BOARD_CANVAS[board]["extent"]]
|
||||
payload[board] = {
|
||||
"board_key": board,
|
||||
"display_name": config.display_name,
|
||||
"token_prefix": config.token_prefix,
|
||||
"role_definitions": config.role_definitions,
|
||||
"available_angles": getattr(app.state, "available_angles", {}).get(board, []),
|
||||
"available_grades": getattr(app.state, "available_grades", {}).get(board, []),
|
||||
"background_url": _static_image_url(board),
|
||||
"canvas": {
|
||||
"extent": extent,
|
||||
"x_min": extent[0],
|
||||
"x_max": extent[1],
|
||||
"y_min": extent[2],
|
||||
"y_max": extent[3],
|
||||
"width": extent[1] - extent[0],
|
||||
"height": extent[3] - extent[2],
|
||||
},
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
@app.get("/api/board-holds/{board}")
|
||||
def board_holds(board: str):
|
||||
config = _board_config(board)
|
||||
return _json_safe({
|
||||
"board_key": board,
|
||||
"display_name": config.display_name,
|
||||
"token_prefix": config.token_prefix,
|
||||
"role_definitions": config.role_definitions,
|
||||
"holds": _board_available_holds(board),
|
||||
})
|
||||
|
||||
|
||||
@app.get("/api/debug/images")
|
||||
def debug_images():
|
||||
payload = {}
|
||||
for board, image_path in BOARD_IMAGE_PATHS.items():
|
||||
payload[board] = {
|
||||
"background_url": _static_image_url(board),
|
||||
"file": _file_info(image_path),
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
@app.post("/api/generate")
|
||||
def generate(req: GenerateRequest):
|
||||
generator = _require_generator()
|
||||
config = _board_config(req.board)
|
||||
|
||||
attempts = req.max_attempts if req.valid_only else 1
|
||||
result = None
|
||||
sampled_results = []
|
||||
|
||||
for attempt in range(1, attempts + 1):
|
||||
candidate = generate_route(
|
||||
generator=generator,
|
||||
board_config=config,
|
||||
angle=req.angle,
|
||||
grade=req.grade,
|
||||
temperature=req.temperature,
|
||||
top_k=req.top_k,
|
||||
max_new_tokens=req.max_new_tokens,
|
||||
)
|
||||
candidate["generation_attempt"] = attempt
|
||||
combined_validity = _combined_validity(
|
||||
list(candidate.get("tokens", [])),
|
||||
requested_board_prefix=config.token_prefix,
|
||||
)
|
||||
candidate.update(combined_validity)
|
||||
candidate["webapp_validity"] = combined_validity
|
||||
|
||||
sampled_results.append(
|
||||
{
|
||||
"attempt": attempt,
|
||||
"basic_valid": bool(candidate.get("basic_valid")),
|
||||
"base_basic_valid": bool(candidate.get("base_basic_valid")),
|
||||
"role_limits_valid": bool(candidate.get("role_limits_valid")),
|
||||
"n_hold_tokens": int(candidate.get("n_hold_tokens", 0)),
|
||||
"n_start_holds": int(candidate.get("n_start_holds", 0)),
|
||||
"n_finish_holds": int(candidate.get("n_finish_holds", 0)),
|
||||
"frames": candidate.get("frames", ""),
|
||||
}
|
||||
)
|
||||
|
||||
result = candidate
|
||||
if not req.valid_only or bool(candidate.get("basic_valid")):
|
||||
break
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=500, detail="Generation failed unexpectedly.")
|
||||
|
||||
result["requested_valid_only"] = bool(req.valid_only)
|
||||
result["max_attempts"] = int(attempts)
|
||||
result["attempts_used"] = int(result.get("generation_attempt", 1))
|
||||
result["sampled_attempts"] = sampled_results
|
||||
|
||||
warnings = []
|
||||
if req.valid_only and not bool(result.get("basic_valid")):
|
||||
warnings.append(
|
||||
f"Could not sample a valid climb after {attempts} attempts. Showing the last sample."
|
||||
)
|
||||
if not bool(result.get("role_limits_valid", True)):
|
||||
warnings.append(
|
||||
"Last sample violates start/finish role limits: at most two starts and at most two finishes."
|
||||
)
|
||||
elif req.valid_only and result["attempts_used"] > 1:
|
||||
warnings.append(
|
||||
f"Sampled {result['attempts_used']} climbs before finding a valid one."
|
||||
)
|
||||
elif not bool(result.get("basic_valid")):
|
||||
warnings.append("Generated climb is structurally invalid.")
|
||||
|
||||
if app.state.grade_predictor is not None:
|
||||
grade_result = predict_route_grade(app.state.grade_predictor, result["tokens"])
|
||||
result.update(grade_result)
|
||||
result["critic_v_error"] = (
|
||||
int(result["predicted_grouped_v"]) - int(result["requested_grouped_v"])
|
||||
)
|
||||
|
||||
if "warnings" in result and isinstance(result["warnings"], list):
|
||||
result["warnings"].extend(warnings)
|
||||
else:
|
||||
result["warnings"] = warnings
|
||||
|
||||
return _payload(result)
|
||||
|
||||
|
||||
@app.post("/api/predict")
|
||||
def predict(req: PredictRequest):
|
||||
predictor = _require_grade_predictor()
|
||||
config = _board_config(req.board)
|
||||
|
||||
result = predict_frames_grade(
|
||||
grade_predictor=predictor,
|
||||
frames=req.frames,
|
||||
angle=req.angle,
|
||||
board_config=config,
|
||||
df_token_meta=app.state.token_meta,
|
||||
)
|
||||
|
||||
tokens = list(result["tokens"])
|
||||
validity = _combined_validity(tokens, requested_board_prefix=config.token_prefix)
|
||||
result.update(validity)
|
||||
result["webapp_validity"] = validity
|
||||
|
||||
if not bool(validity.get("basic_valid", False)):
|
||||
reasons = _invalid_prediction_reasons(validity)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=_json_safe(
|
||||
{
|
||||
"message": "Cannot predict grade for an invalid climb.",
|
||||
"reasons": reasons,
|
||||
"validity": validity,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
return _payload(result, tokens=tokens)
|
||||
Reference in New Issue
Block a user