Add changes to save new rag dadataframe
Browse files- rag_helper.ipynb +13 -3
rag_helper.ipynb
CHANGED
|
@@ -251,6 +251,9 @@
|
|
| 251 |
"outputs": [],
|
| 252 |
"source": [
|
| 253 |
"def run_evaluation(nba_df):\n",
|
|
|
|
|
|
|
|
|
|
| 254 |
" for index, row in nba_df.iterrows():\n",
|
| 255 |
" # Create message with sample query and run model\n",
|
| 256 |
" message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
|
|
@@ -259,9 +262,16 @@
|
|
| 259 |
"\n",
|
| 260 |
" # Obtain output\n",
|
| 261 |
" query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
|
| 262 |
-
"\n",
|
| 263 |
-
"
|
| 264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
]
|
| 266 |
},
|
| 267 |
{
|
|
|
|
| 251 |
"outputs": [],
|
| 252 |
"source": [
|
| 253 |
"def run_evaluation(nba_df):\n",
|
| 254 |
+
" team_flags = []\n",
|
| 255 |
+
" game_flags = []\n",
|
| 256 |
+
" other_stats_flags =[]\n",
|
| 257 |
" for index, row in nba_df.iterrows():\n",
|
| 258 |
" # Create message with sample query and run model\n",
|
| 259 |
" message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
|
|
|
|
| 262 |
"\n",
|
| 263 |
" # Obtain output\n",
|
| 264 |
" query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
|
| 265 |
+
" team_flags.append(\"team\" in query_output.lower())\n",
|
| 266 |
+
" game_flags.append(\"game\" in query_output.lower())\n",
|
| 267 |
+
" other_stats_flags.append(\"other_stats\" in query_output.lower())\n",
|
| 268 |
+
" #print(\"Query: \", + row[\"sql_query\"])\n",
|
| 269 |
+
" #print(\"Response: \",query_output)\n",
|
| 270 |
+
" \n",
|
| 271 |
+
" nba_df[\"team_flag\"] = team_flags\n",
|
| 272 |
+
" nba_df[\"game_flag\"] = game_flags\n",
|
| 273 |
+
" nba_df[\"other_stats_flag\"] = other_stats_flags\n",
|
| 274 |
+
" nba_df.to_csv(get_path(\"expanded_dta.tsv\"), sep=\"\\t\", index=False)\n"
|
| 275 |
]
|
| 276 |
},
|
| 277 |
{
|