Add web demo polish and smoke-test pipeline
This commit is contained in:
@@ -49,6 +49,7 @@ from climbingboardgpt.evaluation import (
|
||||
tokens_to_hold_records,
|
||||
validity_from_records,
|
||||
)
|
||||
from climbingboardgpt.checkpoints import load_checkpoint
|
||||
from climbingboardgpt.grades import to_grouped_v
|
||||
from climbingboardgpt.models import JointRouteTransformerRegressor
|
||||
|
||||
@@ -86,10 +87,7 @@ def load_grade_critic(model_path: Path, device: torch.device):
|
||||
"""
|
||||
if not model_path.exists():
|
||||
return None
|
||||
try:
|
||||
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
||||
except TypeError:
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
checkpoint = load_checkpoint(model_path, map_location=device, trusted=True)
|
||||
|
||||
cfg = checkpoint["config"]
|
||||
stoi = {str(k): int(v) for k, v in checkpoint["stoi"].items()}
|
||||
@@ -333,4 +331,4 @@ def main() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user