torch_quantizer
is a Python package designed for efficient quantization of PyTorch models, particularly focusing on converting floating-point Linear and Conv2d modules to INT8 precision for improved inference speed on CUDA backend. Also, torch_quantizer supports temporary quantization and is specially optimized for diffusion models.
Before installing torch_quantizer
, ensure you have PyTorch installed in your environment.
To install pre-built torch_quantizer
, use the following command:
pip install torch_quantizer-*.whl
To build from source, cloning the repository and compiling with NVCC:
pip install -e .
Import PyTorch before importing torch_quantizer
:
import torch
import torch_quantizer as tq
To benchmark and verify the inference speedup by INT8 operation, use the following methods:
python3 example/benchmark/benchmark.py --o linear --bs 512 --cin 960 --cout 960
python3 example/benchmark/benchmark.py --o conv2d --bs 1 --cin 512 --cout 512
Here are the results for both linear and 2D convolution operations, showing the average time taken for FP32, FP16, and INT8 (Quant+Dequant) operations on RTX 3090.
- Batch Size: 512
- Channels In: 960
- Channels Out: 960
Operation Type | Average Time |
---|---|
FP32 | 8.414e-05 s |
FP16 | 3.304e-05 s |
INT8 (Quant+Dequant) | 2.908e-05 s |
- Batch Size: 1
- Channels In: 512
- Channels Out: 512
Operation Type | Average Time |
---|---|
FP32 | 0.000903 s |
FP16 | 0.000413 s |
INT8 (Quant+Dequant) | 0.000178 s |
These results highlight the performance improvements achieved through quantization to INT8, demonstrating significant reductions in inference time across both operations.
Ensure your model can be inferenced with FP16 precision. For instance:
unet.half()
# do FP16 inference
Convert your model to a fake quantized model for calibration:
wq_params = {'n_bits': 8, 'channel_wise': True, 'scale_method': 'max'}
aq_params = {'n_bits': 8, 'channel_wise': False, 'scale_method': 'mse'}
ddim_steps = # Define diffusion steps here, 1 for non-temporal models
fakeq_model = tq.fake_quant(unet, wq_params, aq_params, num_steps=ddim_steps)
Run your fake quantized model with some input for calibration. For example:
from some_diffusion_library import DiffusionPipeline
pipe = DiffusionPipeline()
prompt = "a photo of a flying dog"
image = pipe(prompt, guidance_scale=7.5)["sample"][0]
Convert the fake quantized model to a real INT8 model:
qunet = tq.fake2real(fakeq_model, save_dir='.')
The INT8 checkpoint will be saved in the specified directory.
Load the INT8 model directly from the checkpoint:
ckpt_path = 'path_to_checkpoint'
qnn = tq.real_quant(unet, n_bits, ddim_steps, ckpt_path)
This repository is built upon QUIK. We thank the authors for their open-sourced code.
torch_quantizer
is released under Apache-2.0 License.