# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # Source: https://github.com/facebookresearch/ToMe/blob/main/tome/utils.py # -------------------------------------------------------- import time from typing import Tuple import torch from tqdm import tqdm def benchmark( model: torch.nn.Module, device: torch.device = 0, input_size: Tuple[int] = (3, 224, 224), batch_size: int = 64, runs: int = 40, throw_out: float = 0.25, use_fp16: bool = False, verbose: bool = False, ) -> float: """ Benchmark the given model with random inputs at the given batch size. Args: - model: the module to benchmark - device: the device to use for benchmarking - input_size: the input size to pass to the model (channels, h, w) - batch_size: the batch size to use for evaluation - runs: the number of total runs to do - throw_out: the percentage of runs to throw out at the start of testing - use_fp16: whether or not to benchmark with float16 and autocast - verbose: whether or not to use tqdm to print progress / print throughput at end Returns: - the throughput measured in images / second """ if not isinstance(device, torch.device): device = torch.device(device) is_cuda = torch.device(device).type == "cuda" model = model.eval().to(device) input = torch.rand(batch_size, *input_size, device=device) if use_fp16: input = input.half() warm_up = int(runs * throw_out) total = 0 start = time.time() with torch.autocast(device.type, enabled=use_fp16): with torch.no_grad(): for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"): if i == warm_up: if is_cuda: torch.cuda.synchronize() total = 0 start = time.time() model(input) total += batch_size if is_cuda: torch.cuda.synchronize() end = time.time() elapsed = end - start throughput = total / elapsed if verbose: print(f"Throughput: {throughput:.2f} im/s") return throughput