Skip to content

Commit

Permalink
[Bug] Fix axis attribute for reduction instrs
Browse files Browse the repository at this point in the history
When the axis op is constant, it should update the axis attribute.

(cherry picked from commit 4fe94fa)
  • Loading branch information
Weiming Zhao committed Sep 7, 2021
1 parent bfa366c commit 0825fd2
Show file tree
Hide file tree
Showing 21 changed files with 28 additions and 22 deletions.
10 changes: 8 additions & 2 deletions lib/transforms/type_legalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <algorithm>
#include <cmath>
#include <limits>
#include <type_traits>
#include <unordered_set>

#include "halo/api/halo_data.h"
Expand Down Expand Up @@ -475,8 +476,8 @@ static void RunOnInstruction(Conv2DTransposeInst* inst) {
}
}

static void RunOnCommonReductionInstruction(Instruction* inst,
std::vector<int32_t> axis,
template <typename T>
static void RunOnCommonReductionInstruction(T* inst, std::vector<int32_t> axis,
bool keep_dims) {
const auto& input_type = inst->GetOperand(0).GetType();
if (!input_type.IsValid()) {
Expand Down Expand Up @@ -528,6 +529,11 @@ static void RunOnCommonReductionInstruction(Instruction* inst,
dt = DataType::INT32;
}

constexpr bool is_arg_inst =
std::is_same<T, ArgmaxInst>() || std::is_same<T, ArgminInst>();
if constexpr (!is_arg_inst) { // NOLINT
inst->SetAxis(axis);
}
inst->GetResultsTypes()[0] = halo::Type{dt, ret_shape};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_mean_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_mean_negative_axes_keepdims_example_dnnl.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_mean_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_mean_negative_axes_keepdims_random_dnnl.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_min_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_min_negative_axes_keepdims_example_dnnl.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_min_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_min_negative_axes_keepdims_random_dnnl.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_mean_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_mean_negative_axes_keepdims_example_popart.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_mean_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_mean_negative_axes_keepdims_random_popart.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_min_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_min_negative_axes_keepdims_example_popart.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_min_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_min_negative_axes_keepdims_random_popart.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l1_negative_axes_keep_dims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_l1_negative_axes_keep_dims_example_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l1_negative_axes_keep_dims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_l1_negative_axes_keep_dims_random_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l2_negative_axes_keep_dims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_l2_negative_axes_keep_dims_example_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l2_negative_axes_keep_dims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_l2_negative_axes_keep_dims_random_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_log_sum_exp_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_log_sum_exp_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_log_sum_exp_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_log_sum_exp_negative_axes_keepdims_random_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_log_sum_negative_axes | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_log_sum_negative_axes_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_mean_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_mean_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_mean_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_mean_negative_axes_keepdims_random_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_min_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_min_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_min_negative_axes_keepdims_random | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_min_negative_axes_keepdims_random_tensorrt.cc.tmp.main.cc.in"
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_sum_square_negative_axes_keepdims_example | FileCheck %s
// CHECK: Result Pass
// clang-format on
// XFAIL: *

#include "test_reduce_sum_square_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"

0 comments on commit 0825fd2

Please sign in to comment.