{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "5f167a6f-5139-46e6-afb2-a1fa4d12f3fd", "metadata": {}, "outputs": [], "source": [ "import time\n", "import logging\n", "import re\n", "import random\n", "import gc\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import evaluate\n", "\n", "from datasets import Dataset, DatasetDict, load_from_disk\n", "from transformers import (\n", " AutoModelForSeq2SeqLM,\n", " AutoTokenizer,\n", " TrainingArguments,\n", " Trainer,\n", " GenerationConfig,\n", " BitsAndBytesConfig,\n", ")\n", "from transformers.trainer_callback import EarlyStoppingCallback\n", "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training" ] }, { "cell_type": "code", "execution_count": 2, "id": "53684b5e-c27e-4eb9-815e-583aa194e096", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "# Enable cudnn benchmark for fixed input sizes (can speed up computation)\n", "torch.backends.cudnn.benchmark = True\n", "\n", "# Set device to RTX 4090\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "code", "execution_count": 3, "id": "a47bf3cd-752d-4d1c-9697-70098d6204fa", "metadata": {}, "outputs": [], "source": [ "random.seed(42)\n", "np.random.seed(42)\n", "torch.manual_seed(42)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed_all(42)" ] }, { "cell_type": "code", "execution_count": 4, "id": "f16df21e-9797-4f78-83a1-a2943759ba55", "metadata": {}, "outputs": [], "source": [ "def clear_memory():\n", " gc.collect()\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "code", "execution_count": 5, "id": "196e83da-6c8c-4cd7-bd70-2598a5e2a16a", "metadata": {}, "outputs": [], "source": [ "logging.basicConfig(\n", " level=logging.INFO,\n", " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n", ")\n", "logger = logging.getLogger(__name__)" ] }, { "cell_type": "code", "execution_count": 6, "id": "cea22b9f-f309-4151-81ac-37547c8feeb0", "metadata": {}, "outputs": [], "source": [ "def preprocess(text: str) -> str:\n", " \"\"\"Remove extra whitespaces and newlines from a text string.\"\"\"\n", " if not isinstance(text, str):\n", " return \"\"\n", " return re.sub(r'\\s+', ' ', text.replace('\\n', ' ')).strip()\n", "\n", "def clean_df(df, rename=None, drop=None, select=None):\n", " \"\"\"\n", " Clean and rename dataframe columns:\n", " - drop: list of columns to drop\n", " - rename: dict mapping old column names to new names\n", " - select: list of columns to keep in final order\n", " \"\"\"\n", " if drop:\n", " df = df.drop(columns=drop, errors='ignore')\n", " if rename:\n", " df = df.rename(columns=rename)\n", " for col in ['query', 'context', 'response']:\n", " if col in df.columns:\n", " df[col] = df[col].apply(preprocess)\n", " if select:\n", " df = df[select]\n", " return df" ] }, { "cell_type": "code", "execution_count": 7, "id": "d4eb82ce-1713-40b6-981d-43ce35aaa6f6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 14:56:53,295 - INFO - Loading raw datasets from various sources...\n", "2025-03-19 14:57:25,655 - INFO - Total rows before dropping duplicates: 490241\n", "2025-03-19 14:57:27,208 - INFO - Total rows after dropping duplicates: 440785\n" ] } ], "source": [ "logger.info(\"Loading raw datasets from various sources...\")\n", "\n", "# Load datasets\n", "df1 = pd.read_json(\"hf://datasets/Clinton/Text-to-sql-v1/texttosqlv2.jsonl\", lines=True)\n", "df2 = pd.read_json(\"hf://datasets/b-mc2/sql-create-context/sql_create_context_v4.json\")\n", "df3 = pd.read_parquet(\"hf://datasets/gretelai/synthetic_text_to_sql/synthetic_text_to_sql_train.snappy.parquet\")\n", "df4 = pd.read_json(\"hf://datasets/knowrohit07/know_sql/know_sql_val3{ign}.json\")\n", "\n", "# Clean and rename columns to unify to 'query', 'context', 'response'\n", "df1 = clean_df(df1, rename={'instruction': 'query', 'input': 'context'}, drop=['source', 'text'])\n", "df2 = clean_df(df2, rename={'question': 'query', 'answer': 'response'})\n", "df3 = clean_df(df3, rename={'sql_prompt': 'query', 'sql_context': 'context', 'sql': 'response'},\n", " select=['query', 'context', 'response'])\n", "df4 = clean_df(df4, rename={'question': 'query', 'answer': 'response'})\n", "\n", "# Concatenate all DataFrames\n", "final_df = pd.concat([df1, df2, df3, df4], ignore_index=True)\n", "logger.info(\"Total rows before dropping duplicates: %d\", len(final_df))\n", "\n", "# Force correct column order and drop rows with missing fields\n", "final_df = final_df[['query', 'context', 'response']]\n", "final_df = final_df.dropna(subset=['query', 'context', 'response'])\n", "final_df = final_df.drop_duplicates()\n", "logger.info(\"Total rows after dropping duplicates: %d\", len(final_df))" ] }, { "cell_type": "code", "execution_count": 8, "id": "8446814e-5a2c-48a4-8c01-059afcf1d3c1", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Token indices sequence length is longer than the specified maximum sequence length for this model (1113 > 512). Running this sequence through the model will result in indexing errors\n", "2025-03-19 15:01:13,787 - INFO - Total rows after filtering by token length (prompt <= 500 and response <= 250 tokens): 398481\n" ] } ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n", "\n", "max_length_prompt = 500\n", "max_length_response = 250\n", "\n", "def tokenize_length_filter(row):\n", " start_prompt = \"Context:\\n\"\n", " middle_prompt = \"\\n\\nQuery:\\n\"\n", " end_prompt = \"\\n\\nResponse:\\n\"\n", " \n", " # Construct the prompt as used in the tokenize_function\n", " prompt = f\"{start_prompt}{row['context']}{middle_prompt}{row['query']}{end_prompt}\"\n", " \n", " # Encode without truncation to get the full token count\n", " prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True, truncation=False)\n", " response_tokens = tokenizer.encode(row['response'], add_special_tokens=True, truncation=False)\n", " \n", " return len(prompt_tokens) <= max_length_prompt and len(response_tokens) <= max_length_response\n", "\n", "final_df = final_df[final_df.apply(tokenize_length_filter, axis=1)]\n", "logger.info(\"Total rows after filtering by token length (prompt <= %d and response <= %d tokens): %d\", \n", " max_length_prompt, max_length_response, len(final_df))\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "177e1e6d-9fbc-442d-9774-5a3e5234329f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 15:01:13,794 - INFO - Sample from filtered final_df:\n", " query \\\n", "0 Name the home team for carlton away team \n", "1 what will the population of Asia be when Latin... \n", "2 How many faculty members do we have for each g... \n", "\n", " context \\\n", "0 CREATE TABLE table_name_77 ( home_team VARCHAR... \n", "1 CREATE TABLE table_22767 ( \"Year\" real, \"World... \n", "2 CREATE TABLE Student ( StuID INTEGER, LName VA... \n", "\n", " response \n", "0 SELECT home_team FROM table_name_77 WHERE away... \n", "1 SELECT \"Asia\" FROM table_22767 WHERE \"Latin Am... \n", "2 SELECT Sex, COUNT(*) FROM Faculty GROUP BY Sex... \n" ] } ], "source": [ "logger.info(\"Sample from filtered final_df:\\n%s\", final_df.head(3))\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 10, "id": "0b639efe-ebeb-4b34-bc3f-accf776ba0da", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 15:01:14,006 - INFO - Final split sizes: Train: 338708, Test: 39848, Validation: 19925\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "81e753f720e44f40b5f0dfa5263e2bf5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Saving the dataset (0/1 shards): 0%| | 0/338708 [00:00 dict:\n", " \"\"\"\n", " Tokenizes a batch of examples for T5 fine-tuning.\n", " Constructs a prompt in the format:\n", " Context:\n", " \n", " \n", " Query:\n", " \n", " \n", " Response:\n", " \"\"\"\n", " start_prompt = \"Context:\\n\"\n", " middle_prompt = \"\\n\\nQuery:\\n\"\n", " end_prompt = \"\\n\\nResponse:\\n\"\n", "\n", " prompts = [\n", " f\"{start_prompt}{ctx}{middle_prompt}{qry}{end_prompt}\"\n", " for ctx, qry in zip(batch['context'], batch['query'])\n", " ]\n", "\n", " tokenized_inputs = tokenizer(\n", " prompts,\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=512\n", " )\n", " tokenized_labels = tokenizer(\n", " batch['response'],\n", " padding=\"max_length\",\n", " truncation=True,\n", " max_length=256\n", " )\n", " labels = [\n", " [-100 if token == tokenizer.pad_token_id else token for token in seq]\n", " for seq in tokenized_labels['input_ids']\n", " ]\n", "\n", " batch['input_ids'] = tokenized_inputs['input_ids']\n", " batch['attention_mask'] = tokenized_inputs['attention_mask']\n", " batch['labels'] = labels\n", " return batch\n", "\n", "try:\n", " tokenized_datasets = load_from_disk(\"tokenized_datasets\")\n", " logger.info(\"Loaded Tokenized Dataset from disk.\")\n", "except Exception as e:\n", " logger.info(\"Tokenized dataset not found. Creating a new one...\")\n", " tokenized_datasets = dataset.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=['query', 'context', 'response'],\n", " num_proc=8\n", " )\n", " tokenized_datasets.save_to_disk(\"tokenized_datasets\")\n", " logger.info(\"Tokenized and Saved Dataset.\")\n", "\n", "tokenized_datasets.set_format(\"torch\")\n", "\n", "logger.info(\"Final tokenized dataset splits: %s\", tokenized_datasets.keys())\n", "logger.info(\"Sample tokenized record from train split:\\n%s\", tokenized_datasets['train'][0])" ] }, { "cell_type": "code", "execution_count": 12, "id": "7f004e55-181c-47aa-9f3e-c7c1ceae780c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "Context:\n", "CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\n", "\n", "Query:\n", "Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\n", "\n", "Response:\n", "\n", "----------------------------------------------------------------------------------------------------\n", "BASELINE HUMAN ANSWER:\n", "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n", "\n", "----------------------------------------------------------------------------------------------------\n", "MODEL GENERATION - ZERO SHOT:\n", "USCYBERCOM, JTF-CND, Offensive Cyber Operations, 10th Fleet, Network Warfare\n" ] } ], "source": [ "model_name = 'google/flan-t5-base'\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n", "original_model = original_model.to(device)\n", "\n", "index = 0\n", "query = dataset['test'][index]['query']\n", "context = dataset['test'][index]['context']\n", "response = dataset['test'][index]['response']\n", "\n", "prompt = f\"\"\"Context:\n", "{context}\n", "\n", "Query:\n", "{query}\n", "\n", "Response:\n", "\"\"\"\n", "inputs = tokenizer(prompt, return_tensors='pt').to(device)\n", "baseline_output = tokenizer.decode(\n", " original_model.generate(\n", " inputs[\"input_ids\"],\n", " max_new_tokens=200,\n", " )[0],\n", " skip_special_tokens=True\n", ")\n", "dash_line = '-' * 100\n", "print(dash_line)\n", "print(f'INPUT PROMPT:\\n{prompt}')\n", "print(dash_line)\n", "print(f'BASELINE HUMAN ANSWER:\\n{response}\\n')\n", "print(dash_line)\n", "print(f'MODEL GENERATION - ZERO SHOT:\\n{baseline_output}')\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 13, "id": "f50e56c7-98b3-42bc-9129-89f3eff802e7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 15:01:30,827 - INFO - Attempting to load the fine-tuned model...\n", "2025-03-19 15:01:32,195 - INFO - Fine-tuned model loaded successfully.\n" ] } ], "source": [ "import math\n", "\n", "try:\n", " logger.info(\"Attempting to load the fine-tuned model...\")\n", " finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\"text2sql_flant5base_finetuned\")\n", " tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n", " finetuned_model = finetuned_model.to(device)\n", " to_train = False\n", " logger.info(\"Fine-tuned model loaded successfully.\")\n", "except Exception as e:\n", " logger.info(\"Fine-tuned model not found.\")\n", " logger.info(\"Initializing model and tokenizer for QLORA fine-tuning...\")\n", " to_train = True\n", "\n", " quant_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_quant_type=\"nf4\",\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", " )\n", "\n", " finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(\n", " model_name,\n", " quantization_config=quant_config,\n", " device_map=\"auto\",\n", " torch_dtype=torch.bfloat16,\n", " )\n", " finetuned_model = prepare_model_for_kbit_training(finetuned_model)\n", " \n", " lora_config = LoraConfig(\n", " r=32,\n", " lora_alpha=64,\n", " target_modules=[\"q\", \"v\"],\n", " lora_dropout=0.1,\n", " bias=\"none\",\n", " task_type=\"SEQ_2_SEQ_LM\"\n", " )\n", " finetuned_model = get_peft_model(finetuned_model, lora_config)\n", " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", " logger.info(\"Base model loaded and prepared for QLORA fine-tuning.\")\n", " clear_memory()\n", "\n", "if to_train:\n", " output_dir = f\"./sql-training-{int(time.time())}\"\n", " logger.info(\"Starting training. Output directory: %s\", output_dir)\n", "\n", " # Compute total training steps:\n", " num_train_samples = len(tokenized_datasets[\"train\"])\n", " per_device_train_batch_size = 64\n", " per_device_eval_batch_size = 64\n", " num_train_epochs = 6\n", " # Assuming no gradient accumulation beyond the per-device batch size\n", " total_steps = math.ceil(num_train_samples / per_device_train_batch_size) * num_train_epochs\n", " # Set warmup steps as 10% of total steps (adjust as needed)\n", " warmup_steps = int(total_steps * 0.1)\n", " \n", " logger.info(\"Total training steps: %d, Warmup steps (10%%): %d\", total_steps, warmup_steps)\n", " \n", " training_args = TrainingArguments(\n", " output_dir=output_dir,\n", " gradient_checkpointing=True,\n", " gradient_checkpointing_kwargs={\"use_reentrant\": True},\n", " gradient_accumulation_steps = 2,\n", " learning_rate=2e-4,\n", " optim=\"adamw_bnb_8bit\", # Memory-efficient optimizer\n", " num_train_epochs=num_train_epochs,\n", " per_device_train_batch_size=per_device_train_batch_size,\n", " per_device_eval_batch_size=per_device_eval_batch_size,\n", " weight_decay=0.01,\n", " logging_steps=200, \n", " logging_dir=f\"{output_dir}/logs\",\n", " eval_strategy=\"epoch\", # Evaluate at the end of each epoch\n", " save_strategy=\"epoch\", # Save the model at the end of each epoch\n", " save_total_limit=3,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"eval_loss\",\n", " bf16=True, \n", " warmup_ratio=0.1, # Warmup 10% of total steps\n", " lr_scheduler_type=\"cosine\",\n", " )\n", " trainer = Trainer(\n", " model=finetuned_model,\n", " args=training_args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"validation\"],\n", " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],\n", " )\n", " logger.info(\"Beginning fine-tuning...\")\n", " trainer.train()\n", " logger.info(\"Training completed.\")\n", " save_path = \"text2sql_flant5base_finetuned\"\n", " finetuned_model.save_pretrained(save_path)\n", " logger.info(\"Model saved to %s\", save_path)\n", " clear_memory()" ] }, { "cell_type": "code", "execution_count": 14, "id": "f364eb6b-56cb-4533-8ef6-b5e7f56895aa", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 15:01:32,235 - INFO - Running inference on 5 examples (displaying real responses).\n", "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "====================================================================================================\n", "----------------------------------------------------------------------------------------------------\n", "Example 1\n", "----------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "Context:\n", "CREATE SCHEMA IF NOT EXISTS defense_security;CREATE TABLE IF NOT EXISTS defense_security.Military_Cyber_Commands (id INT PRIMARY KEY, command_name VARCHAR(255), type VARCHAR(255));INSERT INTO defense_security.Military_Cyber_Commands (id, command_name, type) VALUES (1, 'USCYBERCOM', 'Defensive Cyber Operations'), (2, 'JTF-CND', 'Offensive Cyber Operations'), (3, '10th Fleet', 'Network Warfare');\n", "\n", "Query:\n", "Show the name and type of military cyber commands in the 'Military_Cyber_Commands' table.\n", "\n", "Response:\n", "\n", "----------------------------------------------------------------------------------------------------\n", "HUMAN RESPONSE:\n", "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n", "----------------------------------------------------------------------------------------------------\n", "ORIGINAL MODEL OUTPUT:\n", "USCYBERCOM, JTF-CND, Offensive Cyber Operations\n", "----------------------------------------------------------------------------------------------------\n", "FINE-TUNED MODEL OUTPUT:\n", "SELECT command_name, type FROM defense_security.Military_Cyber_Commands;\n", "====================================================================================================\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Example 2\n", "----------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "Context:\n", "CREATE TABLE incidents (id INT, cause VARCHAR(255), cost INT, date DATE); INSERT INTO incidents (id, cause, cost, date) VALUES (1, 'insider threat', 10000, '2022-01-01'); INSERT INTO incidents (id, cause, cost, date) VALUES (2, 'phishing', 5000, '2022-01-02');\n", "\n", "Query:\n", "Find the total cost of all security incidents caused by insider threats in the last 6 months\n", "\n", "Response:\n", "\n", "----------------------------------------------------------------------------------------------------\n", "HUMAN RESPONSE:\n", "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n", "----------------------------------------------------------------------------------------------------\n", "ORIGINAL MODEL OUTPUT:\n", "10000, 2022-01-01\n", "----------------------------------------------------------------------------------------------------\n", "FINE-TUNED MODEL OUTPUT:\n", "SELECT SUM(cost) FROM incidents WHERE cause = 'insider threat' AND date >= DATE_SUB(CURRENT_DATE, INTERVAL 6 MONTH);\n", "====================================================================================================\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Example 3\n", "----------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "Context:\n", "CREATE TABLE libraries (name VARCHAR(255), state VARCHAR(255), population DECIMAL(10,2), libraries DECIMAL(5,2)); INSERT INTO libraries (name, state, population, libraries) VALUES ('Library1', 'California', 39512223, 3154), ('Library2', 'Texas', 29528404, 2212), ('Library3', 'Florida', 21644287, 1835);\n", "\n", "Query:\n", "Show the top 3 states with the most public libraries per capita.\n", "\n", "Response:\n", "\n", "----------------------------------------------------------------------------------------------------\n", "HUMAN RESPONSE:\n", "SELECT state, (libraries / population) AS libraries_per_capita FROM libraries ORDER BY libraries_per_capita DESC LIMIT 3;\n", "----------------------------------------------------------------------------------------------------\n", "ORIGINAL MODEL OUTPUT:\n", "California, 39512223, 3154\n", "----------------------------------------------------------------------------------------------------\n", "FINE-TUNED MODEL OUTPUT:\n", "SELECT state, population, RANK() OVER (ORDER BY population DESC) as rank FROM libraries GROUP BY state ORDER BY rank DESC LIMIT 3;\n", "====================================================================================================\n", "\n", "----------------------------------------------------------------------------------------------------\n", "Example 4\n", "----------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "Context:\n", "CREATE TABLE users (id INT, location VARCHAR(50)); CREATE TABLE posts (id INT, user_id INT, created_at DATETIME);\n", "\n", "Query:\n", "What is the total number of posts made by users located in Australia, in the last month?\n", "\n", "Response:\n", "\n", "----------------------------------------------------------------------------------------------------\n", "HUMAN RESPONSE:\n", "SELECT COUNT(posts.id) FROM posts INNER JOIN users ON posts.user_id = users.id WHERE users.location = 'Australia' AND posts.created_at >= DATE_SUB(NOW(), INTERVAL 1 MONTH);\n", "----------------------------------------------------------------------------------------------------\n", "ORIGINAL MODEL OUTPUT:\n", "The total number of posts made by users located in Australia is 50.\n", "----------------------------------------------------------------------------------------------------\n", "FINE-TUNED MODEL OUTPUT:\n", "SELECT COUNT(*) FROM posts p JOIN users u ON p.user_id = u.id WHERE u.location = 'Australia' AND p.created_at >= DATE_SUB(CURRENT_DATE, INTERVAL 1 MONTH);\n", "====================================================================================================\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 15:01:40,448 - INFO - Starting evaluation on the full test set using batching.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------------------------------------------\n", "Example 5\n", "----------------------------------------------------------------------------------------------------\n", "INPUT PROMPT:\n", "Context:\n", "CREATE TABLE WindFarms (FarmID INT, FarmName VARCHAR(255), Capacity DECIMAL(5,2), Country VARCHAR(255)); INSERT INTO WindFarms (FarmID, FarmName, Capacity, Country) VALUES (1, 'WindFarm1', 150, 'USA'), (2, 'WindFarm2', 200, 'Canada'), (3, 'WindFarm3', 120, 'Mexico');\n", "\n", "Query:\n", "List the total installed capacity of wind farms in the WindEnergy schema for each country?\n", "\n", "Response:\n", "\n", "----------------------------------------------------------------------------------------------------\n", "HUMAN RESPONSE:\n", "SELECT Country, SUM(Capacity) as TotalCapacity FROM WindFarms GROUP BY Country;\n", "----------------------------------------------------------------------------------------------------\n", "ORIGINAL MODEL OUTPUT:\n", "1, 150, USA, 2, 200, Canada, 3, 120, Mexico\n", "----------------------------------------------------------------------------------------------------\n", "FINE-TUNED MODEL OUTPUT:\n", "SELECT Country, SUM(Capacity) FROM WindFarms GROUP BY Country;\n", "====================================================================================================\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a7beecee09a34f9790be1e4538a87442", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading builder script: 0%| | 0.00/5.94k [00:00 str:\n", " \"\"\"Post-process the generated output to remove repeated text.\"\"\"\n", " # Keep only the first valid SQL query (everything before the first semicolon)\n", " return output_text.split(\";\")[0] + \";\" if \";\" in output_text else output_text\n", "\n", "# Define a helper function for generating outputs with the given generation parameters.\n", "def generate_with_params(model, input_ids):\n", " generated_ids = model.generate(\n", " input_ids=input_ids,\n", " max_new_tokens=100, \n", " num_beams=5,\n", " repetition_penalty=1.2,\n", " temperature=0.1,\n", " early_stopping=True\n", " )\n", " # Decode and post-process output\n", " output_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", " return output_text\n", "\n", "# Helper functions for SQL normalization and evaluation metrics\n", "def normalize_sql(sql):\n", " \"\"\"Normalize SQL by stripping whitespace and lowercasing.\"\"\"\n", " return \" \".join(sql.strip().lower().split())\n", "\n", "def compute_exact_match(predictions, references):\n", " \"\"\"Computes the exact match accuracy after normalization.\"\"\"\n", " matches = sum(1 for pred, ref in zip(predictions, references)\n", " if normalize_sql(pred) == normalize_sql(ref))\n", " return (matches / len(predictions)) * 100 if predictions else 0\n", "\n", "def compute_fuzzy_match(predictions, references):\n", " \"\"\"Computes a soft matching score using token_set_ratio from rapidfuzz.\"\"\"\n", " scores = [fuzz.token_set_ratio(pred, ref) for pred, ref in zip(predictions, references)]\n", " return sum(scores) / len(scores) if scores else 0\n", "\n", "# Dummy function to free up memory if needed.\n", "def clear_memory():\n", " # If using torch.cuda, you can clear cache:\n", " # torch.cuda.empty_cache()\n", " pass\n", "\n", "logger = logging.getLogger(__name__)\n", "logger.setLevel(logging.INFO)\n", "\n", "# --- Part A: Inference on 5 Examples with Real Responses ---\n", "logger.info(\"Running inference on 5 examples (displaying real responses).\")\n", "\n", "num_examples = 5\n", "sample_queries = dataset[\"test\"][:num_examples][\"query\"]\n", "sample_contexts = dataset[\"test\"][:num_examples][\"context\"]\n", "sample_human_responses = dataset[\"test\"][:num_examples][\"response\"]\n", "\n", "print(\"\\n\" + \"=\" * 100)\n", "for idx in range(num_examples):\n", " prompt = f\"\"\"Context:\n", "{sample_contexts[idx]}\n", "\n", "Query:\n", "{sample_queries[idx]}\n", "\n", "Response:\n", "\"\"\"\n", " # Tokenize the prompt and move to device\n", " inputs = tokenizer(prompt, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n", " \n", " # Generate outputs using the modified generation parameters\n", " orig_out = generate_with_params(original_model, inputs[\"input_ids\"])\n", " finetuned_out = post_process_output(generate_with_params(finetuned_model, inputs[\"input_ids\"]))\n", " \n", " print(\"-\" * 100)\n", " print(f\"Example {idx+1}\")\n", " print(\"-\" * 100)\n", " print(\"INPUT PROMPT:\")\n", " print(prompt)\n", " print(\"-\" * 100)\n", " print(\"HUMAN RESPONSE:\")\n", " print(sample_human_responses[idx])\n", " print(\"-\" * 100)\n", " print(\"ORIGINAL MODEL OUTPUT:\")\n", " print(orig_out)\n", " print(\"-\" * 100)\n", " print(\"FINE-TUNED MODEL OUTPUT:\")\n", " print(finetuned_out)\n", " print(\"=\" * 100 + \"\\n\")\n", " clear_memory()\n", "\n", "# --- Part B: Evaluation on Full Test Set with Batching (Optimized) ---\n", "logger.info(\"Starting evaluation on the full test set using batching.\")\n", "\n", "all_human_responses = []\n", "all_original_responses = []\n", "all_finetuned_responses = []\n", "\n", "batch_size = 128 # Adjust based on GPU memory\n", "test_dataset = dataset[\"test\"]\n", "\n", "for i in range(0, len(test_dataset), batch_size):\n", " # Slicing the dataset returns a dict of lists\n", " batch = test_dataset[i:i + batch_size]\n", " \n", " # Construct prompts for each example in the batch\n", " prompts = [\n", " f\"Context:\\n{batch['context'][j]}\\n\\nQuery:\\n{batch['query'][j]}\\n\\nResponse:\"\n", " for j in range(len(batch[\"context\"]))\n", " ]\n", " \n", " # Extend human responses\n", " all_human_responses.extend(batch[\"response\"])\n", " \n", " # Tokenize the batch of prompts with padding and truncation\n", " inputs = tokenizer(prompts, return_tensors=\"pt\", padding=True, truncation=True, max_length=512).to(device)\n", " \n", " # Generate outputs for the batch for both models\n", " orig_ids = original_model.generate(\n", " input_ids=inputs[\"input_ids\"],\n", " max_new_tokens=100,\n", " num_beams=5,\n", " repetition_penalty=1.2,\n", " temperature=0.1,\n", " early_stopping=True\n", " )\n", " finetuned_ids = finetuned_model.generate(\n", " input_ids=inputs[\"input_ids\"],\n", " max_new_tokens=100,\n", " num_beams=5,\n", " repetition_penalty=1.2,\n", " temperature=0.1,\n", " early_stopping=True\n", " )\n", " \n", " # Decode and post-process each sample in the batch\n", " orig_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in orig_ids]\n", " finetuned_texts = [post_process_output(tokenizer.decode(ids, skip_special_tokens=True)) for ids in finetuned_ids]\n", " \n", " all_original_responses.extend(orig_texts)\n", " all_finetuned_responses.extend(finetuned_texts)\n", " clear_memory()\n", "\n", "# Create a DataFrame for a quick comparison of results\n", "zipped_all = list(zip(all_human_responses, all_original_responses, all_finetuned_responses))\n", "df_full = pd.DataFrame(zipped_all, columns=[\"Human Response\", \"Original Model Output\", \"Fine-Tuned Model Output\"])\n", "df_full.to_csv('evaluation_results.csv', index=False)\n", "clear_memory()\n", "\n", "# --- Compute Evaluation Metrics ---\n", "rouge = evaluate.load(\"rouge\")\n", "bleu = evaluate.load(\"bleu\")\n", "\n", "# Compute metrics for the original (non-fine-tuned) model\n", "orig_rouge = rouge.compute(\n", " predictions=all_original_responses,\n", " references=all_human_responses,\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "orig_bleu = bleu.compute(\n", " predictions=all_original_responses,\n", " references=[[ref] for ref in all_human_responses]\n", ")\n", "orig_fuzzy = compute_fuzzy_match(all_original_responses, all_human_responses)\n", "orig_exact = compute_exact_match(all_original_responses, all_human_responses)\n", "\n", "# Compute metrics for the fine-tuned model\n", "finetuned_rouge = rouge.compute(\n", " predictions=all_finetuned_responses,\n", " references=all_human_responses,\n", " use_aggregator=True,\n", " use_stemmer=True,\n", ")\n", "finetuned_bleu = bleu.compute(\n", " predictions=all_finetuned_responses,\n", " references=[[ref] for ref in all_human_responses]\n", ")\n", "finetuned_fuzzy = compute_fuzzy_match(all_finetuned_responses, all_human_responses)\n", "finetuned_exact = compute_exact_match(all_finetuned_responses, all_human_responses)\n", "\n", "print(\"\\n\" + \"=\" * 100)\n", "print(\"Evaluation Metrics:\")\n", "print(\"=\" * 100)\n", "print(\"ORIGINAL MODEL:\")\n", "print(f\" ROUGE: {orig_rouge}\")\n", "print(f\" BLEU: {orig_bleu}\")\n", "print(f\" Fuzzy Match Score: {orig_fuzzy:.2f}%\")\n", "print(f\" Exact Match Accuracy: {orig_exact:.2f}%\\n\")\n", "print(\"FINE-TUNED MODEL:\")\n", "print(f\" ROUGE: {finetuned_rouge}\")\n", "print(f\" BLEU: {finetuned_bleu}\")\n", "print(f\" Fuzzy Match Score: {finetuned_fuzzy:.2f}%\")\n", "print(f\" Exact Match Accuracy: {finetuned_exact:.2f}%\")\n", "print(\"=\" * 100)\n", "clear_memory()" ] }, { "cell_type": "code", "execution_count": 15, "id": "462546a7-6928-4723-b00e-23c3a4091d99", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2025-03-19 16:51:05,225 - INFO - Running inference with deterministic decoding and beam search.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Prompt:\n", "Context:\n", "CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), (3, 'Charlie', 'Canada'), (4, 'David', 'USA'); INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES (101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), (103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), (105, 4, 900, '2024-03-05');\n", "\n", "Query:\n", "Retrieve the total order amount for each customer, showing only customers from the USA, and sort the result by total order amount in descending order.\n", "\n", "Response:\n", "SELECT customer_id, SUM(total_amount) as total_amount FROM orders JOIN customers ON orders.customer_id = customers.id WHERE customers.country = 'USA' GROUP BY customer_id ORDER BY total_amount DESC;\n" ] } ], "source": [ "import torch\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "import logging\n", "\n", "# Set up logging\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format=\"%(asctime)s - %(levelname)s - %(message)s\",\n", ")\n", "logger = logging.getLogger(__name__)\n", "\n", "# Ensure device is set (GPU if available)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Load the fine-tuned model and tokenizer\n", "model_name = \"text2sql_flant5base_finetuned\" \n", "finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)\n", "tokenizer = AutoTokenizer.from_pretrained(\"google/flan-t5-base\")\n", "finetuned_model.to(device)\n", "\n", "def run_inference(prompt_text: str) -> str:\n", " \"\"\"\n", " Runs inference on the fine-tuned model using deterministic decoding\n", " with beam search, returning the generated SQL query.\n", " \"\"\"\n", " inputs = tokenizer(prompt_text, return_tensors=\"pt\").to(device)\n", " generated_ids = finetuned_model.generate(\n", " input_ids=inputs[\"input_ids\"],\n", " max_new_tokens=100, # Adjust based on query complexity\n", " temperature=0.1, # Deterministic output\n", " num_beams=5, # Beam search for better output quality\n", " early_stopping=True, # Stop early if possible\n", " )\n", " generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "\n", " # Post-processing to remove repeated text\n", " generated_sql = generated_sql.split(\";\")[0] + \";\" # Keep only the first valid SQL query\n", "\n", " return generated_sql\n", "\n", "# Sample context and query (example)\n", "context = (\n", " \"CREATE TABLE customers (id INT PRIMARY KEY, name VARCHAR(100), country VARCHAR(50)); \"\n", " \"CREATE TABLE orders (order_id INT PRIMARY KEY, customer_id INT, total_amount DECIMAL(10,2), \"\n", " \"order_date DATE, FOREIGN KEY (customer_id) REFERENCES customers(id)); \"\n", " \"INSERT INTO customers (id, name, country) VALUES (1, 'Alice', 'USA'), (2, 'Bob', 'UK'), \"\n", " \"(3, 'Charlie', 'Canada'), (4, 'David', 'USA'); \"\n", " \"INSERT INTO orders (order_id, customer_id, total_amount, order_date) VALUES \"\n", " \"(101, 1, 500, '2024-01-15'), (102, 2, 300, '2024-01-20'), \"\n", " \"(103, 1, 700, '2024-02-10'), (104, 3, 450, '2024-02-15'), \"\n", " \"(105, 4, 900, '2024-03-05');\"\n", ")\n", "query = (\n", " \"Retrieve the total order amount for each customer, showing only customers from the USA, \"\n", " \"and sort the result by total order amount in descending order.\"\n", ")\n", "\n", "# Construct the prompt\n", "sample_prompt = f\"\"\"Context:\n", "{context}\n", "\n", "Query:\n", "{query}\n", "\n", "Response:\n", "\"\"\"\n", "\n", "logger.info(\"Running inference with deterministic decoding and beam search.\")\n", "generated_sql = run_inference(sample_prompt)\n", "\n", "# Print output in the given format\n", "print(\"Prompt:\")\n", "print(\"Context:\")\n", "print(context)\n", "print(\"\\nQuery:\")\n", "print(query)\n", "print(\"\\nResponse:\")\n", "print(generated_sql)\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "a69f268e-bc69-4633-9c15-4e118c20178e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ LoRA adapter saved at: text2sql_flant5base_finetuned\n", "✅ Fully merged fine-tuned model saved at: text2sql_flant5base_finetuned_full\n" ] } ], "source": [ "import torch\n", "import json\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "from peft import PeftModel\n", "\n", "# Define paths\n", "base_model_name = \"google/flan-t5-base\" # Base model name\n", "lora_model_path = \"text2sql_flant5base_finetuned\" # Folder where LoRA adapter is saved\n", "full_model_output_path = \"text2sql_flant5base_finetuned_full\" # For merged full model\n", "\n", "# Load base model and tokenizer\n", "base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_name, torch_dtype=torch.bfloat16)\n", "tokenizer = AutoTokenizer.from_pretrained(base_model_name)\n", "\n", "# Load fine-tuned LoRA adapter model\n", "lora_model = PeftModel.from_pretrained(base_model, lora_model_path)\n", "\n", "# ✅ Save the LoRA adapter separately (for users who want lightweight adapters)\n", "lora_model.save_pretrained(lora_model_path)\n", "tokenizer.save_pretrained(lora_model_path)\n", "\n", "# ✅ Merge LoRA into the base model to create a fully fine-tuned model\n", "merged_model = lora_model.merge_and_unload()\n", "\n", "# ✅ Save the full fine-tuned model\n", "merged_model.save_pretrained(full_model_output_path)\n", "tokenizer.save_pretrained(full_model_output_path)\n", "\n", "# ✅ Save generation config (optional but recommended for inference settings)\n", "generation_config = {\n", " \"max_new_tokens\": 100,\n", " \"temperature\": 0.1,\n", " \"num_beams\": 5,\n", " \"early_stopping\": True\n", "}\n", "with open(f\"{full_model_output_path}/generation_config.json\", \"w\") as f:\n", " json.dump(generation_config, f)\n", "\n", "print(f\"✅ LoRA adapter saved at: {lora_model_path}\")\n", "print(f\"✅ Fully merged fine-tuned model saved at: {full_model_output_path}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f1c95dfc-6662-44d8-8ecc-bff414fecee5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n", " warnings.warn(\n", "/venv/main/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:629: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.1` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", "2025-03-19 16:51:49,933 - INFO - Running inference with beam search decoding.\n" ] } ], "source": [ "import torch\n", "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n", "import logging\n", "\n", "# Set up logging\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", "logger = logging.getLogger(__name__)\n", "\n", "# Ensure device is set (GPU if available)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Load the fine-tuned model and tokenizer\n", "model_name = \"aarohanverma/text2sql-flan-t5-base-qlora-finetuned\"\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)\n", "tokenizer = AutoTokenizer.from_pretrained(\"aarohanverma/text2sql-flan-t5-base-qlora-finetuned\")\n", "\n", "# Ensure decoder start token is set\n", "if model.config.decoder_start_token_id is None:\n", " model.config.decoder_start_token_id = tokenizer.pad_token_id\n", "\n", "def run_inference(prompt_text: str) -> str:\n", " \"\"\"\n", " Runs inference on the fine-tuned model using beam search with fixes for repetition.\n", " \"\"\"\n", " inputs = tokenizer(prompt_text, return_tensors=\"pt\", truncation=True, max_length=512).to(device)\n", "\n", " generated_ids = model.generate(\n", " input_ids=inputs[\"input_ids\"],\n", " decoder_start_token_id=model.config.decoder_start_token_id, \n", " max_new_tokens=100, \n", " temperature=0.1, \n", " num_beams=5, \n", " repetition_penalty=1.2, \n", " early_stopping=True, \n", " )\n", "\n", " generated_sql = tokenizer.decode(generated_ids[0], skip_special_tokens=True)\n", "\n", " # Post-processing to remove repeated text\n", " generated_sql = generated_sql.split(\";\")[0] + \";\" # Keep only the first valid SQL query\n", "\n", " return generated_sql\n", "\n", "# Example usage:\n", "context = (\n", " \"CREATE TABLE employees (id INT PRIMARY KEY, name VARCHAR(100), department VARCHAR(50), salary INT); \"\n", " \"CREATE TABLE projects (project_id INT PRIMARY KEY, project_name VARCHAR(100), budget INT); \"\n", " \"CREATE TABLE employee_projects (employee_id INT, project_id INT, role VARCHAR(50), \"\n", " \"FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (project_id) REFERENCES projects(project_id)); \"\n", " \"INSERT INTO employees (id, name, department, salary) VALUES \"\n", " \"(1, 'Alice', 'Engineering', 90000), (2, 'Bob', 'Marketing', 70000), \"\n", " \"(3, 'Charlie', 'Engineering', 95000), (4, 'David', 'HR', 60000), (5, 'Eve', 'Engineering', 110000); \"\n", " \"INSERT INTO projects (project_id, project_name, budget) VALUES \"\n", " \"(101, 'AI Research', 500000), (102, 'Marketing Campaign', 200000), (103, 'Cloud Migration', 300000); \"\n", " \"INSERT INTO employee_projects (employee_id, project_id, role) VALUES \"\n", " \"(1, 101, 'Lead Engineer'), (2, 102, 'Marketing Specialist'), (3, 101, 'Engineer'), \"\n", " \"(4, 103, 'HR Coordinator'), (5, 101, 'AI Scientist');\"\n", ")\n", "\n", "query = (\"Find the names of employees who are working on the 'AI Research' project along with their roles.\")\n", "\n", "\n", "\n", "# Construct the prompt\n", "sample_prompt = f\"\"\"Context:\n", "{context}\n", "\n", "Query:\n", "{query}\n", "\n", "Response:\n", "\"\"\"\n", "\n", "logger.info(\"Running inference with beam search decoding.\")\n", "generated_sql = run_inference(sample_prompt)\n", "\n", "print(\"Prompt:\")\n", "print(\"Context:\")\n", "print(context)\n", "print(\"\\nQuery:\")\n", "print(query)\n", "print(\"\\nResponse:\")\n", "print(generated_sql)" ] }, { "cell_type": "code", "execution_count": null, "id": "97425ac4-ad46-4f38-b22d-071e161da20a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }