initial commit
This commit is contained in:
328
scripts/04_evaluate_generated_routes.py
Normal file
328
scripts/04_evaluate_generated_routes.py
Normal file
@@ -0,0 +1,328 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ClimbingBoardGPT — Generated Route Evaluation Script
|
||||
|
||||
This script evaluates routes generated by the GPT model on four dimensions:
|
||||
|
||||
1. Validity: Does the route follow structural rules?
|
||||
- At least 3 holds
|
||||
- No duplicate placements
|
||||
- At least one start and one finish hold
|
||||
- All holds from the same board
|
||||
|
||||
2. Novelty: Is the route different from existing climbs?
|
||||
- Measured by Jaccard distance from the nearest real route
|
||||
|
||||
3. Geometric plausibility: Are holds in reasonable positions?
|
||||
- Height, width, mean hand reach distance
|
||||
|
||||
4. Grade consistency: Does the route's predicted grade match the request?
|
||||
- Uses the trained grade predictor as a "critic"
|
||||
|
||||
This is analogous to how language models are evaluated using BLEU, ROUGE,
|
||||
or human evaluation — but adapted for the climbing domain.
|
||||
|
||||
Usage:
|
||||
python scripts/04_evaluate_generated_routes.py
|
||||
python scripts/04_evaluate_generated_routes.py --grade-model-path models/joint_transformer_grade_predictor.pth
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(REPO_ROOT / "src"))
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from climbingboardgpt.evaluation import (
|
||||
build_placement_coords,
|
||||
frames_to_holds,
|
||||
holds_to_placement_set,
|
||||
nearest_real_route_same_board,
|
||||
parse_token_list,
|
||||
simple_route_features,
|
||||
tokens_to_hold_records,
|
||||
validity_from_records,
|
||||
)
|
||||
from climbingboardgpt.grades import to_grouped_v
|
||||
from climbingboardgpt.models import JointRouteTransformerRegressor
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments for route evaluation."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate generated TB2/Kilter route candidates.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--tokenized-dir", type=Path, default=REPO_ROOT / "data" / "processed" / "tokenized")
|
||||
parser.add_argument("--generated-dir", type=Path, default=REPO_ROOT / "data" / "processed" / "generation")
|
||||
parser.add_argument("--out-dir", type=Path, default=REPO_ROOT / "data" / "processed" / "evaluation")
|
||||
parser.add_argument("--grade-model-path", type=Path, default=REPO_ROOT / "models" / "joint_transformer_grade_predictor.pth")
|
||||
parser.add_argument("--device", type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_grade_critic(model_path: Path, device: torch.device):
|
||||
"""Load the trained grade predictor model as a critic.
|
||||
|
||||
The critic is used to predict the difficulty of generated routes.
|
||||
If we asked for V6 and the critic predicts V6 ± 1, the generation
|
||||
is grade-consistent.
|
||||
|
||||
This is similar to how GANs use a discriminator, except our critic
|
||||
is a regression model rather than a binary classifier.
|
||||
|
||||
Args:
|
||||
model_path: Path to the saved model checkpoint
|
||||
device: torch device
|
||||
|
||||
Returns:
|
||||
Dictionary with model, vocabulary, and config, or None if not found
|
||||
"""
|
||||
if not model_path.exists():
|
||||
return None
|
||||
try:
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
coord_features = checkpoint["coord_features"]
|
||||
if not isinstance(coord_features, torch.Tensor):
|
||||
coord_features = torch.tensor(coord_features, dtype=torch.float32)
|
||||
|
||||
model = JointRouteTransformerRegressor(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_len=cfg["max_len"],
|
||||
coord_features=coord_features,
|
||||
d_model=cfg.get("d_model", 128),
|
||||
nhead=cfg.get("nhead", 4),
|
||||
num_layers=cfg.get("num_layers", 4),
|
||||
dim_feedforward=cfg.get("dim_feedforward", 256),
|
||||
dropout=cfg.get("dropout", 0.10),
|
||||
pad_id=cfg.get("pad_id", stoi["<PAD>"]),
|
||||
).to(device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"stoi": stoi,
|
||||
"pad_id": stoi["<PAD>"],
|
||||
"unk_id": stoi["<UNK>"],
|
||||
"max_len": cfg["max_len"],
|
||||
}
|
||||
|
||||
|
||||
def predict_generated_grade(tokens: list[str], critic, device: torch.device) -> float:
|
||||
"""Use the critic model to predict the difficulty of a generated route.
|
||||
|
||||
Args:
|
||||
tokens: List of token strings (from generated route)
|
||||
critic: Dictionary with model and vocabulary
|
||||
device: torch device
|
||||
|
||||
Returns:
|
||||
Predicted difficulty score (continuous value)
|
||||
"""
|
||||
model = critic["model"]
|
||||
stoi = critic["stoi"]
|
||||
pad_id = critic["pad_id"]
|
||||
unk_id = critic["unk_id"]
|
||||
max_len = critic["max_len"]
|
||||
|
||||
# Remove grade tokens (we want the model to predict, not see the grade)
|
||||
tokens = [token for token in tokens if not token.startswith("<GRADE_")]
|
||||
# Replace <BOS> with <CLS> for the encoder model
|
||||
if tokens and tokens[0] == "<BOS>":
|
||||
tokens = ["<CLS>"] + tokens[1:]
|
||||
else:
|
||||
tokens = ["<CLS>"] + tokens
|
||||
|
||||
# Encode tokens to IDs and pad to max_len
|
||||
ids = [stoi.get(token, unk_id) for token in tokens][:max_len]
|
||||
mask = [1] * len(ids)
|
||||
if len(ids) < max_len:
|
||||
pad_n = max_len - len(ids)
|
||||
ids += [pad_id] * pad_n
|
||||
mask += [0] * pad_n
|
||||
|
||||
with torch.no_grad():
|
||||
input_ids = torch.tensor([ids], dtype=torch.long, device=device)
|
||||
attention_mask = torch.tensor([mask], dtype=torch.bool, device=device)
|
||||
return float(model(input_ids, attention_mask).cpu().item())
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main evaluation pipeline.
|
||||
|
||||
Steps:
|
||||
1. Load generated routes and real routes
|
||||
2. Parse tokens and check validity
|
||||
3. Compute novelty (Jaccard distance from nearest real route)
|
||||
4. Compute geometric features
|
||||
5. Optionally use critic model for grade consistency
|
||||
6. Rank routes by composite score
|
||||
7. Save evaluation results
|
||||
"""
|
||||
args = parse_args()
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 1: Load data
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
generated_path = args.generated_dir / "generated_routes.csv"
|
||||
routes_path = args.tokenized_dir / "route_sequences.csv"
|
||||
token_meta_path = args.tokenized_dir / "token_metadata.csv"
|
||||
|
||||
if not generated_path.exists():
|
||||
raise FileNotFoundError("Missing generated routes. Run scripts/03_train_route_generator.py first.")
|
||||
if not routes_path.exists() or not token_meta_path.exists():
|
||||
raise FileNotFoundError("Missing tokenized artifacts. Run scripts/01_tokenize_routes.py first.")
|
||||
|
||||
df_generated = pd.read_csv(generated_path)
|
||||
df_real = pd.read_csv(routes_path)
|
||||
df_token_meta = pd.read_csv(token_meta_path)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 2: Parse tokens and check validity
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Validity checks ensure generated routes are structurally sound:
|
||||
# - basic_valid: ≥3 holds, no duplicates, has start+finish, one board
|
||||
# - strict_valid: basic_valid + has middle + ≥4 holds
|
||||
df_generated["tokens_parsed"] = df_generated["tokens"].apply(parse_token_list)
|
||||
df_generated["hold_records"] = df_generated["tokens_parsed"].apply(tokens_to_hold_records)
|
||||
df_generated["hold_set"] = df_generated["hold_records"].apply(
|
||||
lambda records: frozenset(int(record["placement_id"]) for record in records)
|
||||
)
|
||||
|
||||
validity = pd.DataFrame(df_generated["hold_records"].apply(validity_from_records).tolist())
|
||||
df_eval = pd.concat([df_generated.reset_index(drop=True), validity], axis=1)
|
||||
|
||||
print(f"Evaluated generated routes: {len(df_eval):,}")
|
||||
print("\nBasic validity by board:")
|
||||
print(df_eval.groupby("board_key")["basic_valid_eval"].mean())
|
||||
print("\nStrict validity by board:")
|
||||
print(df_eval.groupby("board_key")["strict_valid_eval"].mean())
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 3: Novelty (Jaccard distance from nearest real route)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# For each generated route, find the most similar real route on the
|
||||
# same board using Jaccard similarity of hold sets.
|
||||
# Novelty distance = 1 - Jaccard similarity
|
||||
# A value of 1.0 means completely novel (no shared holds)
|
||||
# A value of 0.0 means identical to an existing route
|
||||
df_real["real_holds"] = df_real["frames"].apply(frames_to_holds)
|
||||
df_real["hold_set"] = df_real["real_holds"].apply(holds_to_placement_set)
|
||||
|
||||
nearest = pd.DataFrame(
|
||||
df_eval.apply(
|
||||
lambda row: nearest_real_route_same_board(
|
||||
generated_set=row["hold_set"],
|
||||
generated_board_key=row["board_key"],
|
||||
real_df=df_real,
|
||||
),
|
||||
axis=1,
|
||||
).tolist()
|
||||
)
|
||||
df_eval = pd.concat([df_eval, nearest], axis=1)
|
||||
|
||||
print("\nNovelty statistics:")
|
||||
print(df_eval[["board_key", "nearest_real_jaccard", "novelty_distance"]].describe())
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 4: Geometric features
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Compute simple spatial features for each generated route:
|
||||
# - Number of holds
|
||||
# - Height gained (max Y - min Y)
|
||||
# - Width span (max X - min X)
|
||||
# - Mean hand reach distance
|
||||
coords = build_placement_coords(df_token_meta)
|
||||
geom = pd.DataFrame(
|
||||
df_eval.apply(
|
||||
lambda row: simple_route_features(
|
||||
board_key=row["board_key"],
|
||||
records=row["hold_records"],
|
||||
placement_coords=coords,
|
||||
),
|
||||
axis=1,
|
||||
).tolist()
|
||||
)
|
||||
df_eval = pd.concat([df_eval, geom], axis=1)
|
||||
|
||||
print("\nGeometric feature statistics:")
|
||||
print(df_eval[["board_key", "geom_n_holds", "geom_height", "geom_width", "geom_mean_hand_reach"]].describe())
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 5: Grade consistency (using critic model)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# If a trained grade predictor is available, use it as a "critic"
|
||||
# to check whether generated routes have grades consistent with
|
||||
# what was requested.
|
||||
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
critic = load_grade_critic(args.grade_model_path, device)
|
||||
if critic is not None:
|
||||
print("\nUsing grade critic for consistency scoring...")
|
||||
df_eval["critic_pred_display_difficulty"] = df_eval["tokens_parsed"].apply(
|
||||
lambda tokens: predict_generated_grade(tokens, critic, device)
|
||||
)
|
||||
df_eval["critic_pred_grouped_v"] = df_eval["critic_pred_display_difficulty"].apply(to_grouped_v)
|
||||
df_eval["critic_v_error"] = df_eval["critic_pred_grouped_v"] - df_eval["requested_grouped_v"]
|
||||
|
||||
print("\nCritic grade consistency by board:")
|
||||
summary = df_eval.groupby("board_key")["critic_v_error"].agg(
|
||||
exact=lambda s: float((s == 0).mean() * 100),
|
||||
within_1=lambda s: float((s.abs() <= 1).mean() * 100),
|
||||
within_2=lambda s: float((s.abs() <= 2).mean() * 100),
|
||||
)
|
||||
print(summary)
|
||||
else:
|
||||
print("No trained grade critic found. Skipping critic-based scoring.")
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 6: Rank routes by composite score
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# The composite score rewards:
|
||||
# - Basic validity (weight 2.0)
|
||||
# - Strict validity (weight 1.0)
|
||||
# - Novelty (weight 1.0)
|
||||
# - Grade consistency (weight 1.0 for ±1 V-grade, penalty for larger errors)
|
||||
ranked = df_eval.copy()
|
||||
ranked["score"] = 0.0
|
||||
ranked["score"] += ranked["basic_valid_eval"].astype(float) * 2.0
|
||||
ranked["score"] += ranked["strict_valid_eval"].astype(float) * 1.0
|
||||
ranked["score"] += ranked["novelty_distance"].fillna(0.0)
|
||||
|
||||
if "critic_v_error" in ranked.columns:
|
||||
ranked["score"] += (ranked["critic_v_error"].abs() <= 1).astype(float)
|
||||
ranked["score"] -= 0.25 * ranked["critic_v_error"].abs()
|
||||
|
||||
top_candidates = ranked.sort_values("score", ascending=False).head(100).reset_index(drop=True)
|
||||
|
||||
print(f"\nTop 10 generated routes by composite score:")
|
||||
display_cols = ["board_key", "score", "basic_valid_eval", "strict_valid_eval", "novelty_distance"]
|
||||
if "critic_v_error" in top_candidates.columns:
|
||||
display_cols.append("critic_v_error")
|
||||
print(top_candidates[display_cols].head(10))
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Step 7: Save results
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
df_eval.to_csv(args.out_dir / "generated_route_evaluation.csv", index=False)
|
||||
top_candidates.to_csv(args.out_dir / "top_generated_candidates.csv", index=False)
|
||||
|
||||
print(f"\nSaved evaluation results to:")
|
||||
print(f" {args.out_dir / 'generated_route_evaluation.csv'}")
|
||||
print(f" {args.out_dir / 'top_generated_candidates.csv'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user