Skip to content

Commit

Permalink
Add warning on loss values when using XLA_USE_BF16 (pytorch#2425)
Browse files Browse the repository at this point in the history
  • Loading branch information
jysohn23 authored Aug 14, 2020
1 parent 9e9bc80 commit ef05476
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,20 @@ only be enabled for debugging.
moving to the next step.

* ```XLA_USE_BF16```: If set to 1, tranforms all the _PyTorch_ _Float_ values into _BiFloat16_
when sending to the _TPU_ device.
when sending to the _TPU_ device. Note that when using `XLA_USE_BF16=1` tensor arithmetic will
be done in reduced precision and so tensors will not be accurate if accumulated over time.
For example:

```
# In reduced bfloat16 precision
>>> torch.tensor(4096, dtype=torch.bfloat16) + torch.tensor(1, dtype=torch.bfloat16)
tensor(4096., dtype=torch.bfloat16)
# Whereas in full float32 precision
>>> torch.tensor(4096) + torch.tensor(1)
tensor(4097)
```
So to get accurate metrics such as average loss value over many steps, use manual mixed
precision where metrics stay in FP32.

* ```XLA_USE_F16```: If set to 1, tranforms all the _PyTorch_ _Float_ values into _Float16_
(_PyTorch_ _Half_ type) when sending to devices which supports them.
Expand Down

0 comments on commit ef05476

Please sign in to comment.