Please install latest TorchAO to support float8 dtype
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
Launch training job with the following command (or alternatively set configs in toml files)
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
--float8.enable_float8_linear
: swapnn.Linear
withFloat8Linear
to perform float8 matmul.--float8.enable_fsdp_float8_all_gather
: castFloat8Linear.weight
from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.--float8.precompute_float8_dynamic_scale_for_fsdp
(optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for Float8Linear
).
For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.