Skip to content

Commit

Permalink
Enable some dynamic shapes tests. (pytorch#1434)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlibenzi authored and ailzhang committed Dec 2, 2019
1 parent 66fa4cd commit f66bf94
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 74 deletions.
8 changes: 4 additions & 4 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,17 @@ only be enabled for debugging.
* ```XLA_METRICS_FILE```: If set, the path to a local file where the internal metrics will be
saved at every step. Metrics will be appended to the file, if already existing.

* ```GET_TENSORS_OPBYOP```: Enables pure _OpByOp_ dispatch. The _PyTorch/XLA_ software tries to
* ```XLA_GET_TENSORS_OPBYOP```: Enables pure _OpByOp_ dispatch. The _PyTorch/XLA_ software tries to
fuse together many _PyTorch_ operations into a single computation graph, but sometimes, either
for debugging, or in case the _PyTorch_ code have a very dynamic nature (in shapes or graph
terms), it is better to force the execution in _OpByOp_ mode (every IR node is lowered into
a separate _XLA_ computation, and chain-executed). This environment variable, if set to 1,
enables _OpByOp_ during the "get tensors" operation (the operation used by _PyTorch/XLA_ to
fetch intermediate values back from the _TPU_ device into _PyTorch_ CPU tensors).

* ```SYNC_TENSORS_OPBYOP```: The same as _GET_TENSORS_OPBYOP_ but for "sync tensors" operation
(the operation used at the end of a step, to flush pending IR computations and materialize
them into _TPU_ device data).
* ```XLA_SYNC_TENSORS_OPBYOP```: The same as _XLA_GET_TENSORS_OPBYOP_ but for "sync tensors"
operation (the operation used at the end of a step, to flush pending IR computations and
materialize them into _TPU_ device data).

* ```XLA_SYNC_WAIT```: Forces the XLA tensor sync operation to wait for its completion, before
moving to the next step.
Expand Down
2 changes: 2 additions & 0 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class ArgTemplate(string.Template):
'kthvalue_out': FuncOpts(),
'index_select_out': FuncOpts(),
'log_out': FuncOpts(),
'masked_select_out': FuncOpts(),
'nonzero_out': FuncOpts(),
'take_out': FuncOpts(),
'topk_out': FuncOpts(),
}
Expand Down
50 changes: 44 additions & 6 deletions test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,20 @@ at::Tensor ToCpuTensor(const at::Tensor& t) {
return t.to(torch::kCPU);
}

torch::Tensor CopyToDevice(torch::Tensor t, const torch::Device& device) {
return t.to(device, /*non_blocking=*/false, /*copy=*/true);
}

bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) {
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);
if (tensor1.sizes() != tensor2.sizes() ||
tensor1.dtype() != tensor2.dtype()) {
std::cerr << "Different shape:\n"
<< tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
<< tensor2.dtype() << " " << tensor2.sizes() << "\n";
return false;
}
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);

at::ScalarType type1 = tensor1.scalar_type();
at::ScalarType type2 = tensor2.scalar_type();
if (type1 != type2) {
Expand All @@ -83,15 +86,14 @@ bool EqualValues(at::Tensor tensor1, at::Tensor tensor2) {
}

bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2) {
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);
if (tensor1.sizes() != tensor2.sizes()) {
std::cerr << "Different shape:\n"
<< tensor1.dtype() << " " << tensor1.sizes() << "\n-vs-\n"
<< tensor2.dtype() << " " << tensor2.sizes() << "\n";
return false;
}
tensor1 = ToCpuTensor(tensor1);
tensor2 = ToCpuTensor(tensor2);

at::ScalarType type1 = tensor1.scalar_type();
at::ScalarType type2 = tensor2.scalar_type();
if (type1 != type2) {
Expand Down Expand Up @@ -227,5 +229,41 @@ std::vector<at::Tensor> ExecuteAndFetch(
return Fetch(results);
}

void TestBackward(
const std::vector<torch::Tensor>& inputs, const torch::Device& device,
const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
testfn,
double rtol, double atol) {
std::vector<torch::Tensor> input_vars;
std::vector<torch::Tensor> xinput_vars;
for (size_t i = 0; i < inputs.size(); ++i) {
const torch::Tensor& input = inputs[i];
if (input.defined()) {
input_vars.push_back(
input.clone().detach().set_requires_grad(input.requires_grad()));

torch::Tensor xinput = CopyToDevice(input, device)
.detach()
.set_requires_grad(input.requires_grad());
xinput_vars.push_back(xinput);
} else {
input_vars.emplace_back();
xinput_vars.emplace_back();
}
}

torch::Tensor output = testfn(input_vars);
torch::Tensor xoutput = testfn(xinput_vars);
AllClose(output, xoutput, rtol, atol);
output.backward(torch::ones_like(output));
xoutput.backward(torch::ones_like(xoutput));
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].defined() && inputs[i].requires_grad()) {
ASSERT_TRUE(xinput_vars[i].grad().defined());
AllClose(input_vars[i].grad(), xinput_vars[i].grad(), rtol, atol);
}
}
}

} // namespace cpp_test
} // namespace torch_xla
9 changes: 9 additions & 0 deletions test/cpp/cpp_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ const std::unordered_set<std::string>* GetIgnoredCounters();
// on both sides.
at::Tensor ToCpuTensor(const at::Tensor& t);

// Helper function to copy a tensor to device.
torch::Tensor CopyToDevice(torch::Tensor t, const torch::Device& device);

bool EqualValues(at::Tensor tensor1, at::Tensor tensor2);

bool EqualValuesNoElementTypeCheck(at::Tensor tensor1, at::Tensor tensor2);
Expand Down Expand Up @@ -82,5 +85,11 @@ std::vector<at::Tensor> Fetch(
std::vector<at::Tensor> ExecuteAndFetch(
tensorflow::gtl::ArraySlice<const ir::Value> roots, const Device& device);

void TestBackward(
const std::vector<torch::Tensor>& inputs, const torch::Device& device,
const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
testfn,
double rtol = 1e-5, double atol = 1e-8);

} // namespace cpp_test
} // namespace torch_xla
9 changes: 8 additions & 1 deletion test/cpp/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ FILTER=
BUILD_ONLY=0
RMBUILD=1
LOGFILE=/tmp/pytorch_cpp_test.log
XLA_EXPERIMENTAL="nonzero:masked_select"

if [ "$DEBUG" == "1" ]; then
BUILDTYPE="Debug"
fi

while getopts 'VLDKBF:' OPTION
while getopts 'VLDKBF:X:' OPTION
do
case $OPTION in
V)
Expand All @@ -34,10 +35,15 @@ do
F)
FILTER="--gtest_filter=$OPTARG"
;;
X)
XLA_EXPERIMENTAL="$OPTARG"
;;
esac
done
shift $(($OPTIND - 1))

export XLA_EXPERIMENTAL

rm -rf "$BUILDDIR"
mkdir "$BUILDDIR" 2>/dev/null
pushd "$BUILDDIR"
Expand All @@ -46,6 +52,7 @@ cmake "$RUNDIR" \
-DPYTHON_INCLUDE_DIR=$(python -c "from distutils.sysconfig import get_python_inc; print(get_python_inc())") \
-DPYTHON_LIBRARY=$(python -c "import distutils.sysconfig as sysconfig; print(sysconfig.get_config_var('LIBDIR') + '/' + sysconfig.get_config_var('LDLIBRARY'))")
make -j $VERB

if [ $BUILD_ONLY -eq 0 ]; then
if [ "$LOGFILE" != "" ]; then
./test_ptxla ${FILTER:+"$FILTER"} 2>$LOGFILE
Expand Down
72 changes: 20 additions & 52 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,18 @@
#include "cpp_test_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_client/metrics.h"
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/torch_util.h"
#include "torch_xla_test.h"

namespace torch_xla {
namespace cpp_test {
namespace {

class AtenXlaTensorTest : public AtenXlaTensorTestBase {};

// Helper function to copy a tensor to device.
torch::Tensor CopyToDevice(torch::Tensor t, torch::Device device) {
return t.to(device, /*non_blocking=*/false, /*copy=*/true);
}

void TestBackward(
const std::vector<torch::Tensor>& inputs, const torch::Device& device,
const std::function<torch::Tensor(const std::vector<torch::Tensor>&)>&
testfn,
double rtol = 1e-5, double atol = 1e-8) {
std::vector<torch::Tensor> input_vars;
std::vector<torch::Tensor> xinput_vars;
for (size_t i = 0; i < inputs.size(); ++i) {
const torch::Tensor& input = inputs[i];
if (input.defined()) {
input_vars.push_back(
input.clone().detach().set_requires_grad(input.requires_grad()));

torch::Tensor xinput = CopyToDevice(input, device)
.detach()
.set_requires_grad(input.requires_grad());
xinput_vars.push_back(xinput);
} else {
input_vars.emplace_back();
xinput_vars.emplace_back();
}
}

torch::Tensor output = testfn(input_vars);
torch::Tensor xoutput = testfn(xinput_vars);
AllClose(output, xoutput, rtol, atol);
output.backward(torch::ones_like(output));
xoutput.backward(torch::ones_like(xoutput));
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].defined() && inputs[i].requires_grad()) {
ASSERT_TRUE(xinput_vars[i].grad().defined());
AllClose(input_vars[i].grad(), xinput_vars[i].grad(), rtol, atol);
}
}
}
} // namespace

TEST_F(AtenXlaTensorTest, TestScalarTensor) {
torch::Tensor scalar_tensor =
Expand Down Expand Up @@ -3808,13 +3771,15 @@ TEST_F(AtenXlaTensorTest, TestNonzero) {
torch::Tensor xla_a = CopyToDevice(a, device);
torch::Tensor xla_b = torch::nonzero(xla_a);
AllClose(b, xla_b);
});

if (DebugUtil::ExperimentEnabled("nonzero")) {
// If the nonzero support is enabled, we must not see any aten:: calls.
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}
ExpectCounterChanged("xla::nonzero", cpp_test::GetIgnoredCounters());
if (DebugUtil::ExperimentEnabled("nonzero") &&
bridge::AtenDeviceToXlaDevice(device).hw_type == DeviceType::TPU) {
// If the nonzero support is enabled, we must not see any aten:: calls.
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}
ExpectCounterChanged("xla::nonzero", cpp_test::GetIgnoredCounters());
ResetCounters();
});
}

TEST_F(AtenXlaTensorTest, TestMaskedSelect) {
Expand All @@ -3827,13 +3792,16 @@ TEST_F(AtenXlaTensorTest, TestMaskedSelect) {
torch::Tensor xla_b = CopyToDevice(b, device);
torch::Tensor xla_c = torch::masked_select(xla_a, xla_b);
AllClose(c, xla_c);
});

if (DebugUtil::ExperimentEnabled("masked_select")) {
// If the nonzero support is enabled, we must not see any aten:: calls.
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}
ExpectCounterChanged("xla::masked_select", cpp_test::GetIgnoredCounters());
if (DebugUtil::ExperimentEnabled("masked_select") &&
bridge::AtenDeviceToXlaDevice(device).hw_type == DeviceType::TPU) {
// If the masked_select support is enabled, we must not see any aten::
// calls.
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
}
ExpectCounterChanged("xla::masked_select", cpp_test::GetIgnoredCounters());
ResetCounters();
});
}

TEST_F(AtenXlaTensorTest, TestMultiIndexHeadNull) {
Expand Down
5 changes: 5 additions & 0 deletions test/cpp/torch_xla_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ void XlaTest::ExpectCounterChanged(
EXPECT_TRUE(!changed.empty());
}

void XlaTest::ResetCounters() {
start_msnap_ = std::move(end_msnap_);
end_msnap_ = nullptr;
}

void XlaTest::MakeEndSnapshot() {
if (end_msnap_ == nullptr) {
end_msnap_ = absl::make_unique<MetricsSnapshot>();
Expand Down
2 changes: 2 additions & 0 deletions test/cpp/torch_xla_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class XlaTest : public ::testing::Test {
void ExpectCounterChanged(const std::string& counter_regex,
const std::unordered_set<std::string>* ignore_set);

void ResetCounters();

private:
void MakeEndSnapshot();

Expand Down
22 changes: 13 additions & 9 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,33 @@ export PYTORCH_TEST_WITH_SLOW=1

function run_opbyop {
echo "Running in OpByOp mode ..."
GET_TENSORS_OPBYOP=1 SYNC_TENSORS_OPBYOP=1 "$@"
XLA_GET_TENSORS_OPBYOP=1 XLA_SYNC_TENSORS_OPBYOP=1 "$@"
}

function run_dynamic {
XLA_EXPERIMENTAL="nonzero:masked_select" "$@"
}

if [ "$LOGFILE" != "" ]; then
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA 2>&1 | tee $LOGFILE
run_dynamic python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTensorDeviceOpsXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA 2>&1 | tee $LOGFILE
python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
run_dynamic python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA 2>&1 | tee $LOGFILE
run_dynamic python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA 2>&1 | tee $LOGFILE
run_dynamic python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY 2>&1 | tee $LOGFILE
python3 "$CDIR/test_mp_replication.py" "$@" 2>&1 | tee $LOGFILE
else
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA
run_dynamic python3 "$CDIR/../../test/test_torch.py" "$@" -v TestDevicePrecisionXLA
python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTensorDeviceOpsXLA
python3 "$CDIR/../../test/test_indexing.py" "$@" -v TestIndexingXLA
python3 "$CDIR/../../test/test_indexing.py" "$@" -v NumpyTestsXLA
python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA
python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA
python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_dynamic python3 "$CDIR/../../test/test_nn.py" "$@" -v TestNNDeviceTypeXLA
run_dynamic python3 "$CDIR/../../test/test_type_promotion.py" "$@" -v TestTypePromotionXLA
run_dynamic python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
run_opbyop python3 "$CDIR/test_operations.py" "$@" --verbosity=$VERBOSITY
python3 "$CDIR/test_mp_replication.py" "$@"
fi
4 changes: 2 additions & 2 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -855,7 +855,7 @@ std::vector<at::Tensor> XLATensor::GetTensorsOpByOp(

std::vector<at::Tensor> XLATensor::GetTensors(std::vector<XLATensor>* tensors) {
static const bool op_by_op =
xla::sys_util::GetEnvBool("GET_TENSORS_OPBYOP", false);
xla::sys_util::GetEnvBool("XLA_GET_TENSORS_OPBYOP", false);
return op_by_op ? GetTensorsOpByOp(tensors) : GetTensorsFused(tensors);
}

Expand Down Expand Up @@ -1148,7 +1148,7 @@ void XLATensor::SyncTensorsGraph(
tensorflow::gtl::ArraySlice<const std::string> devices, bool wait,
bool sync_xla_data) {
static const bool op_by_op =
xla::sys_util::GetEnvBool("SYNC_TENSORS_OPBYOP", false);
xla::sys_util::GetEnvBool("XLA_SYNC_TENSORS_OPBYOP", false);
SyncTensorsConfig config;
config.sync_xla_data = sync_xla_data;
if (op_by_op) {
Expand Down

0 comments on commit f66bf94

Please sign in to comment.