Added & fixed some documentation
This commit is contained in:
@@ -44,6 +44,8 @@ from climbingboardgpt.utils import json_safe
|
||||
from climbingboardgpt.visualization import BOARD_CANVAS, load_token_metadata, tokens_to_route_records
|
||||
|
||||
|
||||
# Environment variables keep deployment-specific paths and resource limits out
|
||||
# of code, while the defaults make a local checkout runnable without config.
|
||||
DEVICE = os.getenv("CBGPT_DEVICE") or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
TORCH_THREADS = os.getenv("CBGPT_TORCH_THREADS")
|
||||
TORCH_THREADS_INT = int(TORCH_THREADS) if TORCH_THREADS else None
|
||||
@@ -66,6 +68,8 @@ BOARD_IMAGE_PATHS = {
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""JSON body for ``POST /api/generate``."""
|
||||
|
||||
board: str = Field(..., pattern="^(tb2|kilter)$")
|
||||
angle: int = Field(40, ge=0, le=80)
|
||||
grade: int = Field(6, ge=0, le=16)
|
||||
@@ -77,12 +81,15 @@ class GenerateRequest(BaseModel):
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""JSON body for ``POST /api/predict``."""
|
||||
|
||||
board: str = Field(..., pattern="^(tb2|kilter)$")
|
||||
angle: int = Field(..., ge=0, le=80)
|
||||
frames: str = Field(..., min_length=1, max_length=500)
|
||||
|
||||
|
||||
def _board_config(board: str):
|
||||
"""Return the loaded board config or translate unknown boards to HTTP 400."""
|
||||
try:
|
||||
return app.state.board_configs[board]
|
||||
except KeyError as exc:
|
||||
@@ -90,6 +97,7 @@ def _board_config(board: str):
|
||||
|
||||
|
||||
def _require_generator():
|
||||
"""Return the loaded generator or raise HTTP 503 with a useful path hint."""
|
||||
if app.state.generator is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
@@ -99,6 +107,7 @@ def _require_generator():
|
||||
|
||||
|
||||
def _require_grade_predictor():
|
||||
"""Return the loaded grade predictor or raise HTTP 503 with a path hint."""
|
||||
if app.state.grade_predictor is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
@@ -108,6 +117,7 @@ def _require_grade_predictor():
|
||||
|
||||
|
||||
def _tokens_to_holds(board: str, tokens: list[str]) -> list[dict[str, Any]]:
|
||||
"""Join route tokens to board coordinates for browser-side SVG drawing."""
|
||||
route_records = tokens_to_route_records(tokens)
|
||||
if route_records.empty:
|
||||
return []
|
||||
@@ -151,11 +161,13 @@ def _file_version(path: Path) -> str:
|
||||
|
||||
|
||||
def _static_image_url(board: str) -> str:
|
||||
"""Return the static board image URL with a cache-busting query string."""
|
||||
image_path = BOARD_IMAGE_PATHS[board]
|
||||
return f"/board-images/{image_path.name}?board={board}&v={_file_version(image_path)}"
|
||||
|
||||
|
||||
def _file_info(path: Path) -> dict[str, Any]:
|
||||
"""Return lightweight debug metadata for a static file."""
|
||||
if not path.exists():
|
||||
return {
|
||||
"path": str(path),
|
||||
@@ -285,6 +297,7 @@ def _invalid_prediction_reasons(validity: dict[str, Any]) -> list[str]:
|
||||
|
||||
|
||||
def _angle_key(angle: Any) -> int:
|
||||
"""Normalize angle-like values for signatures and selectors."""
|
||||
try:
|
||||
return int(round(float(angle)))
|
||||
except Exception:
|
||||
@@ -309,6 +322,7 @@ def _route_signature_from_holds(board: str, angle: Any, holds: list[dict[str, An
|
||||
|
||||
|
||||
def _holds_from_sequence(sequence: str) -> list[dict[str, Any]]:
|
||||
"""Extract exact-match hold-role records from a no-grade token sequence."""
|
||||
holds: list[dict[str, Any]] = []
|
||||
for record in tokens_to_hold_records(str(sequence).split()):
|
||||
holds.append(
|
||||
@@ -439,6 +453,7 @@ def _load_known_route_lookup(path: Path) -> dict[str, dict[str, Any]]:
|
||||
|
||||
|
||||
def _known_route_status(board: str, angle: Any, holds: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Check whether a route exactly matches a tokenized dataset route."""
|
||||
signature = _route_signature_from_holds(board, angle, holds)
|
||||
if signature is None:
|
||||
return {
|
||||
@@ -462,6 +477,7 @@ def _known_route_status(board: str, angle: Any, holds: list[dict[str, Any]]) ->
|
||||
|
||||
|
||||
def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[str, Any]:
|
||||
"""Attach drawable holds, canvas data, image URL, and known-route status."""
|
||||
board = str(result["board_key"])
|
||||
tokens = list(tokens if tokens is not None else result.get("tokens", []))
|
||||
extent = [float(v) for v in BOARD_CANVAS[board]["extent"]]
|
||||
@@ -491,6 +507,7 @@ def _payload(result: dict[str, Any], tokens: list[str] | None = None) -> dict[st
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Load model/checkpoint state once for the process lifetime."""
|
||||
if TORCH_THREADS_INT is not None:
|
||||
torch.set_num_threads(TORCH_THREADS_INT)
|
||||
|
||||
@@ -530,11 +547,13 @@ app.mount("/board-images", StaticFiles(directory=REPO_ROOT / "images"), name="bo
|
||||
|
||||
@app.get("/")
|
||||
def index():
|
||||
"""Serve the single-page web UI."""
|
||||
return FileResponse(REPO_ROOT / "webapp" / "static" / "index.html")
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
def health():
|
||||
"""Return runtime readiness and deployment diagnostics."""
|
||||
return {
|
||||
"ok": True,
|
||||
"device": DEVICE,
|
||||
@@ -550,6 +569,7 @@ def health():
|
||||
|
||||
@app.get("/api/boards")
|
||||
def boards():
|
||||
"""Return board metadata needed to initialize client-side controls."""
|
||||
payload = {}
|
||||
for board, config in app.state.board_configs.items():
|
||||
extent = [float(v) for v in BOARD_CANVAS[board]["extent"]]
|
||||
@@ -576,6 +596,7 @@ def boards():
|
||||
|
||||
@app.get("/api/board-holds/{board}")
|
||||
def board_holds(board: str):
|
||||
"""Return all clickable holds for a board."""
|
||||
config = _board_config(board)
|
||||
return json_safe({
|
||||
"board_key": board,
|
||||
@@ -588,6 +609,7 @@ def board_holds(board: str):
|
||||
|
||||
@app.get("/api/debug/images")
|
||||
def debug_images():
|
||||
"""Return static-board-image debug metadata."""
|
||||
payload = {}
|
||||
for board, image_path in BOARD_IMAGE_PATHS.items():
|
||||
payload[board] = {
|
||||
@@ -599,6 +621,7 @@ def debug_images():
|
||||
|
||||
@app.post("/api/generate")
|
||||
def generate(req: GenerateRequest):
|
||||
"""Generate a climb and optionally retry until webapp validity passes."""
|
||||
generator = _require_generator()
|
||||
config = _board_config(req.board)
|
||||
|
||||
@@ -682,6 +705,7 @@ def generate(req: GenerateRequest):
|
||||
|
||||
@app.post("/api/predict")
|
||||
def predict(req: PredictRequest):
|
||||
"""Predict grade for a user-supplied frames string after validity checks."""
|
||||
predictor = _require_grade_predictor()
|
||||
config = _board_config(req.board)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user