ivkalgin's picture
updated README.md (added latency table)
2db91ae
|
raw
history blame
1.72 kB
metadata
license: apache-2.0
datasets:
  - lambada
language:
  - en
library_name: transformers
pipeline_tag: text-generation
tags:
  - text-generation-inference
  - causal-lm
  - int8
  - tensorrt
  - ENOT-AutoDL

INT8 GPT-J 6B

GPT-J 6B is a transformer model trained using Ben Wang's Mesh Transformer JAX. "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.

This repository contains TensorRT engines with mixed precission int8 + fp32. You can find prebuilt engines for the following GPUs:

  • RTX 4090
  • RTX 3080 Ti
  • RTX 2080 Ti

ONNX model generated by ENOT-AutoDL and build script will be published soon.

Metrics:

TensorRT INT8+FP32 torch FP16 torch FP32
Lambada Acc 78.79% 79.17% -
Model size (GB) 8.5 12.1 24.2

Test environment

  • GPU RTX 4090
  • CPU 11th Gen Intel(R) Core(TM) i7-11700K
  • TensorRT 8.5.3.1
  • pytorch 1.13.1+cu116

Latency:

Input sequance length Number of generated tokens TensorRT INT8+FP32 ms torch FP16 ms Acceleration
64 64 1040 1610 1.55
64 128 2089 3224 1.54
64 256 4236 6479 1.53
128 64 1060 1619 1.53
128 128 2120 3241 1.53
128 256 4296 6510 1.52
256 64 1109 1640 1.49
256 128 2204 3276 1.49
256 256 4443 6571 1.49

Test environment

  • GPU RTX 4090
  • CPU 11th Gen Intel(R) Core(TM) i7-11700K
  • TensorRT 8.5.3.1
  • pytorch 1.13.1+cu116

How to use

Example of inference and accuracy test published on github:

git clone https://github.com/ENOT-AutoDL/gpt-j-6B-tensorrt-int8