Note that the information in this section is subject to be removed in future releases of the PyTorch/XLA software, since many of them are peculiar to a given internal implementation which might change.
To diagnose issues, we can use the execution metrics and counters provided by PyTorch/XLA The first thing to check when model is slow is to generate a metrics report.
Metrics report is extremely helpful in diagonsing issues. Please try to include it in your bug report sent to us if you have it.
Put the following line in your program to generate a report:
import torch_xla.debug.metrics as met
print(met.metrics_report())
The report includes things like:
- how many time we issue XLA compilations and time spent on issuing.
- how many times we execute and time spent on execution
- how many device data handles we create/destroy etc.
This information is reported in terms of percentiles of the samples. An example is:
Metric: CompileTime
TotalSamples: 202
Counter: 06m09s401ms746.001us
ValueRate: 778ms572.062us / second
Rate: 0.425201 / second
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
We also provide counters, which are named integer variables which track internal software status. For example:
Counter: CachedSyncTensors
Value: 395
In this report, any counter that starts with aten::
indicates a context switch between the XLA device and CPU, which can be a
potential performance optimization area in the model code.
Counters are useful to understand which operations are routed back to the CPU engine of PyTorch. They are fully qualified with their C++ namespace:
Counter: aten::nonzero
Value: 33
If you see aten::
ops other than nonzero
and _local_scalar_dense
, that usually means a missing
lowering in PyTorch/XLA. Feel free to open a feature request for it on GitHub issues.
PyTorch/XLA behaves semantically like regular PyTorch and XLA tensors share the full tensor interface with CPU & GPU tensors. However, constraints in XLA/hardware and the lazy evaluation model suggest certain patterns might result in bad performance.
If your model shows bad performance, keep in mind the following caveats:
-
XLA/TPU yield degraded performance with too many recompilations.
XLA compilation is expensive. PyTorch/XLA automatically recompiles the graph every time new shapes are encountered. Usually models should stabilize within a few steps and you can see huge speedup for the rest of training.
In order to avoid recompilations, not only shapes must be constant, but also computations accross XLA devices in all hosts.
Possible sources:
- Direct or indirect uses of
nonzero
introduce dynamic shapes; for example, masked indexingbase[index]
whereindex
is a mask tensor. - Loops with a different number of iterations between steps can result in different execution graphs, thus require recompilations.
Solution:
- Tensor shapes should be the same between iterations, or a low number of shape variations should be used.
- Pad tensors to fixed sizes when possible.
- Direct or indirect uses of
-
Certain operations don't have native translations to XLA.
For these operations PyTorch/XLA automatically transfer to the CPU memory, evaluate on CPU, and transfer the result back to XLA device. Doing too many such operations during the training step can lead to significant slowdowns.
Possible sources:
- The
item()
operation explicitly asks to evaluate the result. Don't use it unless it's necessary.
Solution:
- For most ops we can lower them to XLA to fix it. Checkout metrics report section to find out the missing ops and open a feature request on GitHub.
- Even when a PyTorch tensor is known as a scalar, avoid using
tensor.item()
. Keep it as a tensor and use tensor operations on it. - Use
torch.where
to substitute control flow when applicable. E.g. The control flow withitem()
used in clip_grad_norm_ can be simply replaced bytorch.where
with dramatical performance improvement.... else: device = parameters[0].device total_norm = torch.zeros([], device=device if parameters else None) for p in parameters: param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
- The
-
Iterators in
torch_xla.distributed.data_parallel
may drop the last few batches in the input iterator.This is to make sure we do the same amount of work on all XLA devices.
Solution:
- When dataset is small, and there are too few steps, this may result in a no-op epoch. Therefore, it is better to use small batch sizes in those cases.
-
XLA tensor internals are opaque. XLA tensors always appear to be contiguous and without storage. Networks should not try to check the strides of XLA tensors.
-
XLA tensors should be moved to the CPU before saving them. Saving XLA tensors directly causes them to be loaded back on the device(s) they were saved from. If a device is unavailable at load time then the load will fail. Moving XLA tensors to the CPU before saving them lets you decide which device(s) to put the loaded tensors on. This is necessary if you want to load the tensors on a machine without XLA devices. Care should be taken moving the XLA tensors to the CPU before saving them, however, as moving tensors across device types does not preserve view relationships. Instead, views should be reconstructed as necessary after the tensors are loaded.
-
Copying an XLA Tensor with Python's copy.copy returns a deep copy, not a shallow copy. Use a view of an XLA tensor to get a shallow copy of it.
We don't expect users to use tools in this section to debug their models. But we might ask for them when you submit a bug report since they provide additional information that metrics report doesn't have.
There are also a number of environment variables which control the behavior of the PyTorch/XLA software stack.
Setting such variables will cause different degrees of performance degradation, so they should only be enabled for debugging.
-
XLA_IR_DEBUG
: Enables the Python stack trace to be catpured where creating IR nodes, hence allowing to understand which PyTorch operation was responsible of generating such IR. -
XLA_HLO_DEBUG
: Enables the Python stack frame captured when XLA_IR_DEBUG is active, to be propagated to the XLA HLO metadata. -
XLA_SAVE_TENSORS_FILE
: The path to a file which will be used to dump the IR graphs during execution. Note that the file can become really big if the option is left enabled and the PyTorch program let run for long time. The graphs are appended to the file, so to have a clean sheet from run to run, the file should be explicitly removed. -
XLA_SAVE_TENSORS_FMT
: The format of the graphs stored within the XLA_SAVE_TENSORS_FILE file. Can betext
(the default),dot
(the Graphviz format) orhlo
. -
XLA_METRICS_FILE
: If set, the path to a local file where the internal metrics will be saved at every step. Metrics will be appended to the file, if already existing. -
GET_TENSORS_OPBYOP
: Enables pure OpByOp dispatch. The PyTorch/XLA software tries to fuse together many PyTorch operations into a single computation graph, but sometimes, either for debugging, or in case the PyTorch code have a very dynamic nature (in shapes or graph terms), it is better to force the execution in OpByOp mode (every IR node is lowered into a separate XLA computation, and chain-executed). This environment variable, if set to 1, enables OpByOp during the "get tensors" operation (the operation used by PyTorch/XLA to fetch intermediate values back from the TPU device into PyTorch CPU tensors). -
SYNC_TENSORS_OPBYOP
: The same as GET_TENSORS_OPBYOP but for "sync tensors" operation (the operation used at the end of a step, to flush pending IR computations and materialize them into TPU device data). -
XLA_SYNC_WAIT
: Forces the XLA tensor sync operation to wait for its completion, before moving to the next step. -
XLA_USE_BF16
: If set to 1, tranforms all the PyTorch Float values into BiFloat16 when sending to the TPU device. -
XLA_USE_32BIT_LONG
: If set to 1, maps PyTorch Long types to XLA 32bit type. On the versions of the TPU HW at the time of writing, 64bit integer computations are expensive, so setting this flag might help. It should be verified by the user that truncating to 32bit values is a valid operation according to the use of PyTorch Long values in it. -
TF_CPP_LOG_THREAD_ID
: If set to 1, the TF logs will show the thread ID helping with debugging multithreaded processes. -
TF_CPP_VMODULE
: Environment variable used for TF VLOGs and takes the form ofTF_CPP_VMODULE=name=value,...
. For PyTorch/XLA using a configuration likeTF_CPP_VMODULE=tensor=5
would enable logging such as:2019-10-03 17:23:56.419040: I 27891 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 4211381954965020633 on device TPU:3 done! 2019-10-03 17:23:56.419448: I 27890 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 15483856951158150605 on device TPU:5 done! 2019-10-03 17:23:56.419539: I 27896 torch_xla/csrc/tensor.cpp:1104] Executing IR graph hash 4211381954965020633 on device TPU:4 done! ...
In the event that the PyTorch process is hanging, it might be useful to include the stack traces together with the GitHub issue.
First thing is to find out which PID the PyTorch process is associated with. Using the ps
command it is possible to find that information. It will be a python process running your
main python file.
In order to allow GDB to attach a user process the following command should be run as root:
echo 0 > /proc/sys/kernel/yama/ptrace_scope
The above command remains active until the machine is rebooted.
The, given the PID, it is possible to grab the stack traces with the following command:
./scripts/dump_stacks.py PID > /tmp/stack-traces.log