Skip to content

Commit

Permalink
Creates PyTorch.org doc sources (pytorch#1140)
Browse files Browse the repository at this point in the history
Updates API_GUIDE.md and TROUBLESHOOTING.md, adds placeholder doc strings on public classes and functions, adds documentation build resources and build script.
  • Loading branch information
mruberry authored Oct 6, 2019
1 parent 751028d commit 3c06ab5
Show file tree
Hide file tree
Showing 13 changed files with 572 additions and 68 deletions.
248 changes: 180 additions & 68 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
@@ -1,76 +1,113 @@
# PyTorch/XLA API And Best Practices
# PyTorch on XLA Devices

## XLA Tensors
PyTorch runs on XLA devices, like TPUs, with the
[torch_xla package](https://github.com/pytorch/xla/). This document describes
how to run your models on these devices.

PyTorch/XLA adds a new device, similarly to CPU and GPU devices. The following snippet creates an XLA tensor filled with random values, then prints the device and the contents of the tensor:
## Creating an XLA Tensor

PyTorch/XLA adds a new `xla` device type to PyTorch. This device type works just
like other PyTorch device types. For example, here's how to create and
print an XLA tensor:

```python
import torch
import torch_xla
import torch_xla.core.xla_model as xm

x = torch.randn(4, 2, device=xm.xla_device())
print(x.device)
print(x)
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
```

This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and
`xm.xla_device()` returns the current XLA device. This may be a CPU or TPU
depending on your environment.

## XLA Tensors are PyTorch Tensors

PyTorch operations can be performed on XLA tensors just like CPU or CUDA tensors.

For example, XLA tensors can be added together:

```python
t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)
```

Or matrix multiplied:

```python
print(t0.mm(t1))
```

The XLA device is not a physical device but instead stands in for either a Cloud TPU or CPU. The underlying storage for XLA tensors is a contiguous buffer in device memory and the code in the model shouldn't assume any stride.
Or used with neural network modules:

```python
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)
```

XLA Tensor doesn't support converting single tensor to half precision using `tensor.half()`. Instead, environment variable `XLA_USE_BF16` is available, which converts **all** PyTorch float values to bfloat16 when sending them to the TPU device. The conversion is totally transparent to the user, and the XLA tensors will still retain a float dtype. Similarly, when the tensor is moved back to CPU, its type will be float.
Like other device types, XLA tensors only work with other XLA tensors on the
same device. So code like

The [XLA readme](https://github.com/pytorch/xla/blob/master/README.md) describes all the options available to run on TPU or CPU.
```python
l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor
```

## Running a model
will throw an error since the torch.nn.Linear module is on the CPU.

There are different ways to run a model using the PyTorch/XLA framework.
## Running Models on XLA Devices

### Native PyTorch API
Building a new PyTorch network or converting an existing one to run on XLA
devices requires only a few lines of XLA-specific code. The following snippets
highlight these lines when running on a single device, multiple devices with XLA
multiprocessing, or multiple threads with XLA multithreading.

The simplest (but not good performing) one is to just run on one core and send the input tensors to the XLA devices manually:
### Running on a Single XLA Device

The following snippet shows a network training on a single XLA device:

```python
import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

model.train()
for data, target in train_loader:
optimizer.zero_grad()
data = data.to(device)
target = target.to(device)
output = model(data)
loss = loss_fn(output, target)
loss.backward()

xm.optimizer_step(optimizer, barrier=True)
```

The above is only running on one TPU core though, and the time spent to send data to device is serial/inline with the TPU computation.
For simple experiments, or for inference tasks which are not latency-sensitive it might be still OK, but the following methods allow for better scalability.

Note the `xm.optimizer_step(optimizer, barrier=True)` line which replaces the usual
`optimizer.step()`. This is required because of the way XLA tensors work:
operations are not executed immediately, but rather added to a graph of pending
operations which is only executed when its results are required. Using
`xm.optimizer_step(optimizer, barrier=True)` acts as an execution barrier which forces the
evaluation of the graph accumulated for a single step. Without this barrier, the
graph would only be evaluated when evaluating the accuracy of the model, which
is only done at the end of an epoch, for this example. Even for small models,
the accumulated graph would be too big to evaluate at the end of an entire
epoch.

### MultiCore

There are two ways to drive multiple TPU cores using PyTorch/XLA. One is using the `torch.multiprocessing` module (which internally spawns multiple processes), and the other is using Python threading.
The multiprocessing method should allow better performance as it gets around the Python GIL serialization, especially with model code which has a heavy Python side processing.
Note that in the MultiCore setting, a barrier is included inside the data
iterators, so there are no explicit `barrier=True` in the examples below.
This snippet highlights how easy it is to switch your model to run on XLA. The
model definition, dataloader, optimizer and training loop can work on any device.
The only XLA-specific code is a couple lines that acquire the XLA device and
step the optimizer with a <b>barrier</b>. Calling
`xm.optimizer_step(optimizer, barrier=True)` at the end of each training
iteration causes XLA to execute its current graph and update the model's
parameters. See [XLA Tensor Deep Dive](#xla-tensor-deep-dive) for more on
how XLA creates graphs and runs operations.

#### MultiCore - MultiProcessing
### Running on Multiple XLA Devices with MultiProcessing

Code for multiprocessing looks like:
PyTorch/XLA makes it easy to accelerate training by running on multiple XLA
devices. The following snippet shows how:

```python
import torch_xla.core.xla_model as xm
Expand All @@ -81,11 +118,10 @@ def _mp_fn(index):
device = xm.xla_device()
para_loader = pl.ParallelLoader(train_loader, [device])

model = MNIST()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

model.train()
for data, target in para_loader.per_device_loader(device):
optimizer.zero_grad()
output = model(data)
Expand All @@ -97,14 +133,26 @@ if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
```

One thing to remember about the multiprocessing approach is that `torch.multiprocessing` uses the `spawn` method of Python multiprocessing, which spawns completely new processes (contrary to forking) and send pickled data over pipes.
The only data which is pickled is that data passed to the target function of the `xmp.spawn()` API (the `args` argument), so if the parent process changes the global state before calling `xmp.spawn()`, such data won't be reflected into the child processes (unless passed into `args`).
There are three differences between this multidevice snippet and the previous
single device snippet:

- `xmp.spawn()` creates the processes that each run an XLA device.
- `ParallelLoader` loads the training data onto each device.
- `xm.optimizer_step(optimizer)` no longer needs a barrier. ParallelLoader
automatically creates an XLA barrier that evalutes the graph.

The model definition, optimizer definition and training loop remain the same.

Check the [full example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py) showing how to train MNIST on TPU using multiprocesing.
See the
[full multiprocessing example](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist.py)
for more on training a network on multiple XLA devices with multiprocessing.

#### MultiCore - MultiThreading
### Running on Multiple XLA Devices with MultiThreading

To run a model using the Python threading support (embedded within the `torch_xla.distributed.data_parallel.DataParallel` interface), use the following API:
Running on multiple XLA devices using processes (see above) is preferred to using
threads. If, however, you want to use threads then PyTorch/XLA has a
`DataParallel` interface. The following snippet shows the same network training
with multiple threads:

```python
import torch_xla.core.xla_model as xm
Expand All @@ -129,44 +177,108 @@ for epoch in range(1, num_epochs + 1):
model_parallel(train_loop_fn, train_loader)
```

The same multi-core API can be used to run on a single core as well by setting the device_ids argument to the selected core. Passing `[]` as `device_ids` causes the model to run using the PyTorch native CPU support.
The only differences between the multithreading and multiprocessing code are:

- Multiple devices are acquired in the same process with
`xm.get_xla_supported_devices()`.
- The model is wrapped in `dp.DataParallel` and passed both the training loop
and dataloader.

See the
[full multithreading example](https://github.com/pytorch/xla/blob/master/test/test_train_mnist.py)
for more on training a network on multiple XLA devices with multithreading.

## XLA Tensor Deep Dive

Check the [full example](https://github.com/pytorch/xla/blob/master/test/test_train_mnist.py) showing how to train MNIST on TPU using `torch_xla.distributed.data_parallel.DataParallel` (Python threading).
Using XLA tensors and devices requires changing only a few lines of code. But
even though XLA tensors act a lot like CPU and CUDA tensors their internals are
different. This section describes what makes XLA tensors unique.

## Discrepancies between PyTorch/XLA
### XLA Tensors are Lazy

PyTorch/XLA matches PyTorch eager mode user experience with a few exceptions imposed by the lazy tensor approach or implementation details.
These differences don't affect performance, but might give "unexpected" results for normal PyTorch users.
CPU and CUDA tensors launch operations immediately or <b>eagerly</b>. XLA tensors,
on the other hand, are <b>lazy</b>. They record operations in a graph until the
results are needed. Deferring execution like this lets XLA optimize it. A graph
of multiple separate operations might be fused into a single optimized
operation, for example.

We list them in this section so that users are aware. They might get fixed in the future releases and updated here.
Lazy execution is generally invisible to the caller. PyTorch/XLA automatically
constructs the graphs, sends them to XLA devices, and synchronizes when
copying data between an XLA device and the CPU. Inserting a barrier when
taking an optimizer step explicitly synchronizes the CPU and the XLA device.

1. Serialization of XLA tensors doesn't preserve view-relationship.
### XLA Tensors and bFloat16

In normal PyTorch devices like CPU/CUDA, view-relationship is preserved when you save & load tensors sharing the same underlying storage.
PyTorchXLA can use the
[bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format)
datatype when running on TPUs. In fact, PyTorchXLA handles float types
(`torch.float` and `torch.double`) differently on TPUs. This behavior is
controlled by the `XLA_USE_BF16` environment variable:

``` Python
a = torch.rand(3, 3)
b = a[0]
c = a[0:2]
```
That means loaded `b` and `c` still share the same storage. `c` is updated along with `b` and vice versa.
- By default both `torch.float` and `torch.double` are
`torch.float` on TPUs.
- If `XLA_USE_BF16` is set, then `torch.float` and `torch.double` are both
`bfloat16` on TPUs.

In XLA case, `b` and `c` are separate tensors that one doesn't change with the other.
XLA tensors on TPUs will always report their PyTorch datatype regardless of
the actual datatype they're using. This conversion is automatic and opaque.
If an XLA tensor on a TPU is moved back to the CPU it will be converted
from its actual datatype to its PyTorch datatype.

1. `torch.load()` always load XLA Tensors to the original XLA devices when it was saved.
### Memory Layout

* `map_location` is no-op for XLA Tensors. It requires `torch_xla` to load XLA checkpoints.
The internal data representation of XLA tensors is opaque to the user. They
do not expose their storage and they always appear to be contiguous, unlike
CPU and CUDA tensors. This allows XLA to adjust a tensor's memory layout for
better performance.

_Solution_:
### Moving XLA Tensors to and from the CPU

* Convert your tensors to CPU before calling `torch.save()` and move back to XLA device after `torch.load()` on CPU.
XLA tensors can be moved from the CPU to an XLA device and from an XLA device
to the CPU. If a view is moved then the data its viewing is copied to the
other device and the view relationship is not preserved. Put another way,
once data is copied to another device it has no relationship with its
previous device or any tensors on it.

### Saving and Loading XLA Tensors

XLA tensors should be moved to the CPU before saving, as in the following
snippet:

```python
import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)
```

1. `copy.copy()` returns returns a deep copy instead of shallow copy.
This lets you put the loaded tensors on any available device.

_Solution_:
* If you want shallow copy of a copy, you can use `tensor.view()` instead.
Per the above note on moving XLA tensors to the CPU, care must be taken when
working with views. Instead of saving views it's recommended that you recreate
them after the tensors have been loaded and moved to their destination device(s).

Directly saving XLA tensors is possible but not recommended. XLA
tensors are always loaded back to the device they were saved from, and if
that device is unavailable the load will fail. PyTorchXLA, like all of PyTorch,
is under active development and this behavior may change in the future.

## Performance And Debugging
## Further Reading

Model is still running slow after many iterations? Check out [troubleshooting guide](TROUBLESHOOTING.md) for tips about how to debug them!
Additional documentation is available at the
[PyTorch/XLA repo](https://github.com/pytorch/xla/). More examples of running
networks on TPUs are available
[here](https://github.com/pytorch-tpu/examples).
19 changes: 19 additions & 0 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ If your model shows bad performance, keep in mind the following caveats:
* When dataset is small, and there are too few steps, this may result in a no-op epoch. Therefore, it is better to use
small batch sizes in those cases.

## XLA Tensor Quirks

1. **XLA tensor internals are opaque.** XLA tensors always appear to be
contiguous and without storage. Networks should not try to check the strides
of XLA tensors.

1. **XLA tensors should be moved to the CPU before saving them.** Saving
XLA tensors directly causes them to be loaded back on the device(s) they were
saved from. If a device is unavailable at load time then the load will fail.
Moving XLA tensors to the CPU before saving them lets you decide which
device(s) to put the loaded tensors on. This is necessary if you want to
load the tensors on a machine without XLA devices. Care should be taken
moving the XLA tensors to the CPU before saving them, however, as moving
tensors across device types does not preserve view relationships. Instead,
views should be reconstructed as necessary after the tensors are loaded.

1. **Copying an XLA Tensor with Python's copy.copy returns a deep copy, not a
shallow copy**. Use a view of an XLA tensor to get a shallow copy of it.

## More Debugging Tools

We don't expect users to use tools in this section to debug their models. But we might ask for
Expand Down
3 changes: 3 additions & 0 deletions docs/docs_build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Installs requirements and builds HTML version of PyTorch/XLA docs.
pip install -r requirements.txt
sphinx-build -b html source build
3 changes: 3 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sphinx
m2r
-e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme
Loading

0 comments on commit 3c06ab5

Please sign in to comment.