forked from pytorch/xla
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from intelligent-machine-learning/pin_2014_01_08
Pin 2024 01 08
- Loading branch information
Showing
28 changed files
with
243 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# CUDA PJRT plugin (experimental) | ||
|
||
This directory contains an experimental implementation of the PJRT GPU client as | ||
a plugin. The actual implementation of the PJRT C API lives in the main OpenXLA | ||
repository (see `bazel build` command below). | ||
|
||
## Building | ||
|
||
```bash | ||
# Build PJRT plugin | ||
bazel build @xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 --config=cuda | ||
# Copy to package dir | ||
cp bazel-bin/external/xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so plugins/cuda/torch_xla_cuda_plugin | ||
|
||
# Build wheel | ||
pip wheel plugins/cuda | ||
# Or install directly | ||
pip install plugins/cuda | ||
``` | ||
|
||
## Usage | ||
|
||
```python | ||
import os | ||
|
||
# Log device type | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' | ||
os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' | ||
|
||
from torch_xla.experimental import plugins | ||
import torch_xla_cuda_plugin | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
|
||
# Use dynamic plugin instead of built-in CUDA support | ||
plugins.use_dynamic_plugins() | ||
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.GpuPlugin()) | ||
xr.set_device_type('CUDA') | ||
|
||
print(xm.xla_device()) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[project] | ||
name = "torch_xla_cuda_plugin" | ||
version = "0.0.1" | ||
authors = [ | ||
{name = "Will Cromar", email = "[email protected]"}, | ||
] | ||
description = "CUDA Plugin" | ||
requires-python = ">=3.8" | ||
|
||
[tool.setuptools.package-data] | ||
torch_xla_cuda_plugin = ["*.so"] | ||
|
||
[project.entry-points."torch_xla.plugins"] | ||
gpu = "torch_xla_cuda_plugin:GpuPlugin" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import os | ||
from torch_xla.experimental import plugins | ||
from torch_xla._internal import tpu | ||
|
||
class GpuPlugin(plugins.DevicePlugin): | ||
def library_path(self) -> str: | ||
return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so') | ||
|
||
def physical_chip_count(self) -> int: | ||
# TODO: default to actual device count | ||
return os.getenv('GPU_NUM_DEVICES', 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from absl.testing import absltest, parameterized | ||
import torch | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
|
||
|
||
class TestDtypes(parameterized.TestCase): | ||
|
||
@parameterized.parameters(torch.float16, torch.float32, torch.float64, | ||
torch.bfloat16, torch.complex64) | ||
def test_float_round_trip(self, dtype: torch.dtype): | ||
t = torch.randn((3, 3), dtype=dtype) | ||
xt = t.to(xm.xla_device()) | ||
torch.testing.assert_close(xt.cpu(), t) | ||
|
||
@parameterized.parameters( | ||
torch.uint8, | ||
torch.int8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
) | ||
def test_int_round_trip(self, dtype: torch.dtype): | ||
t = torch.randint(0, 128, (3, 3), dtype=dtype) | ||
xt = t.to(xm.xla_device()) | ||
torch.testing.assert_close(xt.cpu(), t) | ||
|
||
def test_bool_round_trip(self): | ||
t = torch.randint(0, 2, (3, 3), dtype=torch.bool) | ||
xt = t.to(xm.xla_device()) | ||
torch.testing.assert_close(xt.cpu(), t) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import sys | ||
import torch | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.distributed.parallel_loader as pl | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch_xla.utils.utils as xu | ||
|
||
|
||
def _mp_fn(index): | ||
device = xm.xla_device() | ||
if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'): | ||
train_loader = xu.SampleGenerator( | ||
data=torch.zeros(1, 12), sample_count=1024) | ||
train_loader = pl.MpDeviceLoader(train_loader, device) | ||
max_steps = 10 | ||
for step, inputs in enumerate(train_loader): | ||
xm.all_reduce('sum', [inputs], scale=1.0 / xm.xrt_world_size()) | ||
if step > max_steps: | ||
break | ||
else: | ||
print(f'{device} is not a TPU or GPU device', file=sys.stderr) | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.