Next version. Models + scripts updated. 2

This commit is contained in:
Pawel
2026-05-21 22:21:26 -04:00
parent 0002ef1545
commit 86d582a572
23 changed files with 1768 additions and 293 deletions

View File

@@ -80,7 +80,7 @@
" make_placement_lookup,\n",
" vocab_payload,\n",
")\n",
"from climbingboardgpt.utils import safe_train_test_split, write_json, json_safe"
"from climbingboardgpt.utils import assign_group_splits, write_json, json_safe"
]
},
{
@@ -314,30 +314,17 @@
"source": [
"df_routes[\"ids_with_grade\"] = df_routes[\"tokens_with_grade\"].apply(lambda tokens: encode(tokens, stoi))\n",
"df_routes[\"ids_no_grade\"] = df_routes[\"tokens_no_grade\"].apply(lambda tokens: encode(tokens, stoi))\n",
"\n",
"df_routes[\"split_stratum\"] = df_routes[\"board_key\"].astype(str) + \"__V\" + df_routes[\"grouped_v\"].astype(str)\n",
"\n",
"train_df, temp_df = safe_train_test_split(\n",
"df_routes[\"split\"] = assign_group_splits(\n",
" df_routes,\n",
" group_cols=[\"board_key\", \"uuid\"],\n",
" test_size=0.20,\n",
" random_state=42,\n",
" stratify_col=\"split_stratum\",\n",
")\n",
"val_df, test_df = safe_train_test_split(\n",
" temp_df,\n",
" test_size=0.50,\n",
" random_state=42,\n",
" val_size_within_temp=0.50,\n",
" random_state=3,\n",
" stratify_col=\"split_stratum\",\n",
")\n",
"\n",
"split_map = {}\n",
"split_map.update({uuid: \"train\" for uuid in train_df[\"uuid\"]})\n",
"split_map.update({uuid: \"val\" for uuid in val_df[\"uuid\"]})\n",
"split_map.update({uuid: \"test\" for uuid in test_df[\"uuid\"]})\n",
"df_routes[\"split\"] = df_routes[\"uuid\"].map(split_map)\n",
"\n",
"print(\"Split counts by board:\")\n",
"print(df_routes.groupby([\"board_key\", \"split\"]).size().unstack(fill_value=0))"
"df_routes.groupby([\"board_key\", \"split\"]).size().unstack(fill_value=0)"
]
},
{
@@ -457,8 +444,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.11"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,