Updated evaluation function to give tolerance for slight floating point differences
Browse files- test_pretrained.ipynb +60 -17
- train-data/sql_train.tsv +1 -1
test_pretrained.ipynb
CHANGED
|
@@ -385,7 +385,7 @@
|
|
| 385 |
},
|
| 386 |
{
|
| 387 |
"cell_type": "code",
|
| 388 |
-
"execution_count":
|
| 389 |
"metadata": {},
|
| 390 |
"outputs": [
|
| 391 |
{
|
|
@@ -400,21 +400,24 @@
|
|
| 400 |
"name": "stdout",
|
| 401 |
"output_type": "stream",
|
| 402 |
"text": [
|
| 403 |
-
"
|
| 404 |
-
"SELECT
|
| 405 |
-
"
|
| 406 |
"SQLite:\n",
|
| 407 |
-
"SELECT
|
| 408 |
"FROM game \n",
|
| 409 |
-
"WHERE
|
| 410 |
"\n",
|
| 411 |
-
"[(
|
| 412 |
-
"
|
|
|
|
| 413 |
"Result matched? True\n"
|
| 414 |
]
|
| 415 |
}
|
| 416 |
],
|
| 417 |
"source": [
|
|
|
|
|
|
|
| 418 |
"def compare_result(sample_query, sample_result, query_output):\n",
|
| 419 |
" # Clean model output to only have the query output\n",
|
| 420 |
" if query_output[0:7] == \"SQLite:\":\n",
|
|
@@ -435,38 +438,77 @@
|
|
| 435 |
" sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
|
| 436 |
" query_match = (query == sample_query)\n",
|
| 437 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
" # Check if this is a multi-line query\n",
|
| 439 |
" if \"|\" in sample_result or \"(\" in sample_result:\n",
|
|
|
|
|
|
|
| 440 |
" if \"(\" in sample_result:\n",
|
| 441 |
" sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
|
| 442 |
" result_list = sample_result.split(\",\") \n",
|
| 443 |
" else:\n",
|
| 444 |
" result_list = sample_result.split(\"|\") \n",
|
| 445 |
"\n",
|
|
|
|
| 446 |
" for i in range(len(result_list)):\n",
|
| 447 |
" result_list[i] = str(result_list[i]).strip()\n",
|
|
|
|
|
|
|
| 448 |
" result = False\n",
|
| 449 |
" for row in rows:\n",
|
| 450 |
" for r in row:\n",
|
| 451 |
-
"
|
| 452 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
" if len(rows) == 1:\n",
|
| 454 |
" for r in rows[0]:\n",
|
| 455 |
" if r == str(len(result_list)):\n",
|
| 456 |
-
" return query_match, True\n",
|
| 457 |
-
"
|
|
|
|
|
|
|
| 458 |
" else:\n",
|
| 459 |
" print(rows)\n",
|
| 460 |
" result = False\n",
|
|
|
|
| 461 |
" for row in rows:\n",
|
| 462 |
" for r in row:\n",
|
|
|
|
| 463 |
" if str(r) in str(sample_result):\n",
|
| 464 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
"\n",
|
| 466 |
" # Compare results and return\n",
|
| 467 |
-
" return query_match, result\n",
|
| 468 |
" except:\n",
|
| 469 |
-
" return False, False\n",
|
| 470 |
"\n",
|
| 471 |
"# Obtain sample\n",
|
| 472 |
"sample = df.sample(n=1)\n",
|
|
@@ -484,8 +526,9 @@
|
|
| 484 |
"print(query_output)\n",
|
| 485 |
"\n",
|
| 486 |
"result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
|
| 487 |
-
"print(\"
|
| 488 |
-
"print(\"
|
|
|
|
| 489 |
]
|
| 490 |
}
|
| 491 |
],
|
|
|
|
| 385 |
},
|
| 386 |
{
|
| 387 |
"cell_type": "code",
|
| 388 |
+
"execution_count": 65,
|
| 389 |
"metadata": {},
|
| 390 |
"outputs": [
|
| 391 |
{
|
|
|
|
| 400 |
"name": "stdout",
|
| 401 |
"output_type": "stream",
|
| 402 |
"text": [
|
| 403 |
+
"What is the total number of assists by the Chicago Bulls at home?\n",
|
| 404 |
+
"SELECT SUM(ast_home) as total_assists FROM game WHERE team_name_home = 'Chicago Bulls';\n",
|
| 405 |
+
"45090.0\n",
|
| 406 |
"SQLite:\n",
|
| 407 |
+
"SELECT SUM(ast_home) \n",
|
| 408 |
"FROM game \n",
|
| 409 |
+
"WHERE team_name_home = 'Chicago Bulls';\n",
|
| 410 |
"\n",
|
| 411 |
+
"[(45090.0,)]\n",
|
| 412 |
+
"Statement valid? True\n",
|
| 413 |
+
"SQLite matched? False\n",
|
| 414 |
"Result matched? True\n"
|
| 415 |
]
|
| 416 |
}
|
| 417 |
],
|
| 418 |
"source": [
|
| 419 |
+
"import math\n",
|
| 420 |
+
"\n",
|
| 421 |
"def compare_result(sample_query, sample_result, query_output):\n",
|
| 422 |
" # Clean model output to only have the query output\n",
|
| 423 |
" if query_output[0:7] == \"SQLite:\":\n",
|
|
|
|
| 438 |
" sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
|
| 439 |
" query_match = (query == sample_query)\n",
|
| 440 |
"\n",
|
| 441 |
+
" # If the queries match, the results clearly also match\n",
|
| 442 |
+
" if query_match:\n",
|
| 443 |
+
" return True, True, True\n",
|
| 444 |
+
"\n",
|
| 445 |
" # Check if this is a multi-line query\n",
|
| 446 |
" if \"|\" in sample_result or \"(\" in sample_result:\n",
|
| 447 |
+
" print(rows)\n",
|
| 448 |
+
" # Create list of results by stripping separators and splitting on them\n",
|
| 449 |
" if \"(\" in sample_result:\n",
|
| 450 |
" sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
|
| 451 |
" result_list = sample_result.split(\",\") \n",
|
| 452 |
" else:\n",
|
| 453 |
" result_list = sample_result.split(\"|\") \n",
|
| 454 |
"\n",
|
| 455 |
+
" # Strip all results in list\n",
|
| 456 |
" for i in range(len(result_list)):\n",
|
| 457 |
" result_list[i] = str(result_list[i]).strip()\n",
|
| 458 |
+
" \n",
|
| 459 |
+
" # Loop through model result and see if it matches training example\n",
|
| 460 |
" result = False\n",
|
| 461 |
" for row in rows:\n",
|
| 462 |
" for r in row:\n",
|
| 463 |
+
" for res in result_list:\n",
|
| 464 |
+
" try:\n",
|
| 465 |
+
" if math.isclose(float(r), float(res), abs_tol=0.5):\n",
|
| 466 |
+
" return True, query_match, True\n",
|
| 467 |
+
" except:\n",
|
| 468 |
+
" if r in res or res in r:\n",
|
| 469 |
+
" return True, query_match, True\n",
|
| 470 |
+
" \n",
|
| 471 |
+
" # Check if the model returned a sum of examples as opposed to the whole thing\n",
|
| 472 |
" if len(rows) == 1:\n",
|
| 473 |
" for r in rows[0]:\n",
|
| 474 |
" if r == str(len(result_list)):\n",
|
| 475 |
+
" return True, query_match, True\n",
|
| 476 |
+
" \n",
|
| 477 |
+
" return True, query_match, result\n",
|
| 478 |
+
" # Else the sample result is a single value or string\n",
|
| 479 |
" else:\n",
|
| 480 |
" print(rows)\n",
|
| 481 |
" result = False\n",
|
| 482 |
+
" # Loop through model result and see if it contains the sample result\n",
|
| 483 |
" for row in rows:\n",
|
| 484 |
" for r in row:\n",
|
| 485 |
+
" # Check by string\n",
|
| 486 |
" if str(r) in str(sample_result):\n",
|
| 487 |
+
" try:\n",
|
| 488 |
+
" if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
|
| 489 |
+
" return True, query_match, True\n",
|
| 490 |
+
" except:\n",
|
| 491 |
+
" return True, query_match, True\n",
|
| 492 |
+
" # Check by number, using try incase the cast as float fails\n",
|
| 493 |
+
" try:\n",
|
| 494 |
+
" if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
|
| 495 |
+
" return True, query_match, True\n",
|
| 496 |
+
" except:\n",
|
| 497 |
+
" pass\n",
|
| 498 |
+
"\n",
|
| 499 |
+
" # Check if the model returned a list of examples instead of a total sum (both acceptable)\n",
|
| 500 |
+
" try:\n",
|
| 501 |
+
" if len(rows) > 1 and len(rows) == int(sample_result):\n",
|
| 502 |
+
" return True, query_match, True\n",
|
| 503 |
+
" if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):\n",
|
| 504 |
+
" return True, query_match, True\n",
|
| 505 |
+
" except:\n",
|
| 506 |
+
" pass\n",
|
| 507 |
"\n",
|
| 508 |
" # Compare results and return\n",
|
| 509 |
+
" return True, query_match, result\n",
|
| 510 |
" except:\n",
|
| 511 |
+
" return False, False, False\n",
|
| 512 |
"\n",
|
| 513 |
"# Obtain sample\n",
|
| 514 |
"sample = df.sample(n=1)\n",
|
|
|
|
| 526 |
"print(query_output)\n",
|
| 527 |
"\n",
|
| 528 |
"result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
|
| 529 |
+
"print(\"Statement valid? \" + str(result[0]))\n",
|
| 530 |
+
"print(\"SQLite matched? \" + str(result[1]))\n",
|
| 531 |
+
"print(\"Result matched? \" + str(result[2]))"
|
| 532 |
]
|
| 533 |
}
|
| 534 |
],
|
train-data/sql_train.tsv
CHANGED
|
@@ -476,7 +476,7 @@ How many away games did the Chicago Bulls play in the 2022 season? SELECT COUNT(
|
|
| 476 |
How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
|
| 477 |
How many home games did the Boston Celtics play in the 2020 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22020'; 36.0
|
| 478 |
What is the average number of fg_pct in home games by the Chicago Bulls? SELECT AVG(fg_pct_home) FROM game WHERE team_name_home = 'Chicago Bulls'; 0.4636694306246544
|
| 479 |
-
In which season did the Los Angeles Lakers have the highest average ast at home? SELECT season_id, AVG(ast_home) as avg_stat FROM game WHERE team_name_home = 'Los Angeles Lakers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1;
|
| 480 |
What is the average number of ft_pct in home games by the Los Angeles Lakers? SELECT AVG(ft_pct_home) FROM game WHERE team_name_home = 'Los Angeles Lakers'; 0.7450706106870195
|
| 481 |
In which season did the Golden State Warriors have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1974.0
|
| 482 |
How many away games did the Miami Heat play in the 1999 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Miami Heat' AND season_id = '21999'; 41.0
|
|
|
|
| 476 |
How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
|
| 477 |
How many home games did the Boston Celtics play in the 2020 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22020'; 36.0
|
| 478 |
What is the average number of fg_pct in home games by the Chicago Bulls? SELECT AVG(fg_pct_home) FROM game WHERE team_name_home = 'Chicago Bulls'; 0.4636694306246544
|
| 479 |
+
In which season did the Los Angeles Lakers have the highest average ast at home? SELECT season_id, AVG(ast_home) as avg_stat FROM game WHERE team_name_home = 'Los Angeles Lakers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 41969|36.6666666666667
|
| 480 |
What is the average number of ft_pct in home games by the Los Angeles Lakers? SELECT AVG(ft_pct_home) FROM game WHERE team_name_home = 'Los Angeles Lakers'; 0.7450706106870195
|
| 481 |
In which season did the Golden State Warriors have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1974.0
|
| 482 |
How many away games did the Miami Heat play in the 1999 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Miami Heat' AND season_id = '21999'; 41.0
|