diff --git "a/how_to_finetune_paligemma_on_detection_dataset.ipynb" "b/how_to_finetune_paligemma_on_detection_dataset.ipynb" new file mode 100644--- /dev/null +++ "b/how_to_finetune_paligemma_on_detection_dataset.ipynb" @@ -0,0 +1,1754 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)\n", + "\n", + "# Fine-tune PaliGemma on Object Detection Dataset\n", + "\n", + "---\n", + "\n", + "[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)\n", + "[![Roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)](https://blog.roboflow.com/how-to-fine-tune-paligemma/)\n", + "[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=OMBmVInx68M)\n", + "\n", + "PaliGemma is an open vision-language model (VLM) inspired by PaLI-3, built with\n", + "open components, such as\n", + "the [SigLIP vision model](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/image_text/SigLIP_demo.ipynb)\n", + "and\n", + "the [Gemma language model](https://ai.google.dev/gemma).\n", + "PaliGemma is designed as a versatile model for transfer to a wide range of\n", + "vision-language tasks such as image and short video caption, visual question\n", + "answering, text reading, object detection and object segmentation. Together with\n", + "the pretrained and transfer checkpoints at multiple resolutions, we provide a\n", + "checkpoint transferred to a mixture of tasks that can be used for off-the-shelf\n", + "exploration.\n", + "\n", + "This notebook is an extension of the [official notebook](https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb) prepared by Google Research.\n", + "\n", + "![PaliGemma model](https://storage.cloud.google.com/com-roboflow-marketing/notebooks/examples/paligemma.png)\n", + "\n", + "This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n", + "\n", + "To make it runnable on a T4 colab runtime with 16GB HBM and 12GB RAM, we opt to only finetune the attention layers of the language model and freeze the other parameters." + ], + "metadata": { + "id": "4LqvmtZPzyY1" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Setup" + ], + "metadata": { + "id": "lBp3Czz3GBmc" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Get access to PaliGemma\n", + "\n", + "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n", + "\n", + "1. Log in to [`Kaggle`](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n", + "1. Go to the [`PaliGemma Model Card`](https://www.kaggle.com/models/google/paligemma/) and click `Request Access`.\n", + "1. Complete the consent form and accept the terms and conditions." + ], + "metadata": { + "id": "4ohXT9pQFjZs" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Configure your API keys\n", + "\n", + "To use PaliGemma, you need to provide your Kaggle username, Kaggle API key, and Roboflow API key. Follow these steps:\n", + "\n", + "- Open your [`Kaggle Settings`](https://www.kaggle.com/settings) page. Click `Create New Token`. This will download a `kaggle.json` file containing your API credentials.\n", + "- Go to your [`Roboflow Settings`](https://app.roboflow.com/settings/api) page. Click `Copy`. This will place your private key in the clipboard.\n", + "- In Colab, go to the left pane and click on `Secrets` (🔑).\n", + " - Store Kaggle Username under the name `KAGGLE_USERNAME`.\n", + " - Store Kaggle API Key under the name `KAGGLE_KEY`.\n", + " - Store Roboflow API Key under the name `ROBOFLOW_API_KEY`." + ], + "metadata": { + "id": "ADTkh-2y_9Yv" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Select the runtime\n", + "\n", + "Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `T4 GPU`, and then click `Save`." + ], + "metadata": { + "id": "4wyojKiG_hX9" + } + }, + { + "cell_type": "code", + "source": [ + "!nvidia-smi" + ], + "metadata": { + "id": "O_8BLW6R_x-z", + "outputId": "f2c36369-fd26-4b58-c93f-4ef2aebc1764", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Thu May 30 08:04:49 2024 \n", + "+---------------------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n", + "|-----------------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|=========================================+======================+======================|\n", + "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", + "| N/A 67C P8 10W / 70W | 0MiB / 15360MiB | 0% Default |\n", + "| | | N/A |\n", + "+-----------------------------------------+----------------------+----------------------+\n", + " \n", + "+---------------------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=======================================================================================|\n", + "| No running processes found |\n", + "+---------------------------------------------------------------------------------------+\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Download dataset from Roboflow Universe\n", + "\n", + "To fine-tune PaliGemma, prepare your dataset in JSONL format. You can use Roboflow to easily convert any dataset into this format." + ], + "metadata": { + "id": "FMlw3ru1YvLg" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -q roboflow\n", + "!pip install -q git+https://github.com/roboflow/supervision.git" + ], + "metadata": { + "id": "Wtvz4QZ9YuG8", + "outputId": "7ca5c1c8-4aa1-442c-c942-5cdc15fc277a", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.5/75.5 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m158.3/158.3 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m178.7/178.7 kB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.8/58.8 kB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.1/49.1 MB\u001b[0m \u001b[31m30.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.5/54.5 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from google.colab import userdata\n", + "from roboflow import Roboflow\n", + "\n", + "ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')\n", + "rf = Roboflow(api_key=ROBOFLOW_API_KEY)\n", + "\n", + "project = rf.workspace(\"roboflow-jvuqo\").project(\"number-ops-j1426\")\n", + "version = project.version(1)\n", + "dataset = version.download(\"paligemma\")" + ], + "metadata": { + "id": "TGDFTYVnY4zn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!head -n 5 {dataset.location}/dataset/_annotations.train.jsonl" + ], + "metadata": { + "id": "WLhSenP5AtQe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "!head -n 5 {dataset.location}/dataset/_annotations.valid.jsonl" + ], + "metadata": { + "id": "YwHY21ABA0WG" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import cv2\n", + "import json\n", + "import supervision as sv\n", + "from typing import List\n", + "\n", + "def read_n_lines(file_path: str, n: int) -> List[str]:\n", + " with open(file_path, 'r') as file:\n", + " lines = [next(file).strip() for _ in range(n)]\n", + " return lines\n", + "\n", + "images = []\n", + "lines = read_n_lines(f\"{dataset.location}/dataset/_annotations.train.jsonl\", 25)\n", + "first = json.loads(lines[0])\n", + "\n", + "CLASSES = first.get('prefix').replace(\"detect \", \"\").split(\" ; \")\n", + "\n", + "for line in lines:\n", + " data = json.loads(line)\n", + " image = cv2.imread(f\"{dataset.location}/dataset/{data.get('image')}\")\n", + " (h, w, _) = image.shape\n", + " detections = sv.Detections.from_lmm(\n", + " lmm='paligemma',\n", + " result=data.get('suffix'),\n", + " resolution_wh=(w, h),\n", + " classes=CLASSES)\n", + "\n", + " image = sv.BoundingBoxAnnotator(thickness=4).annotate(image, detections)\n", + " image = sv.LabelAnnotator(text_scale=2, text_thickness=4).annotate(image, detections)\n", + " images.append(image)\n", + "\n", + "sv.plot_images_grid(images, (5, 5))" + ], + "metadata": { + "id": "6ihTTuTd747l" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Fetch the `big_vision` repository and install related dependencies\n", + "\n", + "Download the `big_vision` repository to your Colab notebook from GitHub and install dependencies related to `big_vision` by running the following code." + ], + "metadata": { + "id": "eg3sqaoPFS3W" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "DfxKb3F839Ks" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# TPUs with\n", + "if \"COLAB_TPU_ADDR\" in os.environ:\n", + " raise \"It seems you are using Colab with remote TPUs which is not supported.\"\n", + "\n", + "# Fetch big_vision repository if python doesn't know about it and install\n", + "# dependencies needed for this notebook.\n", + "if not os.path.exists(\"big_vision_repo\"):\n", + " !git clone --quiet --branch=main --depth=1 \\\n", + " https://github.com/google-research/big_vision big_vision_repo\n", + "\n", + "# Append big_vision code to python import path\n", + "if \"big_vision_repo\" not in sys.path:\n", + " sys.path.append(\"big_vision_repo\")\n", + "\n", + "# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n", + "!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\"\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Set environment variables\n", + "\n", + "Set the environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`." + ], + "metadata": { + "id": "YU2fs7d0F1Fo" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zGLIp1Cx3_CX" + }, + "outputs": [], + "source": [ + "import os\n", + "from google.colab import userdata\n", + "\n", + "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n", + "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n", + "\n", + "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Import JAX and other dependencies\n", + "\n", + "Import JAX and other dependencies required for PaliGemma, like TensorFlow and NumPy." + ], + "metadata": { + "id": "zx3dj5NzG93I" + } + }, + { + "cell_type": "code", + "source": [ + "import base64\n", + "import functools\n", + "import html\n", + "import io\n", + "import os\n", + "import warnings\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "import ml_collections\n", + "\n", + "import tensorflow as tf\n", + "import sentencepiece\n", + "\n", + "from IPython.core.display import display, HTML\n", + "from PIL import Image\n", + "from tqdm.notebook import tqdm\n", + "\n", + "# Import model definition from big_vision\n", + "from big_vision.models.proj.paligemma import paligemma\n", + "from big_vision.trainers.proj.paligemma import predict_fns\n", + "\n", + "# Import big vision utilities\n", + "import big_vision.datasets.jsonl\n", + "import big_vision.utils\n", + "import big_vision.sharding\n", + "\n", + "# Don't let TF use the GPU or TPUs\n", + "tf.config.set_visible_devices([], \"GPU\")\n", + "tf.config.set_visible_devices([], \"TPU\")\n", + "\n", + "backend = jax.lib.xla_bridge.get_backend()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX platform: {backend.platform}\")\n", + "print(f\"JAX devices: {jax.device_count()}\")" + ], + "metadata": { + "id": "OlWELn2FHB22" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Download and configure the model\n", + "\n", + "In this step, you'll download the model checkpoint and configure it so that you can fine-tune it later on. This step shows you how to move model parameters into TPU memory, which is useful for fine-tuning models on devices with limited resources." + ], + "metadata": { + "id": "_PHhkFGuHMFF" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Download the model checkpoint\n", + "\n", + "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n", + "\n", + "Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." + ], + "metadata": { + "id": "9wU_sHbGHQka" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gQNOTfF24AV4" + }, + "outputs": [], + "source": [ + "import os\n", + "import kagglehub\n", + "\n", + "MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n", + "if not os.path.exists(MODEL_PATH):\n", + " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", + " # Note: kaggle archive contains the same checkpoint in multiple formats.\n", + " # Download only the float16 model.\n", + " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n", + " print(f\"Model path: {MODEL_PATH}\")\n", + "\n", + "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", + "if not os.path.exists(TOKENIZER_PATH):\n", + " print(\"Downloading the model tokenizer...\")\n", + " !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}\n", + " print(f\"Tokenizer path: {TOKENIZER_PATH}\")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Configure the model\n", + "\n", + "It's time to actually start configuring the model that you're going to use.\n", + "\n", + "For this notebook, you need to be able to fit your model onto a T4 GPU. Having a limited resource like space constraints means that you have to be mindful of how your model is configured.\n", + "\n", + "If you fine-tune every parameter, your model won't be able to run in the notebook environment. As a result, in this part of the notebook, you'll configure your model so that it has the ability to freeze some of the parameters, and only fine-tune the parameters that really need to be fine-tuned for the model to give you accurate results. In LLMs, parameters are said to be *frozen* when they are no longer actively being used to train the model.\n", + "\n", + "In order to configure your model, you need to:\n", + "\n", + "* Initialize the `model_config` as a [`FrozenConfigDict`](https://github.com/google/ml_collections/tree/master#frozenconfigdict) so that you can freeze some of the parameters and keep memory usage low\n", + "* Initialize an instance of the PaliGemma `Model` class using the `model_config` as its configurations\n", + "* Load the model parameters into RAM\n", + "* Define a `decode` function to sample outputs from the model\n", + "\n", + "This code in this cell takes about a minute to run to completion." + ], + "metadata": { + "id": "fnCT0G9sHxsX" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1aghcULcEdtv" + }, + "outputs": [], + "source": [ + "# Define model\n", + "model_config = ml_collections.FrozenConfigDict({\n", + " \"llm\": {\"vocab_size\": 257_152},\n", + " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", + "})\n", + "model = paligemma.Model(**model_config)\n", + "tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)\n", + "\n", + "# Load params - this can take up to 1 minute in T4 colabs.\n", + "params = paligemma.load(None, MODEL_PATH, model_config)\n", + "\n", + "# Define `decode` function to sample outputs from the model.\n", + "decode_fn = predict_fns.get_all(model)['decode']\n", + "decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Move model parameters into GPU/TPU memory\n", + "\n", + "Now you need to move the model parameters into GPU/TPU memory. First, shard the parameters across the available GPUs, then load the parameters. Here, you'll load the parameters sequentially. This process takes longer than loading them simultaneously, but it requires more RAM than you have available in this notebook.\n", + "\n", + "Finally, print out all of the parameters to see what type each individual parameter is cast to. Frozen parameters are kept as `float16`, while the trainable parameters are cast to `float32`. When you inspect the list, you'll see that most of the parameters have been frozen and are `float16`." + ], + "metadata": { + "id": "jytrSroKIfLD" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RWOdf_fw2SAO" + }, + "outputs": [], + "source": [ + "# Create a pytree mask of the trainable params.\n", + "def is_trainable_param(name, param): # pylint: disable=unused-argument\n", + " if name.startswith(\"llm/layers/attn/\"): return True\n", + " if name.startswith(\"llm/\"): return False\n", + " if name.startswith(\"img/\"): return False\n", + " raise ValueError(f\"Unexpected param name {name}\")\n", + "trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)\n", + "\n", + "# If more than one device is available (e.g. multiple GPUs) the parameters can\n", + "# be sharded across them to reduce HBM usage per device.\n", + "mesh = jax.sharding.Mesh(jax.devices(), (\"data\"))\n", + "\n", + "data_sharding = jax.sharding.NamedSharding(\n", + " mesh, jax.sharding.PartitionSpec(\"data\"))\n", + "\n", + "params_sharding = big_vision.sharding.infer_sharding(\n", + " params, strategy=[('.*', 'fsdp(axis=\"data\")')], mesh=mesh)\n", + "\n", + "# Yes: Some donated buffers are not usable.\n", + "warnings.filterwarnings(\n", + " \"ignore\", message=\"Some donated buffers were not usable\")\n", + "\n", + "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", + "def maybe_cast_to_f32(params, trainable):\n", + " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n", + " params, trainable)\n", + "\n", + "# Loading all params in simultaneous - albeit much faster and more succinct -\n", + "# requires more RAM than the T4 colab runtimes have by default.\n", + "# Instead we do it param by param.\n", + "params, treedef = jax.tree.flatten(params)\n", + "sharding_leaves = jax.tree.leaves(params_sharding)\n", + "trainable_leaves = jax.tree.leaves(trainable_mask)\n", + "for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):\n", + " params[idx] = big_vision.utils.reshard(params[idx], sharding)\n", + " params[idx] = maybe_cast_to_f32(params[idx], trainable)\n", + " params[idx].block_until_ready()\n", + "params = jax.tree.unflatten(treedef, params)\n", + "\n", + "# Print params to show what the model is made of.\n", + "def parameter_overview(params):\n", + " for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:\n", + " print(f\"{path:80s} {str(arr.shape):22s} {arr.dtype}\")\n", + "\n", + "print(\" == Model params == \")\n", + "parameter_overview(params)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Prepare to tune the model\n", + "\n", + "Now that your model is configured, you can tune it. In this step, you'll create your model's inputs as well as the training and validation iterators, view the training examples, and define the training and validation loops." + ], + "metadata": { + "id": "Y1sfijh_Ix09" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Create model inputs\n", + "\n", + "The model checkpoint you're using has already been trained on images of various aspect ratios that have been resized to 224x224 pixels, and to handle tokenized texts.\n", + "\n", + "The code below defines three functions that you'll use in the next step create the model's inputs:\n", + "\n", + "* **`preprocess_image`:** Normalizes the image data. In this case, pre-processing converts the passed-in image to greyscale, removes the alpha layer, and resizes the passed-in image to the size required by the model for image inputs (224x224 pixels).\n", + "* **`preprocess_tokens`:** Splits the tokens up and adds flags to mark whether a token is a prefix or suffix token. These flags will be used later on in the code, during the training step and the evaluation loop.\n", + "* **`postprocess_tokens`:** Removes any tokens left at and/or after the end-of-sequence (EOS) token and returns the remaining decoded tokens.\n" + ], + "metadata": { + "id": "6hNkSJwJI138" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8SRW0NuU4UcW" + }, + "outputs": [], + "source": [ + "def preprocess_image(image, size=224):\n", + " # Model has been trained to handle images of different aspects ratios\n", + " # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize\n", + " # options are helpful to improve quality in some tasks.\n", + " image = np.asarray(image)\n", + " if image.ndim == 2: # Convert image without last channel into greyscale.\n", + " image = np.stack((image,)*3, axis=-1)\n", + " image = image[..., :3] # Remove alpha layer.\n", + " assert image.shape[-1] == 3\n", + "\n", + " image = tf.constant(image)\n", + " image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n", + " return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n", + "\n", + "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n", + " # Model has been trained to handle tokenized text composed of a prefix with\n", + " # full attention and a suffix with causal attention.\n", + " separator = \"\\n\"\n", + " tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)\n", + " mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.\n", + " mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.\n", + "\n", + " if suffix:\n", + " suffix = tokenizer.encode(suffix, add_eos=True)\n", + " tokens += suffix\n", + " mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.\n", + " mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.\n", + "\n", + " mask_input = [1] * len(tokens) # 1 if it's a token, 0 if padding.\n", + " if seqlen:\n", + " padding = [0] * max(0, seqlen - len(tokens))\n", + " tokens = tokens[:seqlen] + padding\n", + " mask_ar = mask_ar[:seqlen] + padding\n", + " mask_loss = mask_loss[:seqlen] + padding\n", + " mask_input = mask_input[:seqlen] + padding\n", + "\n", + " return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))\n", + "\n", + "def postprocess_tokens(tokens):\n", + " tokens = tokens.tolist() # np.array to list[int]\n", + " try: # Remove tokens at and after EOS if any.\n", + " eos_pos = tokens.index(tokenizer.eos_id())\n", + " tokens = tokens[:eos_pos]\n", + " except ValueError:\n", + " pass\n", + " return tokenizer.decode(tokens)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Create the training and validation iterators\n", + "\n", + "Create two iterators:\n", + "\n", + "* A **training iterator** to allow the training process to go through the data in chunks rather than processing it all at once. This allows you to do some data pre-processing before use.\n", + "* A **validation iterator** that allows the training process to iterate over the validation dataset to see how well the tuned model aligned with the provided results." + ], + "metadata": { + "id": "h4Lul8c3JDBQ" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "whzWOojGOtzi" + }, + "outputs": [], + "source": [ + "SEQLEN = 128\n", + "\n", + "train_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(dataset.location, \"dataset/_annotations.train.jsonl\"),\n", + " fopen_keys={\"image\": f\"{dataset.location}/dataset\"})\n", + "\n", + "val_dataset = big_vision.datasets.jsonl.DataSource(\n", + " os.path.join(dataset.location, \"dataset/_annotations.valid.jsonl\"),\n", + " fopen_keys={\"image\": f\"{dataset.location}/dataset\"})\n", + "\n", + "\n", + "def train_data_iterator():\n", + " \"\"\"Never ending iterator over training examples.\"\"\"\n", + " # Shuffle examples and repeat so one can train for many epochs.\n", + " dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()\n", + " for example in dataset.as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " prefix = example[\"prefix\"].decode().lower()\n", + " suffix = example[\"suffix\"].decode().lower()\n", + " tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)\n", + " label, _, _, _ = preprocess_tokens(suffix, seqlen=SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"label\": np.asarray(label),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_loss\": np.asarray(mask_loss),\n", + " }\n", + "\n", + "\n", + "def validation_data_iterator():\n", + " \"\"\"Single iterator over validation examples.\"\"\"\n", + " for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():\n", + " image = Image.open(io.BytesIO(example[\"image\"]))\n", + " image = preprocess_image(image)\n", + "\n", + " prefix = example[\"prefix\"].decode().lower()\n", + " suffix = example[\"suffix\"].decode().lower()\n", + " tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)\n", + " label, _, _, _ = preprocess_tokens(suffix, seqlen=SEQLEN)\n", + "\n", + " yield {\n", + " \"image\": np.asarray(image),\n", + " \"text\": np.asarray(tokens),\n", + " \"label\": np.asarray(label),\n", + " \"mask_ar\": np.asarray(mask_ar),\n", + " \"mask_input\": np.asarray(mask_input),\n", + " }\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### View training examples\n", + "\n", + "In this notebook, the training data contains 90 images that are paired with long descriptions of what's depicted in the image.\n", + "\n", + "**Note:** Normal training data sets that are meant to be used for practical use cases should contain more images, but this notebook limits the number of data points so that you can train the model in a reasonable amount of time for an example.\n", + "\n", + "The code below prints a random selection of images with their descriptions from the training data set so that you can see what the images and descriptions your model is trained on looks like. Each image is displayed in as a 128x128 pixel JPEG, with the description printed next to the image to the right." + ], + "metadata": { + "id": "ml_wkTbJJj-N" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 167 + }, + "id": "BzJfb5t0nsLq", + "outputId": "65726c5b-2d7e-4fb8-b5f4-5183978f8118" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Training examples\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0204><loc0178><loc0972><loc0720> mult

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0128><loc0284><loc0752><loc0751> 3

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0277><loc0356><loc0685><loc0612> 6

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0128><loc0271><loc0765><loc0736> 4

\n", + "
\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "def split_and_keep_second_part(s):\n", + " parts = s.split('\\n', 1)\n", + " if len(parts) > 1:\n", + " return parts[1]\n", + " return s\n", + "\n", + "def render_inline(image, resize=(128, 128)):\n", + " \"\"\"Convert image into inline html.\"\"\"\n", + " image = Image.fromarray(image)\n", + " image.resize(resize)\n", + " with io.BytesIO() as buffer:\n", + " image.save(buffer, format='jpeg')\n", + " image_b64 = str(base64.b64encode(buffer.getvalue()), \"utf-8\")\n", + " return f\"data:image/jpeg;base64,{image_b64}\"\n", + "\n", + "def render_example(image, caption):\n", + " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", + " h, w, _ = image.shape\n", + " try:\n", + " detections = sv.Detections.from_lmm(\n", + " lmm='paligemma',\n", + " result=caption,\n", + " resolution_wh=(w, h),\n", + " classes=CLASSES)\n", + " image = sv.BoundingBoxAnnotator().annotate(image, detections)\n", + " image = sv.LabelAnnotator().annotate(image, detections)\n", + " except:\n", + " print(caption)\n", + " return f\"\"\"\n", + "
\n", + " \n", + "

{html.escape(caption)}

\n", + "
\n", + "\"\"\"\n", + "\n", + "html_out = \"\"\n", + "for idx, example in zip(range(4), train_data_iterator()):\n", + " caption = postprocess_tokens(example[\"text\"]) # detokenize model input.\n", + " caption = split_and_keep_second_part(caption)\n", + " html_out += render_example(example[\"image\"], caption)\n", + "\n", + "print(\"Training examples\")\n", + "display(HTML(html_out))" + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Define the training and evaluation loops\n", + "\n", + "Define the training loop to train the model on the provided dataset, and the evaluation loop to look at all of the examples in the validation dataset and make its predictions.\n", + "\n", + "#### Defining the training loop\n", + "\n", + "The `update_fn` function defines the training step. During the training step, the loss per example is calculated and stochastic gradient descent (SGD) is applied to the trainable parameters.\n", + "\n", + "Recall that earlier in the notebook, you included flags in the `preprocess_tokens` function that included `mask_loss`. You'll use the `mask_loss` flag here to exclude prefix and padded tokens from the loss. Without it, the loss calculation will be skewed. You also need to normalize each example, since each of them has a different number of tokens. After the prefix and padded tokens have been excluded and the examples have been normalized, you can calculate the loss per example.\n", + "\n", + "The training step also includes a function to apply an SGD to optimize the training.\n", + "\n", + "#### Defining the evaluation loop\n", + "\n", + "The `make_predictions` function is your evaluation loop. The evaluation loop is fairly straight forward with one notable change. If you recall from the beginning of the notebook, you only have 90 examples in your training data set. This is a very small amount of training examples, and your model ends up not having enough examples for the batch size when you run the training. This means that in the evaluation loop, you need to pad the batch by repeating examples.\n", + "\n", + "To make sure that your evaluation loop only counts actual examples and not the padded examples, you have to apply a mask to the padded examples that excludes them from the output." + ], + "metadata": { + "id": "hKFJ9rbLKoTa" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "dwUV_imW3WQJ" + }, + "outputs": [], + "source": [ + "# The main update_fn using a simple stochastic gradient descent (SGD).\n", + "@functools.partial(jax.jit, donate_argnums=(0,))\n", + "def update_fn(params, batch, learning_rate):\n", + " imgs, txts, mask_ar = batch[\"image\"], batch[\"text\"], batch[\"mask_ar\"]\n", + "\n", + " def loss_fn(params):\n", + " text_logits, _ = model.apply({\"params\": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)\n", + " logp = jax.nn.log_softmax(text_logits, axis=-1)\n", + "\n", + " # The model takes as input txts[:, :-1] but the loss is defined as predicting\n", + " # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens\n", + " # are part of the loss (e.g. prefix and padded tokens are not included).\n", + " mask_loss = batch[\"mask_loss\"][:, 1:]\n", + " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", + "\n", + " # Compute the loss per example. i.e. the mean of per token pplx.\n", + " # Since each example has a different number of tokens we normalize it.\n", + " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", + " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", + " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n", + "\n", + " # batch_loss: mean of per example loss.\n", + " return jnp.mean(example_loss)\n", + "\n", + " loss, grads = jax.value_and_grad(loss_fn)(params)\n", + "\n", + " # Apply gradients to trainable params using SGD.\n", + " def apply_grad(param, gradient, trainable):\n", + " if not trainable: return param\n", + " return param - learning_rate * gradient\n", + "\n", + " params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)\n", + "\n", + " return params, loss\n", + "\n", + "# Evaluation/inference loop.\n", + "def make_predictions(data_iterator, *, num_examples=None,\n", + " batch_size=4, seqlen=SEQLEN, sampler=\"greedy\"):\n", + " outputs = []\n", + " while True:\n", + " # Construct a list of examples in the batch.\n", + " examples = []\n", + " try:\n", + " for _ in range(batch_size):\n", + " examples.append(next(data_iterator))\n", + " examples[-1][\"_mask\"] = np.array(True) # Indicates true example.\n", + " except StopIteration:\n", + " if len(examples) == 0:\n", + " return outputs\n", + "\n", + " # Not enough examples to complete a batch. Pad by repeating last example.\n", + " while len(examples) % batch_size:\n", + " examples.append(dict(examples[-1]))\n", + " examples[-1][\"_mask\"] = np.array(False) # Indicates padding example.\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Make model predictions\n", + " tokens = decode({\"params\": params}, batch=batch,\n", + " max_decode_len=seqlen, sampler=sampler)\n", + "\n", + " # Fetch model predictions to device and detokenize.\n", + " tokens, mask = jax.device_get((tokens, batch[\"_mask\"]))\n", + " tokens = tokens[mask] # remove padding examples.\n", + " labels = [postprocess_tokens(e[\"label\"]) for e in examples]\n", + " responses = [postprocess_tokens(t) for t in tokens]\n", + "\n", + " # Append to html output.\n", + " for example, label, response in zip(examples, labels, responses):\n", + " outputs.append((example[\"image\"], label, response))\n", + " if num_examples and len(outputs) >= num_examples:\n", + " return outputs" + ] + }, + { + "cell_type": "code", + "source": [ + "html_out = \"\"\n", + "for image, _, caption in make_predictions(validation_data_iterator(), num_examples=4, batch_size=4):\n", + " html_out += render_example(image, caption)\n", + "display(HTML(html_out))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 375 + }, + "id": "GCXYnIdm4ILQ", + "outputId": "51ed1a0f-6580-4728-dfea-606310bc6056" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0233><loc0896><loc0921> 1 ; <loc0046><loc0233><loc0896><loc0921> 3 ; <loc0046><loc0505><loc0896><loc0696> 7 ; <loc0046><loc0603><loc0432><loc0687> 5 ; <loc0046><loc0233><loc0536><loc0912> 6 ; <loc0046><loc0233><loc0896><loc0921> 4 ; <loc0046><loc0233><loc0533><loc0912> 9 ; <loc0046><loc0233><loc0896><loc0912> 2 ; <loc0046><loc0233><loc0531><loc0912> 8 ; <loc0045><loc0233><loc0896><loc0912> 1 ; <loc0045><loc0000><loc1023><loc1013> mult ; <loc0045><loc0541><loc0896><loc0696> 7 ; <loc0045><loc0567><loc0456><loc0690> minus ; <loc0413><loc0233><loc0538><loc0912> div ; <loc0036><loc0000><loc1023><loc1015> 1 ; <loc0036><loc0000><loc1023><loc1023> 1 ; <loc0036><loc0000><loc1023>

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0217><loc0261><loc0957><loc0819> 4 ; <loc0207><loc0252><loc0954><loc0826> 7 ; <loc0207><loc0246><loc0949><loc0826> 5 ; <loc0207><loc0246><loc0949><loc0826> 1 ; <loc0207><loc0421><loc0778><loc0703> minus ; <loc0000><loc0000><loc1023><loc1017> mult ; <loc0207><loc0246><loc0949><loc0818> 3 ; <loc0207><loc0246><loc0944><loc0809> 6 ; <loc0754><loc0264><loc0936><loc0513> 8 ; <loc0754><loc0261><loc0944><loc0813> 4 ; <loc0754><loc0261><loc0936><loc0513> 5 ; <loc0754><loc0261><loc0944><loc0813> 7 ; <loc0754><loc0261><loc0936><loc0813> 9 ; <loc0747><loc0496><loc0956><loc0817> div ; <loc0000><loc0000><loc1023><loc1017> 1 ; <loc0000><loc0000><loc1023><loc1023> 2 ; <loc0000><loc0000><loc1023>

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0230><loc0920><loc0634> 1 ; <loc0116><loc0230><loc0531><loc0549> 7 ; <loc0116><loc0230><loc0920><loc0629> 3 ; <loc0116><loc0230><loc0531><loc0438> 6 ; <loc0116><loc0222><loc0925><loc0629> 4 ; <loc0116><loc0222><loc0925><loc0629> 5 ; <loc0116><loc0222><loc0541><loc0553> 8 ; <loc0000><loc0000><loc1018><loc1018> mult ; <loc0116><loc0222><loc0925><loc0629> 2 ; <loc0116><loc0230><loc0541><loc0557> minus ; <loc0866><loc0404><loc0926><loc0658> plus ; <loc0168><loc0505><loc0906><loc0580> 4 ; <loc0168><loc0505><loc0902><loc0575> 9 ; <loc0866><loc0403><loc0922><loc0658> div ; <loc0116><loc0230><loc0925><loc0635> 1 ; <loc0116><loc0230><loc0925><loc0626> 4 ; <loc0116><loc0230><loc0541><loc0549>

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0143><loc0374><loc0658><loc0442> 4 ; <loc0143><loc0374><loc0658><loc0441> 5 ; <loc0140><loc0366><loc0667><loc0442> 6 ; <loc0000><loc0000><loc1023><loc1017> mult ; <loc0140><loc0366><loc0667><loc0442> 7 ; <loc0134><loc0365><loc0667><loc0442> 8 ; <loc0000><loc0000><loc1023><loc1017> 3 ; <loc0000><loc0000><loc1023><loc1023> 1 ; <loc0000><loc0000><loc1023><loc1023> eq ; <loc0000><loc0000><loc1023><loc1023> 2 ; <loc0000><loc0000><loc0693><loc1023> 4 ; <loc0000><loc0000><loc1023><loc1023> 9 ; <loc0000><loc0000><loc1023><loc1023> minus ; <loc0000><loc0000><loc1023><loc1023> 10 ; <loc0000><loc0356><loc0703><loc0456> 1 ; <loc0000><loc0000><loc1023><loc1023> 3 ; <loc0000><loc0000>

\n", + "
\n" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Tune the model\n", + "\n", + "Now that you've set everything up and taken a look at the training data, it's time to finally tune the model. The code below runs the training loop for the model for 64 steps and prints the learning rate (`lr` in the printed output) and loss rate for each step.\n", + "\n", + "Every 16 steps, the model prints what its predictions are at that step in the training. This code prints out predictions for the same set of images so that you can see the model's ability to predict descriptions improve over time.\n", + "\n", + "At earlier steps in the training, there's likely issues with the descriptions, such as repeated sentences as the model gets stuck in its predictive loop or unfinished sentences. The model's predictions become steadily more accurate as training progresses. By step 64, the model's predictions should closely resemble the descriptions provided by the training data.\n", + "\n", + "This process takes around 15 minutes to complete on T4 TPUs." + ], + "metadata": { + "id": "fNigSP99MJFe" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "067wj_6bZAG3", + "outputId": "8671008b-9a90-4a36-f297-0facd27185b4", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + } + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 1/64 lr: 0.00083 loss: 5.6622\n", + "step: 2/64 lr: 0.00167 loss: 3.8568\n", + "step: 3/64 lr: 0.00250 loss: 2.8670\n", + "step: 4/64 lr: 0.00333 loss: 2.9897\n", + "step: 5/64 lr: 0.00417 loss: 4.1907\n", + "step: 6/64 lr: 0.00500 loss: 3.3373\n", + "step: 7/64 lr: 0.00500 loss: 2.8551\n", + "step: 8/64 lr: 0.00499 loss: 2.9553\n", + "Model predictions at step 8\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0039><loc0238><loc0912><loc0921> mult

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0264><loc0968><loc0826> mult

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0230><loc0936><loc0646> mult

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0137><loc0377><loc0694><loc0452> 9

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 9/64 lr: 0.00497 loss: 2.9513\n", + "step: 10/64 lr: 0.00494 loss: 3.1961\n", + "step: 11/64 lr: 0.00491 loss: 3.0491\n", + "step: 12/64 lr: 0.00487 loss: 2.8612\n", + "step: 13/64 lr: 0.00483 loss: 2.5327\n", + "step: 14/64 lr: 0.00478 loss: 2.4965\n", + "step: 15/64 lr: 0.00472 loss: 2.9966\n", + "step: 16/64 lr: 0.00465 loss: 2.9704\n", + "Model predictions at step 16\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0216><loc0887><loc0887> div

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0236><loc0235><loc0949><loc0789> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0123><loc0216><loc0912><loc0617> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0144><loc0363><loc0653><loc0432> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 17/64 lr: 0.00458 loss: 2.8967\n", + "step: 18/64 lr: 0.00451 loss: 2.5452\n", + "step: 19/64 lr: 0.00442 loss: 2.5070\n", + "step: 20/64 lr: 0.00434 loss: 2.5249\n", + "step: 21/64 lr: 0.00424 loss: 2.3752\n", + "step: 22/64 lr: 0.00415 loss: 2.6018\n", + "step: 23/64 lr: 0.00404 loss: 2.5676\n", + "step: 24/64 lr: 0.00394 loss: 2.3139\n", + "Model predictions at step 24\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0223><loc0906><loc0919> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0246><loc0961><loc0826> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0216><loc0928><loc0647> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0140><loc0366><loc0667><loc0452> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 25/64 lr: 0.00383 loss: 2.2908\n", + "step: 26/64 lr: 0.00371 loss: 2.1218\n", + "step: 27/64 lr: 0.00359 loss: 2.4826\n", + "step: 28/64 lr: 0.00347 loss: 2.5923\n", + "step: 29/64 lr: 0.00335 loss: 2.7243\n", + "step: 30/64 lr: 0.00322 loss: 2.4624\n", + "step: 31/64 lr: 0.00309 loss: 2.4676\n", + "step: 32/64 lr: 0.00296 loss: 2.2581\n", + "Model predictions at step 32\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0036><loc0230><loc0888><loc0912> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0211><loc0251><loc0961><loc0826> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0112><loc0222><loc0919><loc0662> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0132><loc0366><loc0662><loc0447> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 33/64 lr: 0.00283 loss: 2.2747\n", + "step: 34/64 lr: 0.00270 loss: 2.4034\n", + "step: 35/64 lr: 0.00257 loss: 2.1800\n", + "step: 36/64 lr: 0.00243 loss: 2.3051\n", + "step: 37/64 lr: 0.00230 loss: 2.2851\n", + "step: 38/64 lr: 0.00217 loss: 2.1452\n", + "step: 39/64 lr: 0.00204 loss: 2.3099\n", + "step: 40/64 lr: 0.00191 loss: 2.1190\n", + "Model predictions at step 40\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0238><loc0896><loc0912> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0261><loc0965><loc0814> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0230><loc0919><loc0653> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0140><loc0369><loc0662><loc0447> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 41/64 lr: 0.00178 loss: 2.3048\n", + "step: 42/64 lr: 0.00165 loss: 2.2890\n", + "step: 43/64 lr: 0.00153 loss: 2.4525\n", + "step: 44/64 lr: 0.00141 loss: 2.3838\n", + "step: 45/64 lr: 0.00129 loss: 2.2427\n", + "step: 46/64 lr: 0.00117 loss: 2.0810\n", + "step: 47/64 lr: 0.00106 loss: 2.2332\n", + "step: 48/64 lr: 0.00096 loss: 2.1453\n", + "Model predictions at step 48\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0238><loc0896><loc0919> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0261><loc0961><loc0824> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0230><loc0919><loc0653> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0141><loc0374><loc0662><loc0450> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 49/64 lr: 0.00085 loss: 2.1180\n", + "step: 50/64 lr: 0.00076 loss: 2.5212\n", + "step: 51/64 lr: 0.00066 loss: 2.3260\n", + "step: 52/64 lr: 0.00058 loss: 2.3060\n", + "step: 53/64 lr: 0.00049 loss: 2.3717\n", + "step: 54/64 lr: 0.00042 loss: 2.3142\n", + "step: 55/64 lr: 0.00035 loss: 1.9866\n", + "step: 56/64 lr: 0.00028 loss: 2.1863\n", + "Model predictions at step 56\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0233><loc0896><loc0919> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0261><loc0957><loc0824> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0121><loc0222><loc0919><loc0653> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0141><loc0374><loc0655><loc0450> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "step: 57/64 lr: 0.00022 loss: 2.1201\n", + "step: 58/64 lr: 0.00017 loss: 2.2197\n", + "step: 59/64 lr: 0.00013 loss: 2.1977\n", + "step: 60/64 lr: 0.00009 loss: 2.2539\n", + "step: 61/64 lr: 0.00006 loss: 2.0506\n", + "step: 62/64 lr: 0.00003 loss: 2.4619\n", + "step: 63/64 lr: 0.00001 loss: 2.2024\n", + "step: 64/64 lr: 0.00000 loss: 2.3148\n", + "Model predictions at step 64\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0233><loc0896><loc0912> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0261><loc0957><loc0824> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0222><loc0919><loc0653> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0141><loc0374><loc0658><loc0450> 1

\n", + "
\n" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 11min 12s, sys: 463 ms, total: 11min 12s\n", + "Wall time: 11min 14s\n" + ] + } + ], + "source": [ + "# Run a short training loop with cosine learning rate schedule.\n", + "#\n", + "# Note: the first step can be quite slow on some machines (up to several minutes)\n", + "# due to XLA compilation of the jax.jit'd function.\n", + "#\n", + "%%time\n", + "\n", + "BATCH_SIZE = 8\n", + "TRAIN_EXAMPLES = 512\n", + "LEARNING_RATE = 0.005\n", + "\n", + "TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE\n", + "EVAL_STEPS = TRAIN_STEPS // 8\n", + "\n", + "train_data_it = train_data_iterator()\n", + "\n", + "sched_fn = big_vision.utils.create_learning_rate_schedule(\n", + " total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,\n", + " decay_type=\"cosine\", warmup_percent=0.10)\n", + "\n", + "for step in range(1, TRAIN_STEPS+1):\n", + " # Make list of N training examples.\n", + " examples = [next(train_data_it) for _ in range(BATCH_SIZE)]\n", + "\n", + " # Convert list of examples into a dict of np.arrays and load onto devices.\n", + " batch = jax.tree.map(lambda *x: np.stack(x), *examples)\n", + " batch = big_vision.utils.reshard(batch, data_sharding)\n", + "\n", + " # Training step and report training loss\n", + " learning_rate = sched_fn(step)\n", + " params, loss = update_fn(params, batch, learning_rate)\n", + "\n", + " loss = jax.device_get(loss)\n", + " print(f\"step: {step:2d}/{TRAIN_STEPS:2d} lr: {learning_rate:.5f} loss: {loss:.4f}\")\n", + "\n", + " if (step % EVAL_STEPS) == 0:\n", + " print(f\"Model predictions at step {step}\")\n", + " html_out = \"\"\n", + " for image, _, caption in make_predictions(\n", + " validation_data_iterator(), num_examples=4, batch_size=4):\n", + " html_out += render_example(image, caption)\n", + " display(HTML(html_out))\n" + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Evaluate fine-tuned model" + ], + "metadata": { + "id": "rGjB_4Mo8RA_" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 548 + }, + "id": "hgUhEKjzPdMQ", + "outputId": "59b2fe9a-5ba4-48e1-a08d-fcc7c91049b9" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + "

<loc0046><loc0233><loc0896><loc0912> plus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0219><loc0261><loc0957><loc0824> 2

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0116><loc0222><loc0919><loc0653> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0141><loc0374><loc0658><loc0450> 1

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0112><loc0155><loc0927><loc0639> 3

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0127><loc0297><loc0785><loc0673> 4

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0058><loc0261><loc0831><loc0662> 8

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0222><loc0453><loc0735><loc0819> 7

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0449><loc0180><loc0533><loc0797> minus

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0266><loc0502><loc0936><loc0626> 1

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0172><loc0222><loc0912><loc0634> 5

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0230><loc0363><loc0894><loc0738> 7

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0266><loc0572><loc0827><loc0766> 1

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0127><loc0230><loc0901><loc0703> 6

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0317><loc0140><loc0936><loc0812> mult

\n", + "
\n", + "\n", + "
\n", + " \n", + "

<loc0161><loc0261><loc0838><loc0831> 6

\n", + "
\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "# @title Visualize results\n", + "html_out = \"\"\n", + "for image, _, caption in make_predictions(validation_data_iterator(), num_examples=16, batch_size=8):\n", + " html_out += render_example(image, caption)\n", + "display(HTML(html_out))" + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Collect predictions\n", + "targets = []\n", + "predictions = []\n", + "\n", + "for image, label, prediction in make_predictions(validation_data_iterator(), num_examples=512, batch_size=8):\n", + " h, w, _ = image.shape\n", + " target = sv.Detections.from_lmm(\n", + " lmm='paligemma',\n", + " result=label,\n", + " resolution_wh=(w, h),\n", + " classes=CLASSES)\n", + " targets.append(target)\n", + " prediction = sv.Detections.from_lmm(\n", + " lmm='paligemma',\n", + " result=prediction,\n", + " resolution_wh=(w, h),\n", + " classes=CLASSES)\n", + " prediction.confidence = np.ones(len(prediction))\n", + " predictions.append(prediction)" + ], + "metadata": { + "id": "JD9l94a8pYRc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Calculate mAP\n", + "mean_average_precision = sv.MeanAveragePrecision.from_detections(\n", + " predictions=predictions,\n", + " targets=targets,\n", + ")\n", + "\n", + "print(f\"map50_95: {mean_average_precision.map50_95:.2f}\")\n", + "print(f\"map50: {mean_average_precision.map50:.2f}\")\n", + "print(f\"map75: {mean_average_precision.map75:.2f}\")" + ], + "metadata": { + "id": "d62AueMC7Yp3", + "outputId": "487bb72c-c023-47c9-a33b-acdbf6ac1be0", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "map50_95: 0.83\n", + "map50: 0.94\n", + "map75: 0.90\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Calculate Confusion Matrix\n", + "confusion_matrix = sv.ConfusionMatrix.from_detections(\n", + " predictions=predictions,\n", + " targets=targets,\n", + " classes=CLASSES\n", + ")\n", + "\n", + "_ = confusion_matrix.plot()" + ], + "metadata": { + "id": "lE7je-hL8Bj3", + "outputId": "86db3d53-0caa-4186-ae8f-0e0460081c61", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Save fine-tuned model locally" + ], + "metadata": { + "id": "Hr1gTKP8trRb" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "TARGET_MODEL_DIR = f\"{dataset.location}/model\"\n", + "TARGET_MODEL_PATH = f\"{TARGET_MODEL_DIR}/paligemma-3b-pt-224.f16.npz\"\n", + "\n", + "os.makedirs(TARGET_MODEL_DIR, exist_ok=True)" + ], + "metadata": { + "id": "N4Y43q4jKj7a" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "flat, _ = big_vision.utils.tree_flatten_with_names(params)\n", + "with open(TARGET_MODEL_PATH, \"wb\") as f:\n", + " np.savez(f, **{k: v for k, v in flat})" + ], + "metadata": { + "id": "zyVxKr2FOxPe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Deploy model on Roboflow" + ], + "metadata": { + "id": "7gj7L3BkMZrZ" + } + }, + { + "cell_type": "code", + "source": [ + "version.deploy(model_type=\"paligemma-3b-pt-224\", model_path=TARGET_MODEL_DIR)" + ], + "metadata": { + "id": "_YfNYT7qMMY2", + "outputId": "2cbcf204-22d0-45d3-fe29-87cdd663278e", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model files found in /content/number-ops-1/model: ['paligemma-3b-pt-224.f16.npz']\n", + "Found .npz file paligemma-3b-pt-224.f16.npz in model path. Deploying JAX PaliGemma model.\n", + "Zipping files for deploy: ['paligemma-3b-pt-224.f16.npz']\n", + "Uploading to Roboflow... May take several minutes.\n", + "View the status of your deployment at: https://app.roboflow.com/roboflow-jvuqo/number-ops-j1426/1\n", + "Share your model with the world at: https://universe.roboflow.com/roboflow-jvuqo/number-ops-j1426/model/1\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Congratulations\n", + "\n", + "⭐️ If you enjoyed this notebook, [**star the Roboflow Notebooks repo**](https://github.com/roboflow/notebooks) (and [**supervision**](https://github.com/roboflow/supervision) while you're at it) and let us know what tutorials you'd like to see us do next. ⭐️" + ], + "metadata": { + "id": "kR8llI4Qv0pR" + } + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [], + "machine_shape": "hm" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file