-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for loading fp8 checkpoint #68
Comments
Hi wenscarl, could you please take a look here? Since we did a bitcast during save, we need to use a bitcast during serving. |
Thanks. If the checkpoint is generated by |
Hi wenscarl, we can enable running the fp8 model (stored as int8) by setting There are several architecture variations but for a basic transformer with MHA, Could you please elaborate on "they do not go through quantized_einsum"? |
For fp8 training/inference, those layers are replaced by Fp8EinsumOp. See the USE_FP8 option in PAXML here.
from the shape [512], it looks like some bias, since the
to in line with the shape of weights in each layer. The full model config is:
|
The shapes for model weights are:
If the gptj config is set to
the checkpoint loading is able to make through. But will hit error,
Is there any example showcasing how to run offline_quantize tool to generated a quantized checkpoint and then loaded by saxml for inference? |
Hi wenscarl, thanks for all the details. I can confirm fp8 works 100% when the infra was first added. There are no public examples around fp8 since it was experimental. I think your offline_quantize script is correct. The decorator should be
The The sentence piece error is unrelated. If you are able to run float model with that config (exclude quantization stuff), the sentence piece should work for quantized checkpoint/config as well. Did you run into any issues with float? |
|
Updates:
such that the resulting
Is there an example showcasing how to do per-tensor-scaling properly? |
@jianlijianli for viz. |
Hi wenscarl, sorry for the delay. Pax defaults everything to per-channel quantization so there is no API for per-tensor quantization, but it should be easy to hack a bit locally to run per-tensor. I think all we need is to set scale dims to [1] here. The runtime should be able to handle both per-tensor and per-channel scale, thanks to the broadcast behavior of multiply. |
Hi @jianlijianli, some updates. |
Hi wenscarl, apologies for the delay and really glad you had fp8 working now. The static activation quantization is less useful so we didn't support it yet. But it should not be too hard to support it. How do you plan to do static activation? The static activation scale are collected from QAT training or calibration? The comment in https://github.com/google/praxis/blob/main/praxis/layers/quantization/quantize.py#L286 is a bit confusing. The entire quantize.py is rewriting checkpoints so from that point of view it's always weight-only, for both ffw and attention. |
There is a use_fp flag for the offline_quantize tool in saxml/tool to quantize the weight in fp8 but still has to be stored in int8(
praxis/praxis/layers/quantization/operations.py
Line 776 in 3f4cbb4
The text was updated successfully, but these errors were encountered: