# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset

## First load dataset into pandas dataframe

In [34]:
import pandas as pd 
import warnings
warnings.filterwarnings("ignore")

# Load dataset and check length
df = pd.read_csv("./train-data/sql_train.tsv", sep='\t')
print("Total dataset examples: " + str(len(df)))
print("\n")

# Test sampling
sample = df.sample(n=1)
print(sample["natural_query"].values[0])
print(sample["sql_query"].values[0])
print(sample["result"].values[0])

Total dataset examples: 1044


What was the largest deficit overcome by the Miami Heat in any home victory?
SELECT o.largest_lead_away AS max_deficit_overcome FROM game g JOIN other_stats o ON g.game_id = o.game_id WHERE g.team_name_home = 'Miami Heat' AND g.wl_home = 'W' ORDER BY o.largest_lead_away DESC LIMIT 1;
46


## Load pre-trained DeepSeek model using transformers and pytorch packages

In [35]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Set device to cuda if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("./deepseek-coder-1.3b-instruct")
model = AutoModelForCausalLM.from_pretrained("./deepseek-coder-1.3b-instruct", torch_dtype=torch.bfloat16, device_map=device) 
model.generation_config.pad_token_id = tokenizer.pad_token_id

## Create prompt to setup the model for better performance

In [36]:
input_text = """You are an AI assistant that converts natural language queries into valid SQLite queries.
Database Schema and Explanations

team Table
Stores information about NBA teams.
CREATE TABLE IF NOT EXISTS "team" (
  "id" TEXT PRIMARY KEY,      -- Unique identifier for the team
  "full_name" TEXT,           -- Full official name of the team (e.g., "Los Angeles Lakers")
  "abbreviation" TEXT,        -- Shortened team name (e.g., "LAL")
  "nickname" TEXT,            -- Commonly used nickname for the team (e.g., "Lakers")
  "city" TEXT,                -- City where the team is based
  "state" TEXT,               -- State where the team is located
  "year_founded" REAL         -- Year the team was established
);

game Table
Contains detailed statistics for each NBA game, including home and away team performance.
CREATE TABLE IF NOT EXISTS "game" (
  "season_id" TEXT,            -- Season identifier, formatted as "2YYYY" (e.g., "21970" for the 1970 season)
  "team_id_home" TEXT,         -- ID of the home team (matches "id" in team table)
  "team_abbreviation_home" TEXT, -- Abbreviation of the home team
  "team_name_home" TEXT,       -- Full name of the home team
  "game_id" TEXT PRIMARY KEY,  -- Unique identifier for the game
  "game_date" TIMESTAMP,       -- Date the game was played (YYYY-MM-DD format)
  "matchup_home" TEXT,         -- Matchup details including opponent (e.g., "LAL vs. BOS")
  "wl_home" TEXT,              -- "W" if the home team won, "L" if they lost
  "min" INTEGER,               -- Total minutes played in the game
  "fgm_home" REAL,             -- Field goals made by the home team
  "fga_home" REAL,             -- Field goals attempted by the home team
  "fg_pct_home" REAL,          -- Field goal percentage of the home team
  "fg3m_home" REAL,            -- Three-point field goals made by the home team
  "fg3a_home" REAL,            -- Three-point attempts by the home team
  "fg3_pct_home" REAL,         -- Three-point field goal percentage of the home team
  "ftm_home" REAL,             -- Free throws made by the home team
  "fta_home" REAL,             -- Free throws attempted by the home team
  "ft_pct_home" REAL,          -- Free throw percentage of the home team
  "oreb_home" REAL,            -- Offensive rebounds by the home team
  "dreb_home" REAL,            -- Defensive rebounds by the home team
  "reb_home" REAL,             -- Total rebounds by the home team
  "ast_home" REAL,             -- Assists by the home team
  "stl_home" REAL,             -- Steals by the home team
  "blk_home" REAL,             -- Blocks by the home team
  "tov_home" REAL,             -- Turnovers by the home team
  "pf_home" REAL,              -- Personal fouls by the home team
  "pts_home" REAL,             -- Total points scored by the home team
  "plus_minus_home" INTEGER,   -- Plus/minus rating for the home team
  "video_available_home" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
  "team_id_away" TEXT,         -- ID of the away team
  "team_abbreviation_away" TEXT, -- Abbreviation of the away team
  "team_name_away" TEXT,       -- Full name of the away team
  "matchup_away" TEXT,         -- Matchup details from the away teamâ€™s perspective
  "wl_away" TEXT,              -- "W" if the away team won, "L" if they lost
  "fgm_away" REAL,             -- Field goals made by the away team
  "fga_away" REAL,             -- Field goals attempted by the away team
  "fg_pct_away" REAL,          -- Field goal percentage of the away team
  "fg3m_away" REAL,            -- Three-point field goals made by the away team
  "fg3a_away" REAL,            -- Three-point attempts by the away team
  "fg3_pct_away" REAL,         -- Three-point field goal percentage of the away team
  "ftm_away" REAL,             -- Free throws made by the away team
  "fta_away" REAL,             -- Free throws attempted by the away team
  "ft_pct_away" REAL,          -- Free throw percentage of the away team
  "oreb_away" REAL,            -- Offensive rebounds by the away team
  "dreb_away" REAL,            -- Defensive rebounds by the away team
  "reb_away" REAL,             -- Total rebounds by the away team
  "ast_away" REAL,             -- Assists by the away team
  "stl_away" REAL,             -- Steals by the away team
  "blk_away" REAL,             -- Blocks by the away team
  "tov_away" REAL,             -- Turnovers by the away team
  "pf_away" REAL,              -- Personal fouls by the away team
  "pts_away" REAL,             -- Total points scored by the away team
  "plus_minus_away" INTEGER,   -- Plus/minus rating for the away team
  "video_available_away" INTEGER, -- Indicates whether video is available (1 = Yes, 0 = No)
  "season_type" TEXT           -- Regular season or playoffs
);

other_stats Table
Stores additional statistics, linked to the game table via game_id.
CREATE TABLE IF NOT EXISTS "other_stats" (
  "game_id" TEXT,             -- Unique game identifier, matches id column from game table
  "league_id" TEXT,           -- League identifier
  "team_id_home" TEXT,        -- Home team identifier
  "team_abbreviation_home" TEXT, -- Home team abbreviation
  "team_city_home" TEXT,      -- Home team city
  "pts_paint_home" INTEGER,   -- Points in the paint by the home team
  "pts_2nd_chance_home" INTEGER, -- Second chance points by the home team
  "pts_fb_home" INTEGER,      -- Fast break points by the home team
  "largest_lead_home" INTEGER,-- Largest lead by the home team
  "lead_changes" INTEGER,     -- Number of lead changes 
  "times_tied" INTEGER,       -- Number of times the score was tied
  "team_turnovers_home" INTEGER, -- Home team turnovers
  "total_turnovers_home" INTEGER, -- Total turnovers by the home team
  "team_rebounds_home" INTEGER, -- Home team rebounds
  "pts_off_to_home" INTEGER,  -- Points off turnovers by the home team
  "team_id_away" TEXT,        -- Away team identifier
  "team_abbreviation_away" TEXT,  -- Away team abbreviation
  "pts_paint_away" INTEGER,   -- Points in the paint by the away team
  "pts_2nd_chance_away" INTEGER, -- Second chance points by the away team
  "pts_fb_away" INTEGER,      -- Fast break points by the away team
  "largest_lead_away" INTEGER,-- Largest lead by the away team
  "team_turnovers_away" INTEGER, -- Away team turnovers
  "total_turnovers_away" INTEGER, -- Total turnovers by the away team
  "team_rebounds_away" INTEGER, -- Away team rebounds
  "pts_off_to_away" INTEGER   -- Points off turnovers by the away team
);


Team Name Information
In the plaintext user questions, only the full team names will be used, but in the queries you may use the full team names or the abbreviations. 
The full team names can be used with the game table, while the abbreviations should be used with the other_stats table.
Notice they are separated by the | character in the following list:

Atlanta Hawks|ATL
Boston Celtics|BOS
Cleveland Cavaliers|CLE
New Orleans Pelicans|NOP
Chicago Bulls|CHI
Dallas Mavericks|DAL
Denver Nuggets|DEN
Golden State Warriors|GSW
Houston Rockets|HOU
Los Angeles Clippers|LAC
Los Angeles Lakers|LAL
Miami Heat|MIA
Milwaukee Bucks|MIL
Minnesota Timberwolves|MIN
Brooklyn Nets|BKN
New York Knicks|NYK
Orlando Magic|ORL
Indiana Pacers|IND
Philadelphia 76ers|PHI
Phoenix Suns|PHX
Portland Trail Blazers|POR
Sacramento Kings|SAC
San Antonio Spurs|SAS
Oklahoma City Thunder|OKC
Toronto Raptors|TOR
Utah Jazz|UTA
Memphis Grizzlies|MEM
Washington Wizards|WAS
Detroit Pistons|DET
Charlotte Hornets|CHA

Query Guidelines
Use team_name_home and team_name_away to match teams to the game table. Use team_abbreviation_home and team_abbreviation away to match teams to the other_stats table.

To filter by season, use season_id = '2YYYY'.

Example: To get statistics from 2005, use a statement like: season_id = '22005'. To get statistics from 1972, use a statement like: season_id = "21972". To get statistics from 2015, use a statement like: season_id = "22015".

The game_id column can be used to join the game and other_stats tables.

Ensure queries return relevant columns and avoid unnecessary joins.

When obtaining certain statistics by team from the game table, use the team_name_home and team_name_away columns. 
For example, to obtain home game data for the Washington Wizards from the game table use a statement like: team_name_home = 'Washington Wizards'
To obtain away game data from the Los Angeles Lakers from the game table use a statement like: team_name_away = 'Los Angeles Lakers'
To obtain general game data where home or away is not specified for the Chicago Bulls from the game table, use a statement like: (team_name_home = 'Chicago Bulls' OR team_name_away = 'Chicago Bulls')

When obtaining certain statistics by team from the other_stats table, use the team_abbreviation_home and team_abbreviation away columns.
For example, to obtain home statistics from the Charlotte Hornets from the other_stats table use a statement like: team_abbreviation_home = 'CHA'
To obtain away statistics from the Dallas Mavericks from the other_stats table, use a statement like: team_abbreviation_away = 'DAL'
To obtain general statistics from the other_stats table where home or away is not specified for the Detroit Pistons use a statement like: (team_abbreviation_home = 'DET' OR team_abbreviation_away = 'DET)


Example User Requests and SQLite Queries
Request:
"What is the most points the Los Angeles Lakers have ever scored at home?"
SQLite:
SELECT MAX(pts_home) 
FROM game 
WHERE team_name_home = 'Los Angeles Lakers';

Request:
"Which teams are located in the state of California?"
SQLite:
SELECT full_name FROM team WHERE state = 'California';

Request:
"How many total team rebounds did the Los Angeles Clippers have in away games where they scored over 15 fast break points?"
SQLite:
SELECT SUM(os.team_rebounds_away) 
FROM other_stats os  
JOIN game g ON os.game_id = g.game_id  
WHERE g.team_abbreviation_away = 'LAC' AND os.pts_fb_away > 15;

Request:
"How many points did the Miami Heat score on January 10, 2010?"
SQLite:
SELECT team_name_home, pts_home, team_name_away, pts_away 
FROM game 
WHERE DATE(game_date) = '2010-01-10' 
AND (team_name_home = 'Miami Heat' OR team_name_away = 'Miami Heat');

Request:
"Which team had the highest number of team turnovers in an away game?"
SQLite:
SELECT team_abbreviation_away FROM other_stats ORDER BY team_turnovers_away DESC LIMIT 1;

Request:
"Which team won the most home games in the 2000 season?"
SQLite:
SELECT team_name_home, COUNT(*) AS wins
FROM game
WHERE wl_home = 'W' AND season_id = '22000'
GROUP BY team_name_home
ORDER BY wins DESC
LIMIT 1;

Request:
"Which teams were founded before 1979?"
SQLite:
SELECT full_name FROM team WHERE year_founded < 1979;

Request:
"Which game had the most lead changes in the 2020 season?"
SQLite:
SELECT game_id, lead_changes  
FROM other_stats 
WHERE game_id IN 
(SELECT game_id FROM game WHERE season_id = '22020')
ORDER BY lead_changes DESC LIMIT 1;

Request:
"Find the Boston Celtics largest home victory margin in the 2008 season."
SQLite:
SELECT MAX(pts_home - pts_away) AS biggest_win
FROM game
WHERE team_name_home = 'Boston Celtics' AND season_id = '22008';

Request:
"How many fast break points did the Atlanta Hawks score at home?"
SQLite:
SELECT SUM(pts_fb_home) as total_fb_points  FROM other_stats  WHERE team_abbreviation_home = 'ATL';

Generate only the SQLite query prefaced by SQLite: and no other text, do not output an explanation of the query. Now generate an SQLite query for the following user request. Request:
"""

## Test model performance on a single example

In [37]:
# Create message with sample query and run model
message=[{ 'role': 'user', 'content': input_text + sample["natural_query"].values[0]}]
inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)

# Print output
query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
print(query_output)

SQLite:
SELECT MAX(pts_home - pts_away) AS largest_deficit
FROM game
WHERE wl_home = 'W' AND team_name_home = 'Miami Heat';



# Test sample output on sqlite3 database

In [38]:
import sqlite3 as sql

# Create connection to sqlite3 database
connection = sql.connect('./nba-data/nba.sqlite')
cursor = connection.cursor()

# Execute query from model output and print result
if query_output[0:7] == "SQLite:":
    print("cleaned")
    query = query_output[7:]
elif query_output[0:4] == "SQL:":
    query = query_output[4:]
else:
    query = query_output

try:
    cursor.execute(query)
    rows = cursor.fetchall()
    for row in rows:
        print(row)
except:
    pass

cleaned
(43.0,)


## Create function to compare output to ground truth result from examples

In [39]:
import math

def compare_result(sample_query, sample_result, query_output):
    # Clean model output to only have the query output
    if query_output[0:7] == "SQLite:":
        query = query_output[7:]
    elif query_output[0:4] == "SQL:":
        query = query_output[4:]
    else:
        query = query_output
    
    # Try to execute query, if it fails, then this is a failure of the model
    try:
        # Execute query and obtain result
        cursor.execute(query)
        rows = cursor.fetchall()

        # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.
        query = query.replace(" ", "").replace("\n", "").replace("\t", "")
        sample_query = sample_query.replace(" ", "").replace("\n", "").replace("\t", "")
        query_match = (query == sample_query)

        # If the queries match, the results clearly also match
        if query_match:
            return True, True, True

        # Check if this is a multi-line query
        if "|" in sample_result or "(" in sample_result:
            #print(rows)
            # Create list of results by stripping separators and splitting on them
            if "(" in sample_result:
                sample_result = sample_result.replace("(", "").replace(")", "")
                result_list = sample_result.split(",") 
            else:
                result_list = sample_result.split("|") 

            # Strip all results in list
            for i in range(len(result_list)):
                result_list[i] = str(result_list[i]).strip()
            
            # Loop through model result and see if it matches training example
            result = False
            for row in rows:
                for r in row:
                    for res in result_list:
                        try:
                            if math.isclose(float(r), float(res), abs_tol=0.5):
                                return True, query_match, True
                        except:
                            if r in res or res in r:
                                return True, query_match, True
                    
            # Check if the model returned a sum of examples as opposed to the whole thing
            if len(rows) == 1:
                for r in rows[0]:
                    if r == str(len(result_list)):
                        return True, query_match, True
                    
            return True, query_match, result
        # Else the sample result is a single value or string
        else:
            #print(rows)
            result = False
            # Loop through model result and see if it contains the sample result
            for row in rows:
                for r in row:
                    # Check by string
                    if str(r) in str(sample_result):
                        try:
                            if math.isclose(float(r), float(sample_result), abs_tol=0.5):
                                return True, query_match, True
                        except:
                            return True, query_match, True
                    # Check by number, using try incase the cast as float fails
                    try:
                        if math.isclose(float(r), float(sample_result), abs_tol=0.5):
                            return True, query_match, True
                    except:
                        pass

            # Check if the model returned a list of examples instead of a total sum (both acceptable)
            try:
                if len(rows) > 1 and len(rows) == int(sample_result):
                    return True, query_match, True
                if len(rows[0]) > 1 and rows[0][1] is not None and  len(rows[0]) == int(sample_result):
                    return True, query_match, True
            except:
                pass

            # Compare results and return
            return True, query_match, result
    except:
        return False, False, False

# Obtain sample
sample = df.sample(n=1)
print(sample["natural_query"].values[0])
print(sample["sql_query"].values[0])
print(sample["result"].values[0])

# Create message with sample query and run model
message=[{ 'role': 'user', 'content': input_text + sample["natural_query"].values[0]}]
inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(model.device)
outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)

# Print output
query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
print(query_output)

result = compare_result(sample["sql_query"].values[0], sample["result"].values[0], query_output)
print("Statement valid? " + str(result[0]))
print("SQLite matched? " + str(result[1]))
print("Result matched? " + str(result[2]))

How many home games did the Chicago Bulls play in the 2020 season?
SELECT COUNT(*) FROM game WHERE team_name_home = 'Chicago Bulls' AND season_id = '22020';
36.0
SQLite:
SELECT COUNT(*) as total_home_games 
FROM game 
WHERE team_name_home = 'Chicago Bulls' AND season_id = '22020';

Statement valid? True
SQLite matched? False
Result matched? True


## Create function to evaluate pretrained model on full datasets

In [40]:
def run_evaluation(nba_df, title):
    counter = 0
    num_valid = 0
    num_sql_matched = 0
    num_result_matched = 0
    for index, row in nba_df.iterrows():
        # Create message with sample query and run model
        message=[{ 'role': 'user', 'content': input_text + row["natural_query"]}]
        inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors="pt").to(model.device)
        outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)

        # Obtain output
        query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)

        # Evaluate model result
        valid, sql_matched, result_matched = compare_result(row["sql_query"], row["result"], query_output)
        if valid:
            num_valid += 1
        if sql_matched:
            num_sql_matched += 1
        if result_matched:
            num_result_matched += 1

        # Break after predefined number of examples
        counter += 1
        if counter % 50 == 0:
            print("Completed " + str(counter))

    # Print evaluation results
    print("\n" + title + " results:")
    print("Percent valid: " + str(num_valid / len(nba_df)))
    print("Percent SQLite matched: " + str(num_sql_matched / len(nba_df)))
    print("Percent result matched: " + str(num_result_matched / len(nba_df)))

# Evaluate on less than 90 dataset

In [41]:
less_than_90_df = pd.read_csv("./train-data/less_than_90.tsv", sep='\t')
run_evaluation(less_than_90_df.sample(n=50), "Less than 90")
print("Dataset length: " + str(len(less_than_90_df)))

Completed 50

Less than 90 results:
Percent valid: 0.62
Percent SQLite matched: 0.12
Percent result matched: 0.4
Dataset length: 245


# Evaluate on game table queries

In [25]:
game_queries = pd.read_csv("./train-data/queries_from_game.tsv", sep='\t')
run_evaluation(game_queries, "Queries from game")
print("Dataset length: " + str(len(game_queries)))

Completed 50
Completed 100
Completed 150
Completed 200
Completed 250
Completed 300
Completed 350
Completed 400
Completed 450
Completed 500
Completed 550
Completed 600
Completed 650
Completed 700
Completed 750
Completed 800

Queries from game results:
Percent valid: 0.6181384248210023
Percent SQLite matched: 0.015513126491646777
Percent result matched: 0.24343675417661098
Dataset length: 838


## Evaluate on other stats queries

In [23]:
other_stats_queries = pd.read_csv("./train-data/queries_from_other_stats.tsv", sep='\t')
run_evaluation(other_stats_queries, "Queries from other stats")
print("Dataset length: " + str(len(other_stats_queries)))

Completed 50
Completed 100
Completed 150

Queries from other stats results:
Percent valid: 0.6168831168831169
Percent SQLite matched: 0.06493506493506493
Percent result matched: 0.34415584415584416
Dataset length: 154


## Evaluate on team queries

In [24]:
team_queries = pd.read_csv("./train-data/queries_from_team.tsv", sep='\t')
run_evaluation(team_queries, "Queries from team")
print("Dataset length: " + str(len(team_queries)))

Completed 50

Queries from team results:
Percent valid: 0.8846153846153846
Percent SQLite matched: 0.6346153846153846
Percent result matched: 0.8269230769230769
Dataset length: 52


## Evaluate on queries requiring join statements

In [12]:
join_queries = pd.read_csv("./train-data/with_join.tsv", sep='\t')
run_evaluation(join_queries, "Queries with join")
print("Dataset length: " + str(len(join_queries)))

Completed 50
Completed 100
Completed 150

Queries with join results:
Percent valid: 0.06486486486486487
Percent SQLite matched: 0.0
Percent result matched: 0.010810810810810811
Dataset length: 185


## Evaluate on queries not requiring join statements

In [13]:
no_join_queries = pd.read_csv("./train-data/without_join.tsv", sep='\t')
run_evaluation(no_join_queries, "Queries without join")
print("Dataset length: " + str(len(no_join_queries)))

Completed 50
Completed 100
Completed 150
Completed 200
Completed 250
Completed 300
Completed 350
Completed 400
Completed 450
Completed 500
Completed 550
Completed 600
Completed 650
Completed 700
Completed 750
Completed 800
Completed 850

Queries without join results:
Percent valid: 0.7974388824214202
Percent SQLite matched: 0.1559953434225844
Percent result matched: 0.4318975552968568
Dataset length: 859


## Evaluate on full training dataset

In [14]:
# Run evaluation on all training data
run_evaluation(df, "All training data")
print("Dataset length: " + str(len(df)))

Completed 50
Completed 100
Completed 150
Completed 200
Completed 250
Completed 300
Completed 350
Completed 400
Completed 450
Completed 500
Completed 550
Completed 600
Completed 650
Completed 700
Completed 750
Completed 800
Completed 850
Completed 900
Completed 950
Completed 1000

All training data results:
Percent valid: 0.6676245210727969
Percent SQLite matched: 0.12835249042145594
Percent result matched: 0.35823754789272033
Dataset length: 1044
