Skip to content

Commit

Permalink
Refactor docs. (pytorch#1095)
Browse files Browse the repository at this point in the history
* Move sections around to make it ready for refactoring docs.

* Refactor troubleshooting.md.
  • Loading branch information
ailzhang authored Sep 26, 2019
1 parent edfe84f commit fe7f594
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 218 deletions.
49 changes: 4 additions & 45 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Performance guideline for running on TPUs
# PyTorch/XLA API And Best Practice

## XLA tensors
## XLA Tensors

PyTorch/XLA adds a new device, similarly to CPU and GPU devices. The following snippet creates an XLA tensor filled with random values, then prints the device and the contents of the tensor:

Expand Down Expand Up @@ -132,47 +132,6 @@ The same multi-core API can be used to run on a single core as well by setting t

Check the [full example](https://github.com/pytorch/xla/blob/master/test/test_train_mnist.py) showing how to train MNIST on TPU using `torch_xla.distributed.data_parallel.DataParallel` (Python threading).

## Performance caveats

PyTorch/XLA behaves semantically like regular PyTorch and XLA tensors, implementing the full tensor interface. However, constraints in XLA and hardware, and the lazy evaluation model mean some patterns must be avoided:

1. Tensor shapes should be the same between iterations, or a low number of shape variations should be used. PyTorch/XLA automatically recompiles the graph every time new shapes are encountered. This means that, if the shapes don’t stabilize during training, more time will be spent compiling than running the model. Pad tensors to fixed sizes when possible. Direct or indirect uses of `nonzero` introduce dynamic shapes; for example, masked indexing `base[index]` where `index` is a mask tensor.
1. Certain operations don’t have native translations to XLA and therefore require transfer to the CPU memory, evaluation on CPU, and transfer of the result back to the XLA device. This is automatically handled by PyTorch/XLA, but doing too many such operations during the training step can lead to significant slowdowns. The `item()` operation is one such example and it is used in [clip_grad_norm_](https://github.com/pytorch/pytorch/blob/de19eeee99a2a282fc441f637b23d8e50c75ecd1/torch/nn/utils/clip_grad.py#L33). Below is an alternative implementation which avoids the need for `item()`:

```python
...
else:
device = parameters[0].device
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6)
for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
```


1. In order to avoid recompilations, not only shapes must be constant, but also computations accross XLA devices in all hosts. A special case of this is loops with a different number of iterations between steps. PyTorch/XLA automatically handles them, but they are seen as different execution graphs and require recompilations.

1. Iterators in `torch_xla.distributed.data_parallel` may drop the
last few batches in the input iterator, in order to do the same amount of work
on all XLA devices. In the extreme case where dataset is small, and there are
too few steps, this may result in a no-op epoch. Therefore, it is better to use
small batch sizes in those cases.

1. Even when it's known that a PyTorch tensor is a scalar, avoid using
`tensor.item()`. Prefer instead keeping it as a tensor and the use of tensor
operations on it, using control flow substitutes such as `torch.where`.
Following the latter approach will likely result in those operations behind
fully fused within an XLA graph, without the need of issuing separate TPU
computations. This can dramatically improve performance of the model, up to
an N factor, where N is the number of `tensor.item()` calls per step.

`print(torch_xla._XLAC._xla_metrics_report())` can be used to print metrics at the end of each step to collect information regarding the number of compilations and operators that are part of the model but don’t have native XLA implementations. The `XLA_METRICS_FILE=/PATH/TO/FILE` environment setting can also be used to export per step metrics to a file.

In this report, any counter that starts with `aten::`
indicates a context switch between the XLA device and CPU, which can be a
potential performance optimization area in the model code.
## Performance And Debugging

Model is still running slow after many iterations? Check out [troubleshooting guide](TROUBLESHOOTING.md) for tips about how to debug them!
62 changes: 60 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## C++ style guide
## C++ Style Guide

`pytorch/xla` uses `clang-format-7` with a customized style config.
If your PR touches the C++ source files, please run the following command before submmiting a PR.
Expand All @@ -10,7 +10,7 @@ clang-format-7 -i -style /PATH/TO/foo.cpp
find -name '*.cpp' -o -name '*.h' | xargs clang-format-7 -i -style=file
```

## Python style guide
## Python Style Guide

`pytorch/xla` uses `yapf` with a customized style config.
If your PR touches the Python source files, please run the following command before submmiting a PR.
Expand All @@ -19,3 +19,61 @@ If your PR touches the Python source files, please run the following command bef
#TODO:
```

## Build From Source

* Clone `pytorch/xla` into `pytorch/pytorch` in the following structure:

```Shell
pytorch/ # pytorch/pytorch repo
xla/ # pytorch/xla repo
torch/
...
```

* Apply PyTorch patches:

```Shell
xla/scripts/apply_patches.sh
```

* Install the Lark parser used for automatic code generation:

```Shell
pip install lark-parser
```

* Currently _PyTorch_ does not build with _GCC_ 6.x, 7.x, and 8.x (various kind of ICEs). _CLANG_ 7.x is known to be working, so install that in your VM:

```Shell
sudo apt-get install clang-7 clang++-7
export CC=clang-7 CXX=clang++-7
```

You may need to add the following line to your _/etc/apt/sources.list_ file:

```Shell
deb http://deb.debian.org/debian/ testing main
```

And run the following command before trying again to install _CLANG_:

```Shell
sudo apt-get update
```

* Build _PyTorch_ from source following the regular [instructions](https://github.com/pytorch/pytorch#from-source).

```Shell
python setup.py install
```

* Install Bazel following the [instructions](https://docs.bazel.build/versions/master/install.html). You should only install version 0.24.1, as no older nor newer releases will be able to build the required dependencies.

* Build the _PyTorch/XLA_ source:

```Shell
cd xla/
python setup.py install
```


180 changes: 9 additions & 171 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
[![CircleCI](https://circleci.com/gh/pytorch/xla.svg?style=svg)](https://circleci.com/gh/pytorch/xla)

# How to Run PyTorch with TPUs

First, create your [TPU](https://pantheon.corp.google.com/compute/tpus) node with the corresponding release you wish to consume (TPU software version: `pytorch-0.1`):
Expand Down Expand Up @@ -172,59 +174,11 @@ To build from source:
xla/scripts/build_torch_wheels.sh
```

## Building manually

* If a file named xla/.torch_commit_id exists, use its content to checkout the PyTorch commit ID:

```Shell
git checkout $(cat xla/.torch_commit_id)
```

* Apply PyTorch patches:
## Building Manually

```Shell
xla/scripts/apply_patches.sh
```
Please refer to [contribution guide](CONTRIBUTING.md) for instructions to build from source.

* Install the Lark parser used for automatic code generation:

```Shell
pip install lark-parser
```

* Currently _PyTorch_ does not build with _GCC_ 6.x, 7.x, and 8.x (various kind of ICEs). _CLANG_ 7.x is known to be working, so install that in your VM:

```Shell
sudo apt-get install clang-7 clang++-7
export CC=clang-7 CXX=clang++-7
```

You may need to add the following line to your _/etc/apt/sources.list_ file:

```Shell
deb http://deb.debian.org/debian/ testing main
```

And run the following command before trying again to install _CLANG_:

```Shell
sudo apt-get update
```

* Build _PyTorch_ from source following the regular [instructions](https://github.com/pytorch/pytorch#from-source).

```Shell
python setup.py install
```

* Install Bazel following the [instructions](https://docs.bazel.build/versions/master/install.html). You should only install version 0.24.1, as no older nor newer releases will be able to build the required dependencies.

* Build the _PyTorch/XLA_ source:

```Shell
cd xla/
python setup.py install
```
## Tests

To run the tests, follow __one__ of the options below:

Expand Down Expand Up @@ -252,129 +206,13 @@ it is suggested for you to select the _Nightly_ builds when you create a Cloud T

Then run `test/run_tests.sh` and `test/cpp/run_tests.sh` to verify the setup is working.

## PyTorch/XLA API And Best Practice

[![CircleCI](https://circleci.com/gh/pytorch/xla.svg?style=svg)](https://circleci.com/gh/pytorch/xla)

# Debugging

Sometimes bad things happen and a deeper look into the _PyTorch/TPU_ stack is necessary.
In order to do that, _PyTorch/TPU_ has a series of environment variables and function calls
which can help understading its internal behavior.

Note that the infromation in this section is subject to be removed in future releases of
the _PyTorch/TPU_ software, since many of them are peculiar to a given internal implementation
which might change.
Please check out the [API Guideline](API_GUIDE.md) for the best practices to write models to run on TPU & TPU Pod devices.

The _PyTorch/TPU_ stack keeps a series of metrics and counters during its execution, and
the following API returns a string representation of them:

```Python
torch_xla._XLAC._xla_metrics_report()
```

Printing out that information can help during the debug phases and while reporting issues.

The information included within the metrics report include things like how many time we
issue _XLA_ compilations, how long they take, how many times we execute, for how long,
how many device data handles we create/destroy, etc...
These information is reported in terms of percentiles of the samples.
An example is:

```
Metric: CompileTime
TotalSamples: 202
Counter: 06m09s401ms746.001us
ValueRate: 778ms572.062us / second
Rate: 0.425201 / second
Percentiles: 1%=001ms32.778us; 5%=001ms61.283us; 10%=001ms79.236us; 20%=001ms110.973us; 50%=001ms228.773us; 80%=001ms339.183us; 90%=001ms434.305us; 95%=002ms921.063us; 99%=21s102ms853.173us
```

The _PyTorch/TPU_ stack also has counters, which are named integer variables tracks
internal software status.
Example:

```
Counter: CachedSyncTensors
Value: 395
```

Counters are also useful to understand which operations the _PyTorch/TPU_ stack is routing
back to the CPU engine of _PyTorch_.
Things which looks like a _C++_ namespace are part of this category:

```
Counter: aten::nonzero
Value: 33
```

There are also a number of environment variables which control the behavior of the _PyTorch/TPU_
software stack.
Setting such variables will cause different degrees of performance degradation, so they should
only be enabled for debugging.

* ```XLA_IR_DEBUG```: Enables the _Python_ stack trace to be catpured where creating IR nodes,
hence allowing to understand which _PyTorch_ operation was responsible of generating such IR.

* ```XLA_HLO_DEBUG```: Enables the _Python_ stack frame captured when _XLA_IR_DEBUG_ is active,
to be propagated to the _XLA_ _HLO_ metadata.

* ```XLA_SAVE_TENSORS_FILE```: The path to a file which will be used to dump the IR graphs during
execution. Note that the file can become really big if the option is left enabled and the
_PyTorch_ program let run for long time. The graphs are appended to the file, so to have a clean
sheet from run to run, the file should be explicitly removed.

* ```XLA_SAVE_TENSORS_FMT```: The format of the graphs stored within the _XLA_SAVE_TENSORS_FILE_
file. Can be ```text``` (the default), ```dot``` (the _Graphviz_ format) or ```hlo```.

* ```XLA_METRICS_FILE```: If set, the path to a local file where the internal metrics will be
saved at every step. Metrics will be appended to the file, if already existing.

* ```GET_TENSORS_OPBYOP```: Enables pure _OpByOp_ dispatch. The _PyTorch/TPU_ software tries to
fuse together many _PyTorch_ operations into a single computation graph, but sometimes, either
for debugging, or in case the _PyTorch_ code have a very dynamic nature (in shapes or graph
terms), it is better to force the execution in _OpByOp_ mode (every IR node is lowered into
a separate _XLA_ computation, and chain-executed). This environment variable, if set to 1,
enables _OpByOp_ during the "get tensors" operation (the operation used by _PyTorch/TPU_ to
fetch intermediate values back from the _TPU_ device into _PyTorch_ CPU tensors).

* ```SYNC_TENSORS_OPBYOP```: The same as _GET_TENSORS_OPBYOP_ but for "sync tensors" operation
(the operation used at the end of a step, to flush pending IR computations and materialize
them into _TPU_ device data).

* ```XLA_SYNC_WAIT```: Forces the XLA tensor sync operation to wait for its completion, before
moving to the next step.

* ```XLA_USE_BF16```: If set to 1, tranforms all the _PyTorch_ _Float_ values into _BiFloat16_
when sending to the _TPU_ device.

* ```XLA_USE_32BIT_LONG```: If set to 1, maps _PyTorch_ _Long_ types to _XLA_ 32bit type.
On the versions of the TPU HW at the time of writing, 64bit integer computations are
expensive, so setting this flag might help. It should be verified by the user that truncating
to 32bit values is a valid operation according to the use of _PyTorch_ _Long_ values in it.

## Retrieving Stack Traces

In the event that the _PyTorch_ process is hanging, it might be useful to include the stack
traces together with the _Github_ issue.

First thing is to find out which PID the _PyTorch_ process is associated with. Using the ```ps```
command it is possible to find that information. It will be a _python_ process running your
main _python_ file.

In order to allow _GDB_ to attach a user process the following command should be run as root:

```Shell
echo 0 > /proc/sys/kernel/yama/ptrace_scope
```

The above command remains active until the machine is rebooted.

The, given the PID, it is possible to grab the stack traces with the following command:

```Shell
./scripts/dump_stacks.py PID > /tmp/stack-traces.log
```
## Troubleshooting

If you see bad performance when using PyTorch/XLA, please check out the [troubleshooting guide](TROUBLESHOOTING.md) for how to avoid common pitfalls and how to debug.

## Communication

Expand Down
Loading

0 comments on commit fe7f594

Please sign in to comment.