Skip to content

Commit

Permalink
Allow FP16 support in tensor core.
Browse files Browse the repository at this point in the history
  • Loading branch information
dlibenzi committed Apr 20, 2020
1 parent f410267 commit b5df579
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 8 deletions.
3 changes: 3 additions & 0 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ only be enabled for debugging.
* ```XLA_USE_BF16```: If set to 1, tranforms all the _PyTorch_ _Float_ values into _BiFloat16_
when sending to the _TPU_ device.

* ```XLA_USE_F16```: If set to 1, tranforms all the _PyTorch_ _Float_ values into _Float16_
(_PyTorch_ _Half_ type) when sending to devices which supports them.

* ```XLA_USE_32BIT_LONG```: If set to 1, maps _PyTorch_ _Long_ types to _XLA_ 32bit type.
On the versions of the TPU HW at the time of writing, 64bit integer computations are
expensive, so setting this flag might help. It should be verified by the user that truncating
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla_client/xrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,8 @@ tensorflow::DataType XrtComputationClient::XlaTypeToDataType(
return tensorflow::DT_DOUBLE;
case PrimitiveType::BF16:
return tensorflow::DT_BFLOAT16;
case PrimitiveType::F16:
return tensorflow::DT_HALF;
case PrimitiveType::C64:
return tensorflow::DT_COMPLEX64;
case PrimitiveType::C128:
Expand Down
75 changes: 67 additions & 8 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,18 @@ namespace torch_xla {
namespace {

bool ShouldUseBF16() {
bool use_fp16 = xla::sys_util::GetEnvBool("XLA_USE_BF16", false);
if (use_fp16) {
bool use_bf16 = xla::sys_util::GetEnvBool("XLA_USE_BF16", false);
if (use_bf16) {
TF_LOG(INFO) << "Using BF16 data type for floating point values";
}
return use_bf16;
}

bool ShouldUseF16() {
bool use_fp16 = xla::sys_util::GetEnvBool("XLA_USE_FP16", false);
if (use_fp16) {
TF_LOG(INFO) << "Using F16 data type for floating point values";
}
return use_fp16;
}

Expand All @@ -39,7 +47,12 @@ bool ShouldUse32BitLong() {
}

bool UseBF16() {
static bool use_fp16 = ShouldUseBF16();
static bool use_bf16 = ShouldUseBF16();
return use_bf16;
}

bool UseF16() {
static bool use_fp16 = ShouldUseF16();
return use_fp16;
}

Expand All @@ -58,6 +71,8 @@ xla::PrimitiveType XlaTypeFromTensorType(at::ScalarType scalar_type,
return xla::PrimitiveType::F32;
case at::ScalarType::BFloat16:
return xla::PrimitiveType::BF16;
case at::ScalarType::Half:
return xla::PrimitiveType::F16;
case at::ScalarType::Bool:
return xla::PrimitiveType::PRED;
case at::ScalarType::Byte:
Expand Down Expand Up @@ -101,6 +116,20 @@ struct Caster<tensorflow::bfloat16> {
}
};
template <>
struct Caster<at::Half> {
template <typename D>
D cast(const at::Half& value) const {
return static_cast<D>(static_cast<float>(value));
}
};
template <>
struct Caster<xla::half> {
template <typename D>
D cast(const xla::half& value) const {
return static_cast<D>(static_cast<float>(value));
}
};
template <>
struct Caster<std::complex<float>> {
template <typename D>
D cast(const std::complex<float>& value) const {
Expand Down Expand Up @@ -157,6 +186,14 @@ struct NeedCast<at::BFloat16> {
static constexpr bool value = true;
};
template <>
struct NeedCast<xla::half> {
static constexpr bool value = true;
};
template <>
struct NeedCast<at::Half> {
static constexpr bool value = true;
};
template <>
struct NeedCast<std::complex<float>> {
static constexpr bool value = true;
};
Expand Down Expand Up @@ -371,6 +408,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
TensorToBuffer<SType, tensorflow::bfloat16>(
tensor, dest_shape, dest_buffer, dest_buffer_size, device);
break;
case xla::PrimitiveType::F16:
TensorToBuffer<SType, xla::half>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F32:
TensorToBuffer<SType, float>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -453,6 +494,10 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Half:
TensorToBufferSType<at::Half>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Bool:
TensorToBufferSType<bool>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -550,6 +595,8 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
case at::ScalarType::BFloat16:
return XlaLiteralToTensor<SType, at::BFloat16>(literal,
dest_element_type);
case at::ScalarType::Half:
return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
case at::ScalarType::ComplexFloat:
return XlaLiteralToTensor<SType, std::complex<float>>(literal,
dest_element_type);
Expand Down Expand Up @@ -590,6 +637,8 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
case xla::PrimitiveType::BF16:
return XlaLiteralToTensorHelper<tensorflow::bfloat16>(literal,
dest_element_type);
case xla::PrimitiveType::F16:
return XlaLiteralToTensorHelper<xla::half>(literal, dest_element_type);
case xla::PrimitiveType::F32:
return XlaLiteralToTensorHelper<float>(literal, dest_element_type);
case xla::PrimitiveType::F64:
Expand Down Expand Up @@ -711,6 +760,8 @@ xla::hash_t TensorHash(const at::Tensor& tensor) {
return xla::util::DataHash(ctensor.data_ptr<double>(), size);
case at::ScalarType::BFloat16:
return xla::util::DataHash(ctensor.data_ptr<at::BFloat16>(), size);
case at::ScalarType::Half:
return xla::util::DataHash(ctensor.data_ptr<at::Half>(), size);
case at::ScalarType::ComplexFloat:
return xla::util::DataHash(ctensor.data_ptr<std::complex<float>>(), size);
case at::ScalarType::ComplexDouble:
Expand Down Expand Up @@ -761,9 +812,9 @@ xla::Shape CreateComputationShapeFromTensor(const at::Tensor& tensor,
at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
if (!UseBF16()) {
return at::ScalarType::BFloat16;
}
return UseBF16() ? at::ScalarType::Float : at::ScalarType::BFloat16;
case xla::PrimitiveType::F16:
return UseF16() ? at::ScalarType::Float : at::ScalarType::Half;
case xla::PrimitiveType::F32:
return at::ScalarType::Float;
case xla::PrimitiveType::F64:
Expand Down Expand Up @@ -800,6 +851,8 @@ xla::PrimitiveType TensorTypeToRawXlaType(at::ScalarType scalar_type) {
return xla::PrimitiveType::F32;
case at::ScalarType::BFloat16:
return xla::PrimitiveType::BF16;
case at::ScalarType::Half:
return xla::PrimitiveType::F16;
case at::ScalarType::Bool:
return xla::PrimitiveType::PRED;
case at::ScalarType::Byte:
Expand All @@ -826,14 +879,18 @@ xla::PrimitiveType GetDevicePrimitiveType(xla::PrimitiveType type,
Device xla_device = GetDeviceOrCurrent(device);
switch (type) {
case xla::PrimitiveType::F64:
if (UseF16()) {
return xla::PrimitiveType::F16;
}
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::F64
: xla::PrimitiveType::F32;
case xla::PrimitiveType::F32:
// When PyTorch will support native BF16 type, the global configuration
// can be replaced (or augmented) with the proper mapping.
if (UseF16()) {
return xla::PrimitiveType::F16;
}
return UseBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32;
case xla::PrimitiveType::U8:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::U8
Expand Down Expand Up @@ -868,6 +925,8 @@ xla::PrimitiveType MakeXlaPrimitiveType(at::ScalarType scalar_type,
return GetDevicePrimitiveType(xla::PrimitiveType::F32, device);
case at::ScalarType::BFloat16:
return GetDevicePrimitiveType(xla::PrimitiveType::BF16, device);
case at::ScalarType::Half:
return GetDevicePrimitiveType(xla::PrimitiveType::F16, device);
case at::ScalarType::Bool:
return GetDevicePrimitiveType(xla::PrimitiveType::PRED, device);
case at::ScalarType::Byte:
Expand Down

0 comments on commit b5df579

Please sign in to comment.