File size: 39,358 Bytes
aebcbe4
1
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/parquet_explorer.ipynb","timestamp":1754497857381},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754475181338},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754312448728},{"file_id":"https://huggingface.co/datasets/codeShare/chroma_prompts/blob/main/parquet_explorer.ipynb","timestamp":1754310418707},{"file_id":"https://huggingface.co/datasets/codeShare/lora-training-data/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1754223895158},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1747490904984},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1740037333374},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1736477078136},{"file_id":"https://huggingface.co/codeShare/JupyterNotebooks/blob/main/YT-playlist-to-mp3.ipynb","timestamp":1725365086834}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["Download a parquet file to your Google drive and load it from there into this notebook.\n","\n","Parquet files: https://huggingface.co/datasets/codeShare/chroma_prompts/tree/main\n","\n","E621 JSON files: https://huggingface.co/datasets/lodestones/e621-captions/tree/main"],"metadata":{"id":"LeCfcqgiQvCP"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"HFy5aDxM3G7O"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"KYv7Y2gNPW_i"},"outputs":[],"source":["#@markdown Investigate a json file\n","\n","import json\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Initialize lists to store data\n","data = []\n","\n","# Read the .jsonl file line by line\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            # Parse each line as a JSON object\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","\n","# Convert the list of JSON objects to a Pandas DataFrame for easier exploration\n","df = pd.DataFrame(data)\n","\n","# Display basic information about the DataFrame\n","print(\"=== File Overview ===\")\n","print(f\"Number of records: {len(df)}\")\n","print(\"\\nColumn names:\")\n","print(df.columns.tolist())\n","print(\"\\nData types:\")\n","print(df.dtypes)\n","\n","# Display the first few rows\n","print(\"\\n=== First 5 Rows ===\")\n","print(df.head())\n","\n","# Display basic statistics\n","print(\"\\n=== Basic Statistics ===\")\n","print(df.describe(include='all'))\n","\n","# Check for missing values\n","print(\"\\n=== Missing Values ===\")\n","print(df.isnull().sum())\n","\n","# Optional: Display unique values in each column\n","print(\"\\n=== Unique Values per Column ===\")\n","for col in df.columns:\n","    print(f\"{col}: {df[col].nunique()} unique values\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dnIWOPPqSTnw"},"outputs":[],"source":["#@markdown Investigate a json file pt 2\n","\n","import json\n","import pandas as pd\n","import matplotlib.pyplot as plt\n","import seaborn as sns\n","from collections import Counter\n","import numpy as np\n","\n","# Set up plotting style\n","plt.style.use('default')\n","%matplotlib inline\n","\n","# Path to the uploaded .jsonl file\n","#file_path = ''\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","df = pd.DataFrame(data)\n","\n","# 1. Rating Distribution\n","print(\"=== Rating Distribution ===\")\n","rating_counts = df['rating'].value_counts()\n","plt.figure(figsize=(8, 5))\n","sns.barplot(x=rating_counts.index, y=rating_counts.values)\n","plt.title('Distribution of Image Ratings')\n","plt.xlabel('Rating')\n","plt.ylabel('Count')\n","plt.show()\n","print(rating_counts)\n","\n","# 2. Tag Analysis\n","print(\"\\n=== Top 10 Most Common Tags ===\")\n","# Combine all tags into a single list\n","all_tags = []\n","for tags in df['tag_string'].dropna():\n","    all_tags.extend(tags.split())\n","tag_counts = Counter(all_tags)\n","top_tags = pd.DataFrame(tag_counts.most_common(10), columns=['Tag', 'Count'])\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x='Count', y='Tag', data=top_tags)\n","plt.title('Top 10 Most Common Tags')\n","plt.show()\n","print(top_tags)\n","\n","# 3. Image Dimensions Analysis\n","print(\"\\n=== Image Dimensions Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'], df['image_height'], alpha=0.5, s=50)\n","plt.title('Image Width vs. Height')\n","plt.xlabel('Width (pixels)')\n","plt.ylabel('Height (pixels)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Median Width: {df['image_width'].median()}\")\n","print(f\"Median Height: {df['image_height'].median()}\")\n","print(f\"Aspect Ratio (Width/Height) Stats:\\n{df['image_width'].div(df['image_height']).describe()}\")\n","\n","# 4. Score and Voting Analysis\n","print(\"\\n=== Score and Voting Analysis ===\")\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['score'], bins=30, kde=True)\n","plt.title('Distribution of Image Scores')\n","plt.xlabel('Score')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Score Stats:\\n{df['score'].describe()}\")\n","print(f\"\\nCorrelation between Up Score and Down Score: {df['up_score'].corr(df['down_score'])}\")\n","\n","# 5. Summary Length Analysis\n","print(\"\\n=== Summary Length Analysis ===\")\n","df['summary_length'] = df['regular_summary'].dropna().apply(lambda x: len(str(x).split()))\n","plt.figure(figsize=(10, 6))\n","sns.histplot(df['summary_length'], bins=30, kde=True)\n","plt.title('Distribution of Regular Summary Word Counts')\n","plt.xlabel('Word Count')\n","plt.ylabel('Count')\n","plt.show()\n","print(f\"Summary Length Stats:\\n{df['summary_length'].describe()}\")\n","\n","# 6. Missing Data Heatmap\n","print(\"\\n=== Missing Data Heatmap ===\")\n","plt.figure(figsize=(12, 8))\n","sns.heatmap(df.isnull(), cbar=False, cmap='viridis')\n","plt.title('Missing Data Heatmap')\n","plt.show()\n","\n","# 7. Source Platform Analysis\n","print(\"\\n=== Source Platform Analysis ===\")\n","# Extract domain from source URLs\n","df['source_domain'] = df['source'].dropna().str.extract(r'(https?://[^/]+)')\n","source_counts = df['source_domain'].value_counts().head(10)\n","plt.figure(figsize=(10, 6))\n","sns.barplot(x=source_counts.values, y=source_counts.index)\n","plt.title('Top 10 Source Domains')\n","plt.xlabel('Count')\n","plt.ylabel('Domain')\n","plt.show()\n","print(source_counts)\n","\n","# 8. File Size vs. Image Dimensions\n","print(\"\\n=== File Size vs. Image Dimensions ===\")\n","plt.figure(figsize=(10, 6))\n","plt.scatter(df['image_width'] * df['image_height'], df['file_size'], alpha=0.5)\n","plt.title('File Size vs. Image Area')\n","plt.xlabel('Image Area (Width * Height)')\n","plt.ylabel('File Size (bytes)')\n","plt.xscale('log')\n","plt.yscale('log')\n","plt.grid(True)\n","plt.show()\n","print(f\"Correlation between Image Area and File Size: {df['file_size'].corr(df['image_width'] * df['image_height'])}\")"]},{"cell_type":"code","source":["#@markdown  convert E621 JSON to parquet file\n","\n","import json,os\n","import pandas as pd\n","\n","# Path to the uploaded .jsonl file\n","file_path = '' #@param {type:'string'}\n","\n","# Read the .jsonl file into a DataFrame\n","data = []\n","with open(file_path, 'r') as file:\n","    for line in file:\n","        try:\n","            json_obj = json.loads(line.strip())\n","            data.append(json_obj)\n","        except json.JSONDecodeError as e:\n","            print(f\"Error decoding JSON line: {e}\")\n","            continue\n","df = pd.DataFrame(data)\n","\n","# Define columns that likely contain prompts/image descriptions\n","description_columns = [\n","    'regular_summary',\n","    'individual_parts',\n","    'midjourney_style_summary',\n","    'deviantart_commission_request',\n","    'brief_summary'\n","]\n","\n","# Initialize a list to store all descriptions\n","all_descriptions = []\n","\n","# Iterate through each row and collect non-empty descriptions\n","for index, row in df.iterrows():\n","    record_descriptions = []\n","    for col in description_columns:\n","        if pd.notnull(row[col]) and row[col]:  # Check for non-null and non-empty values\n","            record_descriptions.append(f\"{col}: {row[col]}\")\n","    if record_descriptions:\n","        all_descriptions.append({\n","            'id': row['id'],\n","            'descriptions': '; '.join(record_descriptions)  # Join descriptions with semicolon\n","        })\n","\n","# Convert to DataFrame for Parquet\n","output_df = pd.DataFrame(all_descriptions)\n","\n","# Save to Parquet file\n","output_path = '' #@param {type:'string'}\n","output_df.to_parquet(output_path, index=False)\n","os.remove(f'{file_path}')\n","print(f\"\\nDescriptions have been saved to '{output_path}'.\")"],"metadata":{"id":"-NXBRSv4jsUS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Step 1: Mount Google Drive\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","#@markdown paste .parquet file stored on your Google Drive folder to see its characteristics\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Basic exploration of the Parquet file\n","print(\"First 5 rows of the dataset:\")\n","print(df.head())\n","\n","print(\"\\nDataset Info:\")\n","print(df.info())\n","\n","print(\"\\nBasic Statistics:\")\n","print(df.describe())\n","\n","print(\"\\nColumn Names:\")\n","print(df.columns.tolist())\n","\n","print(\"\\nMissing Values:\")\n","print(df.isnull().sum())\n","\n","# Optional: Display number of rows and columns\n","print(f\"\\nShape of the dataset: {df.shape}\")"],"metadata":{"id":"So-PKtbo5AVA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Read contents of a .parquet file\n","\n","# Import pandas\n","import pandas as pd\n","\n","# Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","parquet_column = 'descriptions' #@param {type:'string'}\n","# Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Set pandas display options to show full text without truncation\n","pd.set_option('display.max_colwidth', None)  # Show full content of columns\n","pd.set_option('display.width', None)         # Use full display width\n","\n","# Create sliders for selecting the range of captions\n","#@markdown Caption Range { run: \"auto\", display_mode: \"form\" }\n","start_at = 16814 #@param {type:\"slider\", min:0, max:33147, step:1}\n","range = 247 #@param {type:'slider',min:1,max:1000,step:1}\n","start_index = start_at\n","end_index = start_at + range\n","###@param {type:\"slider\", min:1, max:33148, step:1}\n","\n","include_either_words = '' #@param {type:'string', placeholder:'item1,item2...'}\n","#display_only = True #@param {type:'boolean'}\n","\n","_include_either_words = ''\n","for include_word in include_either_words.split(','):\n","  if include_word.strip()=='':continue\n","  _include_either_words= include_either_words + include_word.lower()+','+include_word.title() +','\n","#-----#\n","_include_either_words = _include_either_words[:len(_include_either_words)-1]\n","\n","\n","# Ensure end_index is greater than start_index and within bounds\n","if end_index <= start_index:\n","    print(\"Error: End index must be greater than start index.\")\n","elif end_index > len(df):\n","    print(f\"Error: End index cannot exceed {len(df)}. Setting to maximum value.\")\n","    end_index = len(df)\n","elif start_index < 0:\n","    print(\"Error: Start index cannot be negative. Setting to 0.\")\n","    start_index = 0\n","\n","# Display the selected range of captions\n","tmp =''\n","\n","categories= ['regular_summary:',';midjourney_style_summary:', 'individual_parts:']\n","\n","print(f\"\\nDisplaying captions from index {start_index} to {end_index-1}:\")\n","for index, caption in df[f'{parquet_column}'][start_index:end_index].items():\n","  for include_word in _include_either_words.split(','):\n","    found = True\n","    if (include_word.strip() in caption) or include_word.strip()=='':\n","      #----#\n","      if not found: continue\n","      tmp= caption + '\\n\\n'\n","      for category in categories:\n","        tmp = tmp.replace(f'{category}',f'\\n\\n{category}\\n')\n","      #----#\n","      print(f'Index {index}: {tmp}')\n"],"metadata":{"id":"wDhyb8M_7pkD"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["\n","#@markdown Build a dataset for training using a .parquet file\n","\n","num_dataset_items = 200 #@param {type:'slider',max:1000}\n","\n","\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets pandas pillow requests\n","\n","# Step 2: Import required libraries\n","import pandas as pd\n","from datasets import Dataset\n","from PIL import Image\n","import requests\n","from io import BytesIO\n","import numpy as np\n","\n","# Step 3: Define the path to the Parquet file\n","file_path = '' #@param {type:'string'}\n","\n","# Step 4: Read the Parquet file\n","df = pd.read_parquet(file_path)\n","\n","# Step 5: Randomly select 300 rows to account for potential image loading failures\n","df_sample = df.sample(n=math.floor(num_dataset_items*1.2), random_state=42).reset_index(drop=True)\n","\n","# Step 6: Function to download, resize, and process images\n","def load_and_resize_image_from_url(url, max_size=(1024, 1024)):\n","    try:\n","        response = requests.get(url, timeout=10)\n","        response.raise_for_status()  # Raise an error for bad status codes\n","        img = Image.open(BytesIO(response.content)).convert('RGB')\n","        # Resize image to fit within 1024x1024 while maintaining aspect ratio\n","        img.thumbnail(max_size, Image.Resampling.LANCZOS)\n","        return img\n","    except Exception as e:\n","        print(f\"Error loading image from {url}: {e}\")\n","        return None\n","\n","# Step 7: Create lists for images and captions\n","images = []\n","texts = []\n","\n","for index, row in df_sample.iterrows():\n","    if len(images) >= num_dataset_items:  # Stop once we have 200 valid images\n","        break\n","    url = row['url']\n","    caption = row['original_caption'] + ', ' + row['vlm_caption'].replace('This image displays:','').replace('This image displays','')\n","\n","    # Load and resize image\n","    img = load_and_resize_image_from_url(url)\n","    if img is not None:\n","        images.append(img)\n","        texts.append(caption)\n","    else:\n","        print(f\"Skipping row {index} due to image loading failure.\")\n","\n","# Step 8: Check if we have enough images\n","if len(images) < num_dataset_items:\n","    print(f\"Warning: Only {len(images)} images were successfully loaded.\")\n","else:\n","    # Truncate to exactly 200 if we have more\n","    images = images[:num_dataset_items]\n","    texts = texts[:num_dataset_items]\n","\n","# Step 9: Create a Hugging Face Dataset\n","dataset = Dataset.from_dict({\n","    'image': images,\n","    'text': texts\n","})\n","\n","# Step 10: Verify the dataset\n","print(dataset)\n","\n","# Step 11: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)\n","\n","# Optional: Save the dataset to disk (if needed)\n","#dataset.save_to_disk('/content/drive/MyDrive/Chroma prompts/custom_dataset')"],"metadata":{"id":"XZvpJ5zw0fzR"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["dataset_name=''#@param {type:'string'}\n","\n","if dataset_name.strip()=='':\n","  dataset_name='my_dataset'\n","\n","\n","dataset.save_to_disk(f'/content/drive/MyDrive/{dataset_name}')\n","\n","\n"],"metadata":{"id":"iTyxazlM1OAn"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"sQmoYDLHUXxF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"jFnWBQHa142R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown load two .parquet datasets for merging\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset1_path = '' #@param {type: 'string'}\n","\n","dataset2_path = '' #@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n","    dataset1 = load_from_disk(dataset1_path)\n","    dataset2 = load_from_disk(dataset2_path)\n","    print(\"Dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 6: Verify the dataset\n","print(dataset1)\n","print(dataset2)\n","\n","# Step 7: Example of accessing an image and text\n","#print(\"\\nExample of accessing first item:\")\n","#print(\"Text:\", redcaps_dataset['text'][0])\n","#print(\"Image type:\", type(dataset['image'][0]))\n","#print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"LoCcBJqs4pzL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"AmLgPcrdRqCJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"X5HLZqjTRt7L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Merge the two .parquet files into one\n","\n","# Step 1: Import required libraries\n","from datasets import load_from_disk, concatenate_datasets\n","from google.colab import drive\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","drive.mount('/content/drive')\n","\n","# Step 3: Define paths for the datasets\n","dataset1_path = '' #@param {type:'string'}\n","dataset2_path = '' #@param {type:'string'}\n","merged_dataset_path = ''  #@param {type:'string'}\n","\n","# Step 4: Load the datasets\n","try:\n","    dataset1 = load_from_disk(dataset1_path)\n","    dataset2 = load_from_disk(dataset2_path)\n","    print(\"Datasets loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading datasets: {e}\")\n","    raise\n","\n","# Step 5: Verify the datasets\n","print(\"Dataset 1:\", dataset1)\n","print(\"Dataset 2:\", dataset2)\n","\n","# Step 6: Merge the datasets\n","try:\n","    merged_dataset = concatenate_datasets([dataset1, dataset2])\n","    print(\"Datasets merged successfully!\")\n","except Exception as e:\n","    print(f\"Error merging datasets: {e}\")\n","    raise\n","\n","# Step 7: Verify the merged dataset\n","print(\"Merged Dataset:\", merged_dataset)\n","\n","# Step 8: Save the merged dataset to Google Drive\n","try:\n","    merged_dataset.save_to_disk(merged_dataset_path)\n","    print(f\"Merged dataset saved successfully to {merged_dataset_path}\")\n","except Exception as e:\n","    print(f\"Error saving merged dataset: {e}\")\n","    raise\n","\n","# Step 9: Optional - Verify the saved dataset by loading it back\n","try:\n","    loaded_merged_dataset = load_from_disk(merged_dataset_path)\n","    print(\"Saved merged dataset loaded successfully for verification:\")\n","    print(loaded_merged_dataset)\n","except Exception as e:\n","    print(f\"Error loading saved merged dataset: {e}\")\n","    raise"],"metadata":{"id":"HF_cmJu1EMJV"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["🔄 Change to T4 Runtime  : Past this point you can train a LoRa on the Dataset , but you need to change the runtime to T4 for that first\n","\n","See original file at:https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3_(4B)-Vision.ipynb"],"metadata":{"id":"0Kmf1OP6Se4Q"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"id":"ESLqweKz4xM_"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Test the merged dataset\n","\n","# Step 1: Install required libraries (if not already installed)\n","# !pip install datasets\n","\n","# Step 2: Mount Google Drive (only needed in Google Colab)\n","#from google.colab import drive\n","#drive.mount('/content/drive')\n","\n","# Step 3: Import required library\n","from datasets import load_from_disk\n","\n","# Step 4: Define the path to the saved dataset on Google Drive\n","dataset_path = ''#@param {type:'string'}\n","\n","# Step 5: Load the dataset\n","try:\n","    dataset = load_from_disk(dataset_path)\n","    print(\"Dataset loaded successfully!\")\n","except Exception as e:\n","    print(f\"Error loading dataset: {e}\")\n","    raise\n","\n","# Step 6: Verify the dataset\n","print(dataset)\n","\n","# Step 7: Example of accessing an image and text\n","print(\"\\nExample of accessing first item:\")\n","print(\"Text:\", dataset['text'][0])\n","print(\"Image type:\", type(dataset['image'][0]))\n","print(\"Image size:\", dataset['image'][0].size)"],"metadata":{"id":"xUA37h2APkWc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display an image from the dataset\n","index = 85 #@param {type:'slider',max:200}\n","dataset['image'][index]\n","\n","\n"],"metadata":{"id":"4hCnrtv6R9B1"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@markdown Display matching prompt text caption\n","dataset['text'][index]"],"metadata":{"id":"MSetS3MCR2qJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"K9CBpiISFa6C"},"source":["To format the dataset, all vision fine-tuning tasks should follow this format:\n","\n","```python\n","[\n","    {\n","        \"role\": \"user\",\n","        \"content\": [\n","            {\"type\": \"text\", \"text\": instruction},\n","            {\"type\": \"image\", \"image\": sample[\"image\"]},\n","        ],\n","    },\n","    {\n","        \"role\": \"user\",\n","        \"content\": [\n","            {\"type\": \"text\", \"text\": instruction},\n","            {\"type\": \"image\", \"image\": sample[\"image\"]},\n","        ],\n","    },\n","]\n","```"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oPXzJZzHEgXe"},"outputs":[],"source":["#@markdown Convert the merged dataset to the 'correct' format for training the Gemma LoRa model\n","\n","instruction = \"Describe this image.\" # <- Select the prompt for your use case here\n","\n","def convert_to_conversation(sample):\n","    conversation = [\n","        {\n","            \"role\": \"user\",\n","            \"content\": [\n","                {\"type\": \"text\", \"text\": instruction},\n","                {\"type\": \"image\", \"image\": sample[\"image\"]},\n","            ],\n","        },\n","        {\"role\": \"assistant\", \"content\": [{\"type\": \"text\", \"text\": sample[\"text\"]}]},\n","    ]\n","    return {\"messages\": conversation}\n","pass"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gFW2qXIr7Ezy"},"outputs":[],"source":["converted_dataset = [convert_to_conversation(sample) for sample in dataset]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gGFzmplrEy9I"},"outputs":[],"source":["converted_dataset[0]"]},{"cell_type":"markdown","metadata":{"id":"529CsYil1qc6"},"source":["### Installation"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"9vJOucOw1qc6"},"outputs":[],"source":["%%capture\n","import os\n","if \"COLAB_\" not in \"\".join(os.environ.keys()):\n","    !pip install unsloth\n","else:\n","    # Do this only in Colab notebooks! Otherwise use pip install unsloth\n","    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo\n","    !pip install sentencepiece protobuf \"datasets>=3.4.1,<4.0.0\" \"huggingface_hub>=0.34.0\" hf_transfer\n","    !pip install --no-deps unsloth"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QmUBVEnvCDJv"},"outputs":[],"source":["from unsloth import FastVisionModel # FastLanguageModel for LLMs\n","import torch\n","\n","# 4bit pre quantized models we support for 4x faster downloading + no OOMs.\n","fourbit_models = [\n","    \"unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit\", # Llama 3.2 vision support\n","    \"unsloth/Llama-3.2-11B-Vision-bnb-4bit\",\n","    \"unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit\", # Can fit in a 80GB card!\n","    \"unsloth/Llama-3.2-90B-Vision-bnb-4bit\",\n","\n","    \"unsloth/Pixtral-12B-2409-bnb-4bit\",              # Pixtral fits in 16GB!\n","    \"unsloth/Pixtral-12B-Base-2409-bnb-4bit\",         # Pixtral base model\n","\n","    \"unsloth/Qwen2-VL-2B-Instruct-bnb-4bit\",          # Qwen2 VL support\n","    \"unsloth/Qwen2-VL-7B-Instruct-bnb-4bit\",\n","    \"unsloth/Qwen2-VL-72B-Instruct-bnb-4bit\",\n","\n","    \"unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit\",      # Any Llava variant works!\n","    \"unsloth/llava-1.5-7b-hf-bnb-4bit\",\n","] # More models at https://huggingface.co/unsloth\n","\n","model, processor = FastVisionModel.from_pretrained(\n","    \"unsloth/gemma-3-4b-pt\",\n","    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.\n","    use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for long context\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bEzvL7Sm1CrS"},"outputs":[],"source":["from unsloth import get_chat_template\n","\n","processor = get_chat_template(\n","    processor,\n","    \"gemma-3\"\n",")"]},{"cell_type":"markdown","metadata":{"id":"SXd9bTZd1aaL"},"source":["We now add LoRA adapters for parameter efficient fine-tuning, allowing us to train only 1% of all model parameters efficiently.\n","\n","**[NEW]** We also support fine-tuning only the vision component, only the language component, or both. Additionally, you can choose to fine-tune the attention modules, the MLP layers, or both!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6bZsfBuZDeCL"},"outputs":[],"source":["model = FastVisionModel.get_peft_model(\n","    model,\n","    finetune_vision_layers     = True, # False if not finetuning vision layers\n","    finetune_language_layers   = True, # False if not finetuning language layers\n","    finetune_attention_modules = True, # False if not finetuning attention layers\n","    finetune_mlp_modules       = True, # False if not finetuning MLP layers\n","\n","    r = 16,                           # The larger, the higher the accuracy, but might overfit\n","    lora_alpha = 16,                  # Recommended alpha == r at least\n","    lora_dropout = 0,\n","    bias = \"none\",\n","    random_state = 3408,\n","    use_rslora = False,               # We support rank stabilized LoRA\n","    loftq_config = None,               # And LoftQ\n","    target_modules = \"all-linear\",    # Optional now! Can specify a list if needed\n","    modules_to_save=[\n","        \"lm_head\",\n","        \"embed_tokens\",\n","    ],\n",")"]},{"cell_type":"markdown","metadata":{"id":"FecKS-dA82f5"},"source":["Before fine-tuning, let us evaluate the base model's performance. We do not expect strong results, as it has not encountered this chat template before."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vcat4UxA81vr"},"outputs":[],"source":["FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","image = dataset[2][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n","                        use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"idAEIeSQ3xdS"},"source":["<a name=\"Train\"></a>\n","### Train the model\n","Now let's use Huggingface TRL's `SFTTrainer`! More docs here: [TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`. We also support TRL's `DPOTrainer`!\n","\n","We use our new `UnslothVisionDataCollator` which will help in our vision finetuning setup."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"95_Nn-89DhsL"},"outputs":[],"source":["from unsloth.trainer import UnslothVisionDataCollator\n","from trl import SFTTrainer, SFTConfig\n","\n","FastVisionModel.for_training(model) # Enable for training!\n","\n","trainer = SFTTrainer(\n","    model=model,\n","    train_dataset=converted_dataset,\n","    processing_class=processor.tokenizer,\n","    data_collator=UnslothVisionDataCollator(model, processor),\n","    args = SFTConfig(\n","        per_device_train_batch_size = 1,\n","        gradient_accumulation_steps = 4,\n","        gradient_checkpointing = True,\n","\n","        # use reentrant checkpointing\n","        gradient_checkpointing_kwargs = {\"use_reentrant\": False},\n","        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper\n","        warmup_ratio = 0.03,\n","        #max_steps = 30,\n","        num_train_epochs = 5,          # Set this instead of max_steps for full training runs\n","        learning_rate = 2e-4,\n","        logging_steps = 1,\n","        save_strategy=\"steps\",\n","        optim = \"adamw_torch_fused\",\n","        weight_decay = 0.01,\n","        lr_scheduler_type = \"cosine\",\n","        seed = 3407,\n","        output_dir = \"outputs\",\n","        report_to = \"none\",             # For Weights and Biases\n","\n","        # You MUST put the below items for vision finetuning:\n","        remove_unused_columns = False,\n","        dataset_text_field = \"\",\n","        dataset_kwargs = {\"skip_prepare_dataset\": True},\n","        max_length = 2048,\n","    )\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"2ejIt2xSNKKp"},"outputs":[],"source":["# @title Show current memory stats\n","gpu_stats = torch.cuda.get_device_properties(0)\n","start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n","print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n","print(f\"{start_gpu_memory} GB of memory reserved.\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yqxqAZ7KJ4oL"},"outputs":[],"source":["trainer_stats = trainer.train()\n"]},{"cell_type":"code","execution_count":null,"metadata":{"cellView":"form","id":"pCqnaKmlO1U9"},"outputs":[],"source":["# @title Show final memory and time stats\n","used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n","used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n","used_percentage = round(used_memory / max_memory * 100, 3)\n","lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n","print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n","print(\n","    f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\"\n",")\n","print(f\"Peak reserved memory = {used_memory} GB.\")\n","print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n","print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n","print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"]},{"cell_type":"markdown","metadata":{"id":"ekOmTR1hSNcr"},"source":["<a name=\"Inference\"></a>\n","### Inference\n","Let's run the model! You can modify the instruction and input—just leave the output blank.\n","\n","We'll use the best hyperparameters for inference on Gemma: `top_p=0.95`, `top_k=64`, and `temperature=1.0`."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kR3gIAX-SM2q"},"outputs":[],"source":["FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","image = dataset[10][\"image\"]\n","instruction = \"Describe this image.\"\n","\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n","                        use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"code","source":["# Step 1: Import required libraries\n","from PIL import Image\n","import io\n","import torch\n","from google.colab import files  # For file upload in Colab\n","\n","# Step 2: Assume model and processor are already loaded and configured\n","FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","# Step 3: Upload image from user\n","print(\"Please upload an image file (e.g., .jpg, .png):\")\n","uploaded = files.upload()  # Opens a file upload widget in Colab\n","\n","# Step 4: Load the uploaded image\n","if not uploaded:\n","    raise ValueError(\"No file uploaded. Please upload an image.\")\n","\n","# Get the first uploaded file\n","file_name = list(uploaded.keys())[0]\n","try:\n","    image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')\n","except Exception as e:\n","    raise ValueError(f\"Error loading image: {e}\")\n","\n","# Step 5: Define the instruction\n","instruction = \"Describe this image.\"\n","\n","# Step 6: Prepare messages for the model\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [{\"type\": \"image\"}, {\"type\": \"text\", \"text\": instruction}],\n","    }\n","]\n","\n","# Step 7: Apply chat template and prepare inputs\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","# Step 8: Generate output with text streaming\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor, skip_prompt=True)\n","result = model.generate(\n","    **inputs,\n","    streamer=text_streamer,\n","    max_new_tokens=512,\n","    use_cache=True,\n","    temperature=1.0,\n","    top_p=0.95,\n","    top_k=64\n",")"],"metadata":{"id":"oOyy5FUh8fBi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"uMuVrWbjAzhc"},"source":["<a name=\"Save\"></a>\n","### Saving, loading finetuned models\n","To save the final model as LoRA adapters, use Hugging Face’s `push_to_hub` for online saving, or `save_pretrained` for local storage.\n","\n","**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"upcOlWe7A1vc"},"outputs":[],"source":["model.save_pretrained(\"lora_model\")  # Local saving\n","processor.save_pretrained(\"lora_model\")\n","# model.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving\n","# processor.push_to_hub(\"your_name/lora_model\", token = \"...\") # Online saving"]},{"cell_type":"markdown","metadata":{"id":"AEEcJ4qfC7Lp"},"source":["Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MKX_XKs_BNZR"},"outputs":[],"source":["if False:\n","    from unsloth import FastVisionModel\n","\n","    model, processor = FastVisionModel.from_pretrained(\n","        model_name=\"lora_model\",  # YOUR MODEL YOU USED FOR TRAINING\n","        load_in_4bit=True,  # Set to False for 16bit LoRA\n","    )\n","    FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","FastVisionModel.for_inference(model)  # Enable for inference!\n","\n","sample = dataset[1]\n","image = sample[\"image\"].convert(\"RGB\")\n","messages = [\n","    {\n","        \"role\": \"user\",\n","        \"content\": [\n","            {\n","                \"type\": \"text\",\n","                \"text\": sample[\"text\"],\n","            },\n","            {\n","                \"type\": \"image\",\n","            },\n","        ],\n","    },\n","]\n","input_text = processor.apply_chat_template(messages, add_generation_prompt=True)\n","inputs = processor(\n","    image,\n","    input_text,\n","    add_special_tokens=False,\n","    return_tensors=\"pt\",\n",").to(\"cuda\")\n","\n","from transformers import TextStreamer\n","\n","text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)\n","_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128,\n","                   use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)"]},{"cell_type":"markdown","metadata":{"id":"f422JgM9sdVT"},"source":["### Saving to float16 for VLLM\n","\n","We also support saving to `float16` directly. Select `merged_16bit` for float16. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens."]}]}