diff --git a/TROUBLESHOOTING.md b/TROUBLESHOOTING.md index fa05dd927ae..15542a2a001 100644 --- a/TROUBLESHOOTING.md +++ b/TROUBLESHOOTING.md @@ -174,7 +174,7 @@ only be enabled for debugging. * ```XLA_METRICS_FILE```: If set, the path to a local file where the internal metrics will be saved at every step. Metrics will be appended to the file, if already existing. -* ```GET_TENSORS_OPBYOP```: Enables pure _OpByOp_ dispatch. The _PyTorch/XLA_ software tries to +* ```XLA_GET_TENSORS_OPBYOP```: Enables pure _OpByOp_ dispatch. The _PyTorch/XLA_ software tries to fuse together many _PyTorch_ operations into a single computation graph, but sometimes, either for debugging, or in case the _PyTorch_ code have a very dynamic nature (in shapes or graph terms), it is better to force the execution in _OpByOp_ mode (every IR node is lowered into @@ -182,9 +182,9 @@ only be enabled for debugging. enables _OpByOp_ during the "get tensors" operation (the operation used by _PyTorch/XLA_ to fetch intermediate values back from the _TPU_ device into _PyTorch_ CPU tensors). -* ```SYNC_TENSORS_OPBYOP```: The same as _GET_TENSORS_OPBYOP_ but for "sync tensors" operation - (the operation used at the end of a step, to flush pending IR computations and materialize - them into _TPU_ device data). +* ```XLA_SYNC_TENSORS_OPBYOP```: The same as _XLA_GET_TENSORS_OPBYOP_ but for "sync tensors" + operation (the operation used at the end of a step, to flush pending IR computations and + materialize them into _TPU_ device data). * ```XLA_SYNC_WAIT```: Forces the XLA tensor sync operation to wait for its completion, before moving to the next step. diff --git a/scripts/gen.py b/scripts/gen.py index e05521c0c7e..69bf2d451ed 100755 --- a/scripts/gen.py +++ b/scripts/gen.py @@ -97,6 +97,8 @@ class ArgTemplate(string.Template): 'kthvalue_out': FuncOpts(), 'index_select_out': FuncOpts(), 'log_out': FuncOpts(), + 'masked_select_out': FuncOpts(), + 'nonzero_out': FuncOpts(), 'take_out': FuncOpts(), 'topk_out': FuncOpts(), } diff --git a/test/cpp/cpp_test_util.cpp b/test/cpp/cpp_test_util.cpp index 4b04ddf989b..70d390813ef 100644 --- a/test/cpp/cpp_test_util.cpp +++ b/test/cpp/cpp_test_util.cpp @@ -59,7 +59,13 @@ at::Tensor ToCpuTensor(const at::Tensor& t) { return t.to(torch::kCPU); } +torch::Tensor CopyToDevice(torch::Tensor t, const torch::Device& device) { + return t.to(device, /*non_blocking=*/false, /*copy=*/true); +} + bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) { + tensor1 = ToCpuTensor(tensor1); + tensor2 = ToCpuTensor(tensor2); if (tensor1.sizes() != tensor2.sizes() || tensor1.dtype() != tensor2.dtype()) { std::cerr << "Different shape:\n" @@ -67,9 +73,6 @@ bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) { << tensor2.dtype() << " " << tensor2.sizes() << "\n"; return false; } - tensor1 = ToCpuTensor(tensor1); - tensor2 = ToCpuTensor(tensor2); - at::ScalarType type1 = tensor1.scalar_type(); at::ScalarType type2 = tensor2.scalar_type(); if (type1 != type2) { @@ -83,15 +86,14 @@ bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) { } bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) { + tensor1 = ToCpuTensor(tensor1); + tensor2 = ToCpuTensor(tensor2); if (tensor1.sizes() != tensor2.sizes()) { std::cerr << "Different shape:\n" << tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n" << tensor2.dtype() << " " << tensor2.sizes() << "\n"; return false; } - tensor1 = ToCpuTensor(tensor1); - tensor2 = ToCpuTensor(tensor2); - at::ScalarType type1 = tensor1.scalar_type(); at::ScalarType type2 = tensor2.scalar_type(); if (type1 != type2) { @@ -227,5 +229,41 @@ std::vector ExecuteAndFetch( return Fetch(results); } +void TestBackward( + const std::vector& inputs, const torch::Device& device, + const std::function&)>& + testfn, + double rtol, double atol) { + std::vector input_vars; + std::vector xinput_vars; + for (size_t i = 0; i < inputs.size(); ++i) { + const torch::Tensor& input = inputs[i]; + if (input.defined()) { + input_vars.push_back( + input.clone().detach().set_requires_grad(input.requires_grad())); + + torch::Tensor xinput = CopyToDevice(input, device) + .detach() + .set_requires_grad(input.requires_grad()); + xinput_vars.push_back(xinput); + } else { + input_vars.emplace_back(); + xinput_vars.emplace_back(); + } + } + + torch::Tensor output = testfn(input_vars); + torch::Tensor xoutput = testfn(xinput_vars); + AllClose(output, xoutput, rtol, atol); + output.backward(torch::ones_like(output)); + xoutput.backward(torch::ones_like(xoutput)); + for (size_t i = 0; i < inputs.size(); ++i) { + if (inputs[i].defined() && inputs[i].requires_grad()) { + ASSERT_TRUE(xinput_vars[i].grad().defined()); + AllClose(input_vars[i].grad(), xinput_vars[i].grad(), rtol, atol); + } + } +} + } // namespace cpp_test } // namespace torch_xla diff --git a/test/cpp/cpp_test_util.h b/test/cpp/cpp_test_util.h index cd0956a327a..567e5578693 100644 --- a/test/cpp/cpp_test_util.h +++ b/test/cpp/cpp_test_util.h @@ -34,6 +34,9 @@ const std::unordered_set* GetIgnoredCounters(); // on both sides. at::Tensor ToCpuTensor(const at::Tensor& t); +// Helper function to copy a tensor to device. +torch::Tensor CopyToDevice(torch::Tensor t, const torch::Device& device); + bool EqualValues(at::Tensor tensor1, at::Tensor tensor2); bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2); @@ -82,5 +85,11 @@ std::vector Fetch( std::vector ExecuteAndFetch( tensorflow::gtl::ArraySlice roots, const Device& device); +void TestBackward( + const std::vector& inputs, const torch::Device& device, + const std::function&)>& + testfn, + double rtol = 1e-5, double atol = 1e-8); + } // namespace cpp_test } // namespace torch_xla diff --git a/test/cpp/run_tests.sh b/test/cpp/run_tests.sh index fac05c85f19..6fbfcedf3df 100755 --- a/test/cpp/run_tests.sh +++ b/test/cpp/run_tests.sh @@ -8,12 +8,13 @@ FILTER= BUILD_ONLY=0 RMBUILD=1 LOGFILE=/tmp/pytorch_cpp_test.log +XLA_EXPERIMENTAL="nonzero:masked_select" if [ "$DEBUG" == "1" ]; then BUILDTYPE="Debug" fi -while getopts 'VLDKBF:' OPTION +while getopts 'VLDKBF:X:' OPTION do case $OPTION in V) @@ -34,10 +35,15 @@ do F) FILTER="--gtest_filter=$OPTARG" ;; + X) + XLA_EXPERIMENTAL="$OPTARG" + ;; esac done shift $(($OPTIND - 1)) +export XLA_EXPERIMENTAL + rm -rf "$BUILDDIR" mkdir "$BUILDDIR" 2>/dev/null pushd "$BUILDDIR" @@ -46,6 +52,7 @@ cmake "$RUNDIR" \ -DPYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \ -DPYTHON_LIBRARY=$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR') + '/' + sysconfig.get_config_var('LDLIBRARY'))") make -j $VERB + if [ $BUILD_ONLY -eq 0 ]; then if [ "$LOGFILE" != "" ]; then ./test_ptxla ${FILTER:+"$FILTER"} 2>$LOGFILE diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index 2ad9f064eb5..e93981d685f 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -6,55 +6,18 @@ #include "cpp_test_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" +#include "torch_xla/csrc/aten_xla_bridge.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/torch_util.h" #include "torch_xla_test.h" namespace torch_xla { namespace cpp_test { +namespace { class AtenXlaTensorTest : public AtenXlaTensorTestBase {}; -// Helper function to copy a tensor to device. -torch::Tensor CopyToDevice(torch::Tensor t, torch::Device device) { - return t.to(device, /*non_blocking=*/false, /*copy=*/true); -} - -void TestBackward( - const std::vector& inputs, const torch::Device& device, - const std::function&)>& - testfn, - double rtol = 1e-5, double atol = 1e-8) { - std::vector input_vars; - std::vector xinput_vars; - for (size_t i = 0; i < inputs.size(); ++i) { - const torch::Tensor& input = inputs[i]; - if (input.defined()) { - input_vars.push_back( - input.clone().detach().set_requires_grad(input.requires_grad())); - - torch::Tensor xinput = CopyToDevice(input, device) - .detach() - .set_requires_grad(input.requires_grad()); - xinput_vars.push_back(xinput); - } else { - input_vars.emplace_back(); - xinput_vars.emplace_back(); - } - } - - torch::Tensor output = testfn(input_vars); - torch::Tensor xoutput = testfn(xinput_vars); - AllClose(output, xoutput, rtol, atol); - output.backward(torch::ones_like(output)); - xoutput.backward(torch::ones_like(xoutput)); - for (size_t i = 0; i < inputs.size(); ++i) { - if (inputs[i].defined() && inputs[i].requires_grad()) { - ASSERT_TRUE(xinput_vars[i].grad().defined()); - AllClose(input_vars[i].grad(), xinput_vars[i].grad(), rtol, atol); - } - } -} +} // namespace TEST_F(AtenXlaTensorTest, TestScalarTensor) { torch::Tensor scalar_tensor = @@ -3808,13 +3771,15 @@ TEST_F(AtenXlaTensorTest, TestNonzero) { torch::Tensor xla_a = CopyToDevice(a, device); torch::Tensor xla_b = torch::nonzero(xla_a); AllClose(b, xla_b); - }); - if (DebugUtil::ExperimentEnabled("nonzero")) { - // If the nonzero support is enabled, we must not see any aten:: calls. - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - } - ExpectCounterChanged("xla::nonzero", cpp_test::GetIgnoredCounters()); + if (DebugUtil::ExperimentEnabled("nonzero") && + bridge::AtenDeviceToXlaDevice(device).hw_type == DeviceType::TPU) { + // If the nonzero support is enabled, we must not see any aten:: calls. + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + } + ExpectCounterChanged("xla::nonzero", cpp_test::GetIgnoredCounters()); + ResetCounters(); + }); } TEST_F(AtenXlaTensorTest, TestMaskedSelect) { @@ -3827,13 +3792,16 @@ TEST_F(AtenXlaTensorTest, TestMaskedSelect) { torch::Tensor xla_b = CopyToDevice(b, device); torch::Tensor xla_c = torch::masked_select(xla_a, xla_b); AllClose(c, xla_c); - }); - if (DebugUtil::ExperimentEnabled("masked_select")) { - // If the nonzero support is enabled, we must not see any aten:: calls. - ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); - } - ExpectCounterChanged("xla::masked_select", cpp_test::GetIgnoredCounters()); + if (DebugUtil::ExperimentEnabled("masked_select") && + bridge::AtenDeviceToXlaDevice(device).hw_type == DeviceType::TPU) { + // If the masked_select support is enabled, we must not see any aten:: + // calls. + ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters()); + } + ExpectCounterChanged("xla::masked_select", cpp_test::GetIgnoredCounters()); + ResetCounters(); + }); } TEST_F(AtenXlaTensorTest, TestMultiIndexHeadNull) { diff --git a/test/cpp/torch_xla_test.cpp b/test/cpp/torch_xla_test.cpp index d1008d3150c..0b0708d38b7 100644 --- a/test/cpp/torch_xla_test.cpp +++ b/test/cpp/torch_xla_test.cpp @@ -56,6 +56,11 @@ void XlaTest::ExpectCounterChanged( EXPECT_TRUE(!changed.empty()); } +void XlaTest::ResetCounters() { + start_msnap_ = std::move(end_msnap_); + end_msnap_ = nullptr; +} + void XlaTest::MakeEndSnapshot() { if (end_msnap_ == nullptr) { end_msnap_ = absl::make_unique(); diff --git a/test/cpp/torch_xla_test.h b/test/cpp/torch_xla_test.h index b76c83893d7..3ca97d2ff51 100644 --- a/test/cpp/torch_xla_test.h +++ b/test/cpp/torch_xla_test.h @@ -24,6 +24,8 @@ class XlaTest : public ::testing::Test { void ExpectCounterChanged(const std::string& counter_regex, const std::unordered_set* ignore_set); + void ResetCounters(); + private: void MakeEndSnapshot(); diff --git a/test/run_tests.sh b/test/run_tests.sh index 2a3cb367501..5f70909f586 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -32,29 +32,33 @@ export PYTORCH_TEST_WITH_SLOW=1 function run_opbyop { echo "Running in OpByOp mode ..." - GET_TENSORS_OPBYOP=1 SYNC_TENSORS_OPBYOP=1 "$@" + XLA_GET_TENSORS_OPBYOP=1 XLA_SYNC_TENSORS_OPBYOP=1 "$@" +} + +function run_dynamic { + XLA_EXPERIMENTAL="nonzero:masked_select" "$@" } if [ "$LOGFILE" != "" ]; then python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA 2>&1 | tee $LOGFILE - python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA 2>&1 | tee $LOGFILE + run_dynamic python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA 2>&1 | tee $LOGFILE python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTensorDeviceOpsXLA 2>&1 | tee $LOGFILE python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA 2>&1 | tee $LOGFILE python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA 2>&1 | tee $LOGFILE - python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA 2>&1 | tee $LOGFILE - python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA 2>&1 | tee $LOGFILE - python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE + run_dynamic python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA 2>&1 | tee $LOGFILE + run_dynamic python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA 2>&1 | tee $LOGFILE + run_dynamic python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE python3 "$CDIR/test_mp_replication.py" "$@" 2>&1 | tee $LOGFILE else python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA - python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA + run_dynamic python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTensorDeviceOpsXLA python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA - python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA - python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA - python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY + run_dynamic python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA + run_dynamic python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA + run_dynamic python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY python3 "$CDIR/test_mp_replication.py" "$@" fi diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index d987343ba7b..8505a435f04 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -855,7 +855,7 @@ std::vector XLATensor::GetTensorsOpByOp( std::vector XLATensor::GetTensors(std::vector* tensors) { static const bool op_by_op = - xla::sys_util::GetEnvBool("GET_TENSORS_OPBYOP", false); + xla::sys_util::GetEnvBool("XLA_GET_TENSORS_OPBYOP", false); return op_by_op ? GetTensorsOpByOp(tensors) : GetTensorsFused(tensors); } @@ -1148,7 +1148,7 @@ void XLATensor::SyncTensorsGraph( tensorflow::gtl::ArraySlice devices, bool wait, bool sync_xla_data) { static const bool op_by_op = - xla::sys_util::GetEnvBool("SYNC_TENSORS_OPBYOP", false); + xla::sys_util::GetEnvBool("XLA_SYNC_TENSORS_OPBYOP", false); SyncTensorsConfig config; config.sync_xla_data = sync_xla_data; if (op_by_op) {