This notebook creates captions from images using a lora adaptation of google gemma 3 LLM.  This lora is very basic as it has only been trained in 400 images of reddit posts and e621 NSFW posts over 5 epochs.

Created by Adcom: https://tensor.art/u/754389913230900026

### Installation

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [None]:
if True:
    from unsloth import FastVisionModel

    model, processor = FastVisionModel.from_pretrained(
        model_name='codeShare/flux_chroma_image_captioner',  # YOUR MODEL YOU USED FOR TRAINING
        load_in_4bit=True,  # Set to False for 16bit LoRA
    )
    FastVisionModel.for_inference(model)  # Enable for inference!

In [3]:
from unsloth import get_chat_template

processor = get_chat_template(
    processor,
    "gemma-3"
)

A prompt to upload an image for processing will appear when running this cell

In [None]:
# Step 1: Import required libraries
from PIL import Image
import io
import torch
from google.colab import files  # For file upload in Colab

# Step 2: Assume model and processor are already loaded and configured
FastVisionModel.for_inference(model)  # Enable for inference!

# Step 3: Upload image from user
print("Please upload an image file (e.g., .jpg, .png):")
uploaded = files.upload()  # Opens a file upload widget in Colab

# Step 4: Load the uploaded image
if not uploaded:
    raise ValueError("No file uploaded. Please upload an image.")

# Get the first uploaded file
file_name = list(uploaded.keys())[0]
try:
    image = Image.open(io.BytesIO(uploaded[file_name])).convert('RGB')
except Exception as e:
    raise ValueError(f"Error loading image: {e}")

# Step 5: Define the instruction
instruction = "Describe this image."

# Step 6: Prepare messages for the model
messages = [
    {
        "role": "user",
        "content": [{"type": "image"}, {"type": "text", "text": instruction}],
    }
]

# Step 7: Apply chat template and prepare inputs
input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")

# Step 8: Generate output with text streaming
from transformers import TextStreamer

text_streamer = TextStreamer(processor, skip_prompt=True)
result = model.generate(
    **inputs,
    streamer=text_streamer,
    max_new_tokens=512,
    use_cache=True,
    temperature=1.0,
    top_p=0.95,
    top_k=64
)

<---- Upload a set if images to /content/ prior to running this cell.  You can also open a .zip file and rename the folder with images as '/content/input'

In [None]:
# Step 1: Import required libraries
from PIL import Image
import torch
import os
from pathlib import Path

# Step 2: Assume model and processor are already loaded and configured
FastVisionModel.for_inference(model)  # Enable for inference!

# Step 3: Define input and output directories
input_dirs = ['/content/', '/content/input/']
output_dir = '/content/output/'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Step 4: Define supported image extensions
image_extensions = {'.jpg', '.webp', '.jpeg', '.png', '.bmp', '.gif'}

# Step 5: Collect all image files from input directories
image_files = []
for input_dir in input_dirs:
    if os.path.exists(input_dir):
        for file in Path(input_dir).rglob('*'):
            if file.suffix.lower() in image_extensions:
                image_files.append(file)
    else:
        print(f"Directory {input_dir} does not exist, skipping...")

if not image_files:
    raise ValueError("No images found in /content/ or /content/input/")

# Step 6: Define the instruction
instruction = "Describe this image."

# Step 7: Process each image
for image_path in image_files:
    try:
        # Load image
        image = Image.open(image_path).convert('RGB')

        # Prepare messages for the model
        messages = [
            {
                "role": "user",
                "content": [{"type": "image"}, {"type": "text", "text": instruction}],
            }
        ]

        # Apply chat template and prepare inputs
        input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
        inputs = processor(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cuda")

        # Generate output without streaming
        print(f"\nProcessing {image_path.name}...")
        result = model.generate(
            **inputs,
            max_new_tokens=512,
            use_cache=True,
            temperature=1.0,
            top_p=0.95,
            top_k=64
        )

        # Decode the generated text
        caption = processor.decode(result[0], skip_special_tokens=True).strip()

        # Print caption with extra whitespace for easy selection
        print(f"\n=== Caption for {image_path.name} ===\n\n{caption}\n\n====================\n")

        # Save image and caption
        output_image_path = os.path.join(output_dir, image_path.name)
        output_caption_path = os.path.join(output_dir, f"{image_path.stem}.txt")

        # Copy original image to output directory
        image.save(output_image_path)

        # Save caption to text file
        with open(output_caption_path, 'w') as f:
            f.write(caption)

        print(f"Saved image and caption for {image_path.name}")

        # Delete the original image if it's in /content/ (but not /content/input/)
        if str(image_path).startswith('/content/') and not str(image_path).startswith('/content/input/'):
            try:
                os.remove(image_path)
                print(f"Deleted original image: {image_path}")
            except Exception as e:
                print(f"Error deleting {image_path}: {e}")

    except Exception as e:
        print(f"Error processing {image_path.name}: {e}")

print(f"\nProcessing complete. Output saved to {output_dir}")

In [None]:
# @markdown ðŸ’¾ Create .zip file of output to /content/
output_filename ='' #@param {type:'string'}
if output_filename.strip()=='':
  output_filename = 'chroma_prompts.zip'
#-----#
import shutil
shutil.make_archive('chroma_prompts', 'zip', 'output')



In [None]:

# @markdown ðŸ§¹Clear all images/.txt files/.zip files from /content/
import os
from pathlib import Path

# Define the directory to clean
directory_to_clean = '/content/'

# Define supported image and text extensions
extensions_to_delete = {'.zip','.webp' ,'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.txt'}

# Iterate through files in the directory and delete those with specified extensions
for file in Path(directory_to_clean).iterdir():
    if file.suffix.lower() in extensions_to_delete:
        try:
            os.remove(file)
            print(f"Deleted: {file}")
        except Exception as e:
            print(f"Error deleting {file}: {e}")

print(f"\nCleaning of {directory_to_clean} complete.")