From c8621a93b52aad3b4263618926b7853f3f86adb2 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 20 Nov 2024 11:49:45 -0800 Subject: [PATCH 1/5] Make as_const_* return a std::optional instead of a pointer To prevent use-after-free bugs --- src/AlignLoads.cpp | 2 +- src/BoundSmallAllocations.cpp | 2 +- src/Bounds.cpp | 12 +- src/CodeGen_ARM.cpp | 4 +- src/CodeGen_C.cpp | 16 +- src/CodeGen_D3D12Compute_Dev.cpp | 12 +- src/CodeGen_Hexagon.cpp | 4 +- src/CodeGen_Internal.cpp | 29 ++-- src/CodeGen_Internal.h | 2 +- src/CodeGen_LLVM.cpp | 40 ++--- src/CodeGen_Metal_Dev.cpp | 12 +- src/CodeGen_OpenCL_Dev.cpp | 2 +- src/CodeGen_PTX_Dev.cpp | 2 +- src/CodeGen_Vulkan_Dev.cpp | 7 +- src/CodeGen_WebGPU_Dev.cpp | 7 +- src/CodeGen_X86.cpp | 2 +- src/Deinterleave.cpp | 9 +- src/DerivativeUtils.cpp | 6 +- src/DistributeShifts.cpp | 4 +- src/FindIntrinsics.cpp | 14 +- src/FlattenNestedRamps.cpp | 3 +- src/FuseGPUThreadLoops.cpp | 4 +- src/HexagonOptimize.cpp | 4 +- src/IROperator.cpp | 86 +++++----- src/IROperator.h | 25 +-- src/LowerWarpShuffles.cpp | 16 +- src/Monotonic.cpp | 16 +- src/Profiling.cpp | 8 +- src/Random.cpp | 6 +- src/Simplify.cpp | 35 +--- src/Simplify_Call.cpp | 149 ++++++++---------- src/Simplify_Cast.cpp | 56 +++---- src/Simplify_Internal.h | 10 +- src/Simplify_Reinterpret.cpp | 10 +- src/Solve.cpp | 2 +- src/StageStridedLoads.cpp | 14 +- src/StorageFolding.cpp | 2 +- src/VectorizeLoops.cpp | 10 +- src/autoschedulers/adams2019/FunctionDAG.cpp | 29 ++-- src/autoschedulers/adams2019/State.cpp | 3 +- .../anderson2021/FunctionDAG.cpp | 29 ++-- .../li2018/GradientAutoscheduler.cpp | 8 +- test/correctness/bound_storage.cpp | 9 +- test/correctness/constant_expr.cpp | 6 +- test/correctness/fuse_gpu_threads.cpp | 4 +- 45 files changed, 329 insertions(+), 403 deletions(-) diff --git a/src/AlignLoads.cpp b/src/AlignLoads.cpp index ebcad5c6d156..263c6b4844de 100644 --- a/src/AlignLoads.cpp +++ b/src/AlignLoads.cpp @@ -71,7 +71,7 @@ class AlignLoads : public IRMutator { Expr index = mutate(op->index); const Ramp *ramp = index.as(); - const int64_t *const_stride = ramp ? as_const_int(ramp->stride) : nullptr; + auto const_stride = ramp ? as_const_int(ramp->stride) : std::nullopt; if (!ramp || !const_stride) { // We can't handle indirect loads, or loads with // non-constant strides. diff --git a/src/BoundSmallAllocations.cpp b/src/BoundSmallAllocations.cpp index f83a13d99614..a4fe341df603 100644 --- a/src/BoundSmallAllocations.cpp +++ b/src/BoundSmallAllocations.cpp @@ -125,7 +125,7 @@ class BoundSmallAllocations : public IRMutator { << "Try storing on the heap or stack instead."; } - const int64_t *size_ptr = bound.defined() ? as_const_int(bound) : nullptr; + auto size_ptr = bound.defined() ? as_const_int(bound) : std::nullopt; int64_t size = size_ptr ? *size_ptr : 0; if (size_ptr && size == 0 && !op->new_expr.defined()) { diff --git a/src/Bounds.cpp b/src/Bounds.cpp index fe72e6bedfdf..a158eb179216 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -355,18 +355,18 @@ class Bounds : public IRVisitor { // each other; however, if the bounds can be simplified to // constants, they might fit regardless of types. a = simplify(a); - const auto *umin = as_const_uint(a.min); - const auto *umax = as_const_uint(a.max); + auto umin = as_const_uint(a.min); + auto umax = as_const_uint(a.max); if (umin && umax && to.can_represent(*umin) && to.can_represent(*umax)) { could_overflow = false; } else { - const auto *imin = as_const_int(a.min); - const auto *imax = as_const_int(a.max); + auto imin = as_const_int(a.min); + auto imax = as_const_int(a.max); if (imin && imax && to.can_represent(*imin) && to.can_represent(*imax)) { could_overflow = false; } else { - const auto *fmin = as_const_float(a.min); - const auto *fmax = as_const_float(a.max); + auto fmin = as_const_float(a.min); + auto fmax = as_const_float(a.max); if (fmin && fmax && to.can_represent(*fmin) && to.can_represent(*fmax)) { could_overflow = false; } diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 4cd05e3c7b5f..914f5b9d71f2 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1178,7 +1178,7 @@ void CodeGen_ARM::visit(const Cast *op) { if (expr_match(pattern.pattern, op, matches)) { if (pattern.intrin.find("shift_right_narrow") != string::npos) { // The shift_right_narrow patterns need the shift to be constant in [1, output_bits]. - const uint64_t *const_b = as_const_uint(matches[1]); + auto const_b = as_const_uint(matches[1]); if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) { continue; } @@ -2015,7 +2015,7 @@ void CodeGen_ARM::visit(const Call *op) { if (expr_match(pattern.pattern, op, matches)) { if (pattern.intrin.find("shift_right_narrow") != string::npos) { // The shift_right_narrow patterns need the shift to be constant in [1, output_bits]. - const uint64_t *const_b = as_const_uint(matches[1]); + auto const_b = as_const_uint(matches[1]); if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) { continue; } diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index b8dbf173d43e..4a4b12ca070b 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1325,9 +1325,8 @@ void CodeGen_C::visit(const Mul *op) { } void CodeGen_C::visit(const Div *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { - visit_binop(op->type, op->a, make_const(op->a.type(), bits), ">>"); + if (auto bits = is_const_power_of_two_integer(op->b)) { + visit_binop(op->type, op->a, make_const(op->a.type(), *bits), ">>"); } else if (op->type.is_int()) { print_expr(lower_euclidean_div(op->a, op->b)); } else { @@ -1336,9 +1335,8 @@ void CodeGen_C::visit(const Div *op) { } void CodeGen_C::visit(const Mod *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { - visit_binop(op->type, op->a, make_const(op->a.type(), (1 << bits) - 1), "&"); + if (auto bits = is_const_power_of_two_integer(op->b)) { + visit_binop(op->type, op->a, make_const(op->a.type(), ((uint64_t)1 << *bits) - 1), "&"); } else if (op->type.is_int()) { print_expr(lower_euclidean_mod(op->a, op->b)); } else if (op->type.is_float()) { @@ -1613,7 +1611,7 @@ void CodeGen_C::visit(const Call *op) { } else if (op->is_intrinsic(Call::alloca)) { internal_assert(op->args.size() == 1); internal_assert(op->type.is_handle()); - const int64_t *sz = as_const_int(op->args[0]); + auto sz = as_const_int(op->args[0]); if (op->type == type_of() && Call::as_intrinsic(op->args[0], {Call::size_of_halide_buffer_t})) { stream << get_indent(); @@ -1752,8 +1750,8 @@ void CodeGen_C::visit(const Call *op) { internal_assert(op->args.size() == 3); std::string struct_instance = print_expr(op->args[0]); std::string struct_prototype = print_expr(op->args[1]); - const int64_t *index = as_const_int(op->args[2]); - internal_assert(index != nullptr); + auto index = as_const_int(op->args[2]); + internal_assert(index); rhs << "((decltype(" << struct_prototype << "))" << struct_instance << ")->f_" << *index; } else if (op->is_intrinsic(Call::get_user_context)) { diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index 15ca77ab56a0..584391dff94a 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -268,10 +268,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Min *op) { } void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Div *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { + if (auto bits = is_const_power_of_two_integer(op->b)) { ostringstream oss; - oss << print_expr(op->a) << " >> " << bits; + oss << print_expr(op->a) << " >> " << *bits; print_assignment(op->type, oss.str()); } else if (op->type.is_int()) { print_expr(lower_euclidean_div(op->a, op->b)); @@ -281,10 +280,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Div *op) { } void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Mod *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { + if (auto bits = is_const_power_of_two_integer(op->b)) { ostringstream oss; - oss << print_expr(op->a) << " & " << ((1 << bits) - 1); + oss << print_expr(op->a) << " & " << (((uint64_t)1 << *bits) - 1); print_assignment(op->type, oss.str()); } else if (op->type.is_int()) { print_expr(lower_euclidean_mod(op->a, op->b)); @@ -349,7 +347,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) { if (op->is_intrinsic(Call::gpu_thread_barrier)) { internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n"; - const auto *fence_type_ptr = as_const_int(op->args[0]); + auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; auto fence_type = *fence_type_ptr; diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index b98a4390a28b..83ce42bf5b51 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -1932,8 +1932,8 @@ void CodeGen_Hexagon::visit(const Call *op) { return; } else if (op->is_intrinsic(Call::dynamic_shuffle)) { internal_assert(op->args.size() == 4); - const int64_t *min_index = as_const_int(op->args[2]); - const int64_t *max_index = as_const_int(op->args[3]); + auto min_index = as_const_int(op->args[2]); + auto max_index = as_const_int(op->args[3]); internal_assert(min_index && max_index); Value *lut = codegen(op->args[0]); Value *idx = codegen(op->args[1]); diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index dc88960dde7f..56df50ce371f 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -116,25 +116,23 @@ bool can_allocation_fit_on_stack(int64_t size) { Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) { // Detect if it's a small int division internal_assert(a.type() == b.type()); - const int64_t *const_int_divisor = as_const_int(b); - const uint64_t *const_uint_divisor = as_const_uint(b); + auto const_int_divisor = as_const_int(b); + auto const_uint_divisor = as_const_uint(b); Type t = a.type(); internal_assert(!t.is_float()) << "lower_int_uint_div is not meant to handle floating-point case.\n"; - int shift_amount; - if (is_const_power_of_two_integer(b, &shift_amount) && - (t.is_int() || t.is_uint())) { + if (auto shift_amount = is_const_power_of_two_integer(b)) { if (round_to_zero) { Expr result = a; // Normally a right-shift isn't right for division rounding to // zero. It does the wrong thing for negative values. Add a fudge so // that a right-shift becomes correct. result += (result >> (t.bits() - 1)) & (b - 1); - return result >> shift_amount; + return result >> *shift_amount; } else { - return a >> make_const(UInt(a.type().bits()), shift_amount); + return a >> make_const(UInt(a.type().bits()), *shift_amount); } } else if (const_int_divisor && t.is_int() && @@ -262,15 +260,14 @@ Expr lower_int_uint_div(const Expr &a, const Expr &b, bool round_to_zero) { Expr lower_int_uint_mod(const Expr &a, const Expr &b) { // Detect if it's a small int modulus - const int64_t *const_int_divisor = as_const_int(b); - const uint64_t *const_uint_divisor = as_const_uint(b); + auto const_int_divisor = as_const_int(b); + auto const_uint_divisor = as_const_uint(b); Type t = a.type(); internal_assert(!t.is_float()) << "lower_int_uint_div is not meant to handle floating-point case.\n"; - int bits; - if (is_const_power_of_two_integer(b, &bits)) { + if (is_const_power_of_two_integer(b)) { return a & simplify(b - 1); } else if (const_int_divisor && t.is_int() && @@ -294,7 +291,7 @@ Expr lower_int_uint_mod(const Expr &a, const Expr &b) { namespace { std::pair unsigned_long_div_mod_round_to_zero(Expr &num, const Expr &den, - const uint64_t *upper_bound) { + std::optional upper_bound) { internal_assert(num.type() == den.type()); internal_assert(num.type().is_uint()); Type ty = num.type(); @@ -333,7 +330,7 @@ std::pair unsigned_long_div_mod_round_to_zero(Expr &num, const Expr } // namespace std::pair long_div_mod_round_to_zero(const Expr &num, const Expr &den, - const uint64_t *max_abs) { + std::optional max_abs) { debug(1) << "Using long div: (num: " << num << "); (den: " << den << ")\n"; internal_assert(num.type() == den.type()); Expr abs_num = (num.type().is_int()) ? abs(num) : num; @@ -476,8 +473,7 @@ Expr lower_euclidean_mod(Expr a, Expr b) { Expr lower_signed_shift_left(const Expr &a, const Expr &b) { internal_assert(b.type().is_int()); - const int64_t *const_int_b = as_const_int(b); - if (const_int_b) { + if (auto const_int_b = as_const_int(b)) { Expr val; const uint64_t b_unsigned = std::abs(*const_int_b); if (*const_int_b >= 0) { @@ -497,8 +493,7 @@ Expr lower_signed_shift_left(const Expr &a, const Expr &b) { Expr lower_signed_shift_right(const Expr &a, const Expr &b) { internal_assert(b.type().is_int()); - const int64_t *const_int_b = as_const_int(b); - if (const_int_b) { + if (auto const_int_b = as_const_int(b)) { Expr val; const uint64_t b_unsigned = std::abs(*const_int_b); if (*const_int_b >= 0) { diff --git a/src/CodeGen_Internal.h b/src/CodeGen_Internal.h index 3e48281aa05e..144c72b2c7c8 100644 --- a/src/CodeGen_Internal.h +++ b/src/CodeGen_Internal.h @@ -53,7 +53,7 @@ bool can_allocation_fit_on_stack(int64_t size); * max_abs is the maximum absolute value of (a/b). * Returns the pair {div_round_to_zero, mod_round_to_zero}. */ std::pair long_div_mod_round_to_zero(const Expr &a, const Expr &b, - const uint64_t *max_abs = nullptr); + std::optional max_abs = std::nullopt); /** Given a Halide Euclidean division/mod operation, do constant optimizations * and possibly call lower_euclidean_div/lower_euclidean_mod if necessary. diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index f793a9110767..a94cb2380050 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1928,7 +1928,7 @@ Value *CodeGen_LLVM::codegen_buffer_pointer(Value *base_address, Halide::Type ty // aliasing analysis, especially for backends that do address // computation in 32 bits but use 64-bit pointers. if (const Add *add = index.as()) { - if (const int64_t *offset = as_const_int(add->b)) { + if (auto offset = as_const_int(add->b)) { Value *base = codegen_buffer_pointer(base_address, type, add->a); Value *off = codegen(make_const(Int(8 * d.getPointerSize()), *offset)); return CreateInBoundsGEP(builder.get(), llvm_type_of(type), base, off); @@ -1988,8 +1988,8 @@ void CodeGen_LLVM::add_tbaa_metadata(llvm::Instruction *inst, string buffer, con if (index.defined()) { if (const Ramp *ramp = index.as()) { - const int64_t *pstride = as_const_int(ramp->stride); - const int64_t *pbase = as_const_int(ramp->base); + auto pstride = as_const_int(ramp->stride); + auto pbase = as_const_int(ramp->base); if (pstride && pbase) { // We want to find the smallest aligned width and offset // that contains this ramp. @@ -2005,7 +2005,7 @@ void CodeGen_LLVM::add_tbaa_metadata(llvm::Instruction *inst, string buffer, con constant_index = true; } } else { - const int64_t *pbase = as_const_int(index); + auto pbase = as_const_int(index); if (pbase) { base = *pbase; constant_index = true; @@ -2915,7 +2915,7 @@ void CodeGen_LLVM::visit(const Call *op) { llvm::Value *struct_instance = codegen(op->args[0]); llvm::Value *struct_prototype = codegen(op->args[1]); llvm::Value *typed_struct_instance = builder->CreatePointerCast(struct_instance, struct_prototype->getType()); - const int64_t *index = as_const_int(op->args[2]); + auto index = as_const_int(op->args[2]); // make_struct can use a fixed-size struct, an array type, or a scalar llvm::Type *pointee_type; @@ -2928,7 +2928,7 @@ void CodeGen_LLVM::visit(const Call *op) { llvm::StructType *struct_type = llvm::dyn_cast(pointee_type); llvm::Type *array_type = llvm::dyn_cast(pointee_type); if (struct_type || array_type) { - internal_assert(index != nullptr); + internal_assert(index); llvm::Value *gep = CreateInBoundsGEP(builder.get(), pointee_type, typed_struct_instance, {ConstantInt::get(i32_t, 0), ConstantInt::get(i32_t, (int)*index)}); @@ -2936,7 +2936,7 @@ void CodeGen_LLVM::visit(const Call *op) { value = builder->CreateLoad(result_type, gep); } else { // The struct is actually just a scalar - internal_assert(index == nullptr || *index == 0); + internal_assert(!index || *index == 0); value = builder->CreateLoad(pointee_type, typed_struct_instance); } } else if (op->is_intrinsic(Call::get_user_context)) { @@ -3100,7 +3100,7 @@ void CodeGen_LLVM::visit(const Call *op) { // restrictions if we recognize the most common types we // expect to get alloca'd. const Call *call = op->args[0].as(); - const int64_t *sz = as_const_int(op->args[0]); + auto sz = as_const_int(op->args[0]); if (op->type == type_of() && call && call->is_intrinsic(Call::size_of_halide_buffer_t)) { value = create_alloca_at_entry(halide_buffer_t_type, 1); @@ -3109,7 +3109,7 @@ void CodeGen_LLVM::visit(const Call *op) { sz && *sz == 16) { value = create_alloca_at_entry(semaphore_t_type, 1); } else { - internal_assert(sz != nullptr); + internal_assert(sz); if (op->type == type_of()) { value = create_alloca_at_entry(dimension_t_type, *sz / sizeof(halide_dimension_t)); } else { @@ -3828,10 +3828,10 @@ void CodeGen_LLVM::visit(const Store *op) { int store_lanes = value_type.lanes(); int native_lanes = maximum_vector_bits() / value_type.bits(); - Expr base = (ramp != nullptr) ? ramp->base : 0; - Expr stride = (ramp != nullptr) ? ramp->stride : 0; - Value *stride_val = (!is_dense && ramp != nullptr) ? codegen(stride) : nullptr; - Value *index = (ramp == nullptr) ? codegen(op->index) : nullptr; + Expr base = ramp ? ramp->base : 0; + Expr stride = ramp ? ramp->stride : 0; + Value *stride_val = (!is_dense && ramp) ? codegen(stride) : nullptr; + Value *index = !ramp ? codegen(op->index) : nullptr; for (int i = 0; i < store_lanes; i += native_lanes) { int slice_lanes = std::min(native_lanes, store_lanes - i); @@ -3849,7 +3849,7 @@ void CodeGen_LLVM::visit(const Store *op) { StoreInst *store = builder->CreateAlignedStore(slice_val, vec_ptr, llvm::Align(alignment)); annotate_store(store, slice_index); } - } else if (ramp != nullptr) { + } else if (ramp) { if (get_target().bits == 64 && !stride_val->getType()->isIntegerTy(64)) { stride_val = builder->CreateIntCast(stride_val, i64_t, true); } @@ -3996,7 +3996,7 @@ void CodeGen_LLVM::visit(const IfThenElse *op) { vector rhs; for (auto &block : blocks) { const EQ *eq = block.first.as(); - const int64_t *r = eq ? as_const_int(eq->b) : nullptr; + auto r = eq ? as_const_int(eq->b) : std::nullopt; if (eq && r && Int(32).can_represent(*r) && @@ -5190,12 +5190,12 @@ llvm::Value *CodeGen_LLVM::fixed_to_scalable_vector_type(llvm::Value *fixed_arg) internal_assert(effective_vscale != 0); internal_assert(isa(fixed_arg->getType())); const llvm::FixedVectorType *fixed_type = cast(fixed_arg->getType()); - internal_assert(fixed_type != nullptr); + internal_assert(fixed_type); auto lanes = fixed_type->getNumElements(); llvm::ScalableVectorType *scalable_type = cast(get_vector_type(fixed_type->getElementType(), lanes / effective_vscale, VectorTypeConstraint::VScale)); - internal_assert(fixed_type != nullptr); + internal_assert(fixed_type); internal_assert(fixed_type->getElementType() == scalable_type->getElementType()); internal_assert(lanes == (scalable_type->getMinNumElements() * effective_vscale)); @@ -5227,11 +5227,11 @@ llvm::Value *CodeGen_LLVM::scalable_to_fixed_vector_type(llvm::Value *scalable_a internal_assert(effective_vscale != 0); internal_assert(isa(scalable_arg->getType())); const llvm::ScalableVectorType *scalable_type = cast(scalable_arg->getType()); - internal_assert(scalable_type != nullptr); + internal_assert(scalable_type); llvm::FixedVectorType *fixed_type = cast(get_vector_type(scalable_type->getElementType(), scalable_type->getMinNumElements() * effective_vscale, VectorTypeConstraint::Fixed)); - internal_assert(fixed_type != nullptr); + internal_assert(fixed_type); internal_assert(fixed_type->getElementType() == scalable_type->getElementType()); internal_assert(fixed_type->getNumElements() == (scalable_type->getMinNumElements() * effective_vscale)); @@ -5444,7 +5444,7 @@ bool CodeGen_LLVM::try_vector_predication_intrinsic(const std::string &name, VPR if (!std::holds_alternative(mask)) { if (std::holds_alternative(mask)) { - internal_assert(base_vector_type != nullptr) << "Requested all enabled mask without any vector type to use for type/length.\n"; + internal_assert(base_vector_type) << "Requested all enabled mask without any vector type to use for type/length.\n"; llvm::ElementCount llvm_vector_ec; if (is_scalable) { const auto *vt = cast(base_vector_type); diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index 194bfdc3e5dd..cdb3fa26bde7 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -241,10 +241,9 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const VectorReduce *op) { } void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Div *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { + if (auto bits = is_const_power_of_two_integer(op->b)) { ostringstream oss; - oss << print_expr(op->a) << " >> " << bits; + oss << print_expr(op->a) << " >> " << *bits; print_assignment(op->type, oss.str()); } else if (op->type.is_int()) { print_expr(lower_euclidean_div(op->a, op->b)); @@ -254,10 +253,9 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Div *op) { } void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Mod *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { + if (auto bits = is_const_power_of_two_integer(op->b)) { ostringstream oss; - oss << print_expr(op->a) << " & " << ((1 << bits) - 1); + oss << print_expr(op->a) << " & " << (((uint64_t)1 << *bits) - 1); print_assignment(op->type, oss.str()); } else if (op->type.is_int()) { print_expr(lower_euclidean_mod(op->a, op->b)); @@ -312,7 +310,7 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Call *op) { if (op->is_intrinsic(Call::gpu_thread_barrier)) { internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n"; - const auto *fence_type_ptr = as_const_int(op->args[0]); + auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; auto fence_type = *fence_type_ptr; diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index cc1fb060445e..a6dfc73da7c3 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -317,7 +317,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Call *op) { } else if (op->is_intrinsic(Call::gpu_thread_barrier)) { internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n"; - const auto *fence_type_ptr = as_const_int(op->args[0]); + auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; auto fence_type = *fence_type_ptr; diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 5ad9ccf0fe5c..87a407b90eee 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -257,7 +257,7 @@ void CodeGen_PTX_Dev::visit(const Call *op) { // arguments internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n"; - const auto *fence_type_ptr = as_const_int(op->args[0]); + auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; llvm::Function *barrier0 = module->getFunction("llvm.nvvm.barrier0"); diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 80b04ef05ff4..559f0bb932f4 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -755,12 +755,11 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Div *op) { void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Mod *op) { debug(2) << "CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(Mod): " << op->type << " ((" << op->a << ") % (" << op->b << "))\n"; - int bits = 0; - if (is_const_power_of_two_integer(op->b, &bits) && op->type.is_int_or_uint()) { + if (auto bits = is_const_power_of_two_integer(op->b)) { op->a.accept(this); SpvId src_a_id = builder.current_id(); - int bitwise_value = ((1 << bits) - 1); + int bitwise_value = ((1 << *bits) - 1); Expr expr = make_const(op->type, bitwise_value); expr.accept(this); SpvId src_b_id = builder.current_id(); @@ -1020,7 +1019,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Call *op) { if (op->is_intrinsic(Call::gpu_thread_barrier)) { internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n"; - const auto *fence_type_ptr = as_const_int(op->args[0]); + auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; auto fence_type = *fence_type_ptr; diff --git a/src/CodeGen_WebGPU_Dev.cpp b/src/CodeGen_WebGPU_Dev.cpp index 4fc5346ac13d..37cb1cd2af9c 100644 --- a/src/CodeGen_WebGPU_Dev.cpp +++ b/src/CodeGen_WebGPU_Dev.cpp @@ -500,7 +500,7 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Call *op) { internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify fence type.\n"; - const auto *fence_type_ptr = as_const_int(op->args[0]); + auto fence_type_ptr = as_const_int(op->args[0]); internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n"; auto fence_type = *fence_type_ptr; @@ -549,11 +549,10 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Cast *op) { } void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Div *op) { - int bits; - if (is_const_power_of_two_integer(op->b, &bits)) { + if (auto bits = is_const_power_of_two_integer(op->b)) { // WGSL requires the RHS of a shift to be unsigned. Type uint_type = op->a.type().with_code(halide_type_uint); - visit_binop(op->type, op->a, make_const(uint_type, bits), ">>"); + visit_binop(op->type, op->a, make_const(uint_type, *bits), ">>"); } else { CodeGen_GPU_C::visit(op); } diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index a31ba15cca4b..90609e1477c6 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -594,7 +594,7 @@ void CodeGen_X86::visit(const Call *op) { op->type.element_of() == Int(16)) && op->is_intrinsic(Call::mul_shift_right)) { internal_assert(op->args.size() == 3); - const uint64_t *shift = as_const_uint(op->args[2]); + auto shift = as_const_uint(op->args[2]); if (shift && *shift < 16 && *shift >= 8) { Type narrow = op->type.with_bits(8); Expr narrow_a = lossless_cast(narrow, op->args[0]); diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index cf8652395bb7..fbb8077b0602 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -658,14 +658,14 @@ class Interleaver : public IRMutator { return Stmt(); } - const int64_t *stride_ptr = as_const_int(r0->stride); + auto optional_stride = as_const_int(r0->stride); // The stride isn't a constant or is <= 1 - if (!stride_ptr || *stride_ptr <= 1) { + if (!optional_stride || *optional_stride <= 1) { return Stmt(); } - const int64_t stride = *stride_ptr; + const int64_t stride = *optional_stride; const int lanes = r0->lanes; const int64_t expected_stores = stride; @@ -715,8 +715,7 @@ class Interleaver : public IRMutator { return Stmt(); } - Expr diff = simplify(ri->base - r0->base); - const int64_t *offs = as_const_int(diff); + auto offs = as_const_int(simplify(ri->base - r0->base)); // Difference between bases is not constant. if (!offs) { diff --git a/src/DerivativeUtils.cpp b/src/DerivativeUtils.cpp index 86f5902017ff..a9f51b4a4070 100644 --- a/src/DerivativeUtils.cpp +++ b/src/DerivativeUtils.cpp @@ -381,13 +381,13 @@ pair solve_inverse(Expr expr, Expr rmax = simplify(interval.max); Expr rextent = simplify(rmax - rmin + 1); - const int64_t *extent_int = as_const_int(rextent); - if (extent_int == nullptr) { + auto extent_int = as_const_int(rextent); + if (!extent_int) { return {false, Expr()}; } // For some reason interval.is_single_point() doesn't work - if (extent_int != nullptr && *extent_int == 1) { + if (extent_int && *extent_int == 1) { return {true, rmin}; } diff --git a/src/DistributeShifts.cpp b/src/DistributeShifts.cpp index 5d30d8b7b9c4..4d053e7d8dfe 100644 --- a/src/DistributeShifts.cpp +++ b/src/DistributeShifts.cpp @@ -121,7 +121,7 @@ class DistributeShiftsAsMuls : public IRMutator { Expr distribute_shift(const Call *op) { if (op->is_intrinsic(Call::shift_left)) { - if (const uint64_t *const_b = as_const_uint(op->args[1])) { + if (auto const_b = as_const_uint(op->args[1])) { Expr a = op->args[0]; // Only rewrite widening shifts. const Cast *cast_a = a.as(); @@ -133,7 +133,7 @@ class DistributeShiftsAsMuls : public IRMutator { } } } else if (op->is_intrinsic(Call::widening_shift_left)) { - if (const uint64_t *const_b = as_const_uint(op->args[1])) { + if (auto const_b = as_const_uint(op->args[1])) { const uint64_t const_m = 1ull << *const_b; Expr b = make_const(op->type, const_m); Expr a = Cast::make(op->type, op->args[0]); diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index b72122460706..f3fe0f470f8d 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -385,11 +385,10 @@ class FindIntrinsics : public IRMutator { // Rewrite multiplies to shifts if possible. if (op->type.is_int() || op->type.is_uint()) { - int pow2 = 0; - if (is_const_power_of_two_integer(a, &pow2)) { - return mutate(b << cast(UInt(b.type().bits()), pow2)); - } else if (is_const_power_of_two_integer(b, &pow2)) { - return mutate(a << cast(UInt(a.type().bits()), pow2)); + if (auto pow2 = is_const_power_of_two_integer(a)) { + return mutate(b << cast(UInt(b.type().bits()), *pow2)); + } else if (auto pow2 = is_const_power_of_two_integer(b)) { + return mutate(a << cast(UInt(a.type().bits()), *pow2)); } } @@ -467,9 +466,8 @@ class FindIntrinsics : public IRMutator { Expr a = mutate(op->a); Expr b = mutate(op->b); - int shift_amount; - if (is_const_power_of_two_integer(b, &shift_amount) && op->type.is_int_or_uint()) { - return mutate(a >> make_const(UInt(a.type().bits()), shift_amount)); + if (auto shift_amount = is_const_power_of_two_integer(b)) { + return mutate(a >> make_const(UInt(a.type().bits()), *shift_amount)); } if (a.same_as(op->a) && b.same_as(op->b)) { diff --git a/src/FlattenNestedRamps.cpp b/src/FlattenNestedRamps.cpp index d44a89278b26..e50e8b3b5535 100644 --- a/src/FlattenNestedRamps.cpp +++ b/src/FlattenNestedRamps.cpp @@ -70,8 +70,7 @@ class FlattenRamps : public IRMutator { int max_constant_offset = 0; for (Expr &idx : indices) { idx = simplify(common_subexpression_elimination(idx - min_lane)); - const int64_t *i = as_const_int(idx); - if (i) { + if (auto i = as_const_int(idx)) { const_indices.push_back((int)(*i)); max_constant_offset = std::max((int)(*i), max_constant_offset); } else { diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 851a60c8fef8..3464a91ddb58 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -654,8 +654,8 @@ class ExtractSharedAndHeapAllocations : public IRMutator { const auto &candidate_group = mem_allocs[free_spaces[i]]; Expr size = alloc_size * alloc.type.bytes(); Expr dist = candidate_group.max_size * candidate_group.widest_type.bytes() - size; - const int64_t *current_diff = as_const_int(simplify(dist)); - internal_assert(current_diff != nullptr); + auto current_diff = as_const_int(simplify(dist)); + internal_assert(current_diff); int64_t abs_diff = std::abs(*current_diff); if ((free_idx == -1) || (abs_diff < diff)) { diff = abs_diff; diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 7414f0fd7225..35a1ef780fff 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -148,7 +148,7 @@ Expr as_mul(const Expr &a) { } else if (const Call *wm = Call::as_intrinsic(a, {Call::widening_mul})) { return simplify(Mul::make(cast(wm->type, wm->args[0]), cast(wm->type, wm->args[1]))); } else if (const Call *s = Call::as_intrinsic(a, {Call::shift_left, Call::widening_shift_left})) { - const uint64_t *log2_b = as_const_uint(s->args[1]); + auto log2_b = as_const_uint(s->args[1]); if (log2_b) { Expr b = make_one(s->type) << cast(UInt(s->type.bits()), (int)*log2_b); return simplify(Mul::make(cast(s->type, s->args[0]), b)); @@ -1076,7 +1076,7 @@ class OptimizePatterns : public IRMutator { // Run bounds analysis to estimate the range of result. Expr abs_result = op->type.is_int() ? abs(a / b) : a / b; Expr extent_upper = find_constant_bound(abs_result, Direction::Upper, bounds); - const uint64_t *upper_bound = as_const_uint(extent_upper); + auto upper_bound = as_const_uint(extent_upper); a = mutate(a); b = mutate(b); std::pair div_mod = long_div_mod_round_to_zero(a, b, upper_bound); diff --git a/src/IROperator.cpp b/src/IROperator.cpp index b090cd1aa1fb..68f9aa2d2747 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -179,79 +179,70 @@ bool is_pure(const Expr &e) { return pure.result; } -const int64_t *as_const_int(const Expr &e) { +std::optional as_const_int(const Expr &e) { if (!e.defined()) { - return nullptr; + return {}; } else if (const Broadcast *b = e.as()) { return as_const_int(b->value); } else if (const IntImm *i = e.as()) { - return &(i->value); + return i->value; } else { - return nullptr; + return {}; } } -const uint64_t *as_const_uint(const Expr &e) { +std::optional as_const_uint(const Expr &e) { if (!e.defined()) { - return nullptr; + return {}; } else if (const Broadcast *b = e.as()) { return as_const_uint(b->value); } else if (const UIntImm *i = e.as()) { - return &(i->value); + return i->value; } else { - return nullptr; + return {}; } } -const double *as_const_float(const Expr &e) { +std::optional as_const_float(const Expr &e) { if (!e.defined()) { - return nullptr; + return {}; } else if (const Broadcast *b = e.as()) { return as_const_float(b->value); } else if (const FloatImm *f = e.as()) { - return &(f->value); + return f->value; } else { - return nullptr; + return {}; } } -bool is_const_power_of_two_integer(const Expr &e, int *bits) { +std::optional is_const_power_of_two_integer(const Expr &e) { if (!(e.type().is_int() || e.type().is_uint())) { - return false; + return {}; } - const Broadcast *b = e.as(); - if (b) { - return is_const_power_of_two_integer(b->value, bits); - } - - const Cast *c = e.as(); - if (c) { - return is_const_power_of_two_integer(c->value, bits); - } - - uint64_t val = 0; - - if (const int64_t *i = as_const_int(e)) { - if (*i < 0) { - return false; - } - val = (uint64_t)(*i); - } else if (const uint64_t *u = as_const_uint(e)) { - val = *u; + if (const Broadcast *b = e.as()) { + return is_const_power_of_two_integer(b->value); + } else if (const Cast *c = e.as()) { + return is_const_power_of_two_integer(c->value); + } else if (auto i = as_const_int(e)) { + return is_const_power_of_two_integer(*i); + } else if (auto u = as_const_uint(e)) { + return is_const_power_of_two_integer(*u); + } else { + return {}; } +} +std::optional is_const_power_of_two_integer(uint64_t val) { if (val && ((val & (val - 1)) == 0)) { - *bits = 0; - for (; val; val >>= 1) { - if (val == 1) { - return true; - } - (*bits)++; - } + return ctz64(val); + } else { + return {}; } +} - return false; +std::optional is_const_power_of_two_integer(int64_t val) { + return val < 0 ? std::nullopt : is_const_power_of_two_integer((uint64_t)val); } bool is_positive_const(const Expr &e) { @@ -2001,13 +1992,13 @@ Expr cast(Type t, Expr a) { } // Fold constants early - if (const int64_t *i = as_const_int(a)) { + if (auto i = as_const_int(a)) { return make_const(t, *i); } - if (const uint64_t *u = as_const_uint(a)) { + if (auto u = as_const_uint(a)) { return make_const(t, *u); } - if (const double *f = as_const_float(a)) { + if (auto f = as_const_float(a)) { return make_const(t, *f); } @@ -2263,7 +2254,7 @@ Expr log(Expr x) { Expr pow(Expr x, Expr y) { user_assert(x.defined() && y.defined()) << "pow of undefined Expr\n"; - if (const int64_t *i = as_const_int(y)) { + if (auto i = as_const_int(y)) { return raise_to_integer_power(std::move(x), *i); } @@ -2287,7 +2278,7 @@ Expr erf(const Expr &x) { } Expr fast_pow(Expr x, Expr y) { - if (const int64_t *i = as_const_int(y)) { + if (auto i = as_const_int(y)) { return raise_to_integer_power(std::move(x), *i); } @@ -2558,8 +2549,7 @@ Expr lerp(Expr zero_val, Expr one_val, Expr weight) { // Compilation error for constant weight that is out of range for integer use // as this seems like an easy to catch gotcha. if (!zero_val.type().is_float()) { - const double *const_weight = as_const_float(weight); - if (const_weight) { + if (auto const_weight = as_const_float(weight)) { user_assert(*const_weight >= 0.0 && *const_weight <= 1.0) << "Floating-point weight for lerp with integer arguments is " << *const_weight << ", which is not in the range [0.0, 1.0].\n"; diff --git a/src/IROperator.h b/src/IROperator.h index 2a65769dd277..0db5606f011c 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -9,6 +9,7 @@ #include #include +#include #include "Expr.h" #include "Target.h" @@ -27,21 +28,25 @@ bool is_const(const Expr &e); bool is_const(const Expr &e, int64_t v); /** If an expression is an IntImm or a Broadcast of an IntImm, return - * a pointer to its value. Otherwise returns nullptr. */ -const int64_t *as_const_int(const Expr &e); + * a its value. Otherwise returns std::nullopt. */ +std::optional as_const_int(const Expr &e); /** If an expression is a UIntImm or a Broadcast of a UIntImm, return - * a pointer to its value. Otherwise returns nullptr. */ -const uint64_t *as_const_uint(const Expr &e); + * its value. Otherwise returns std::nullopt. */ +std::optional as_const_uint(const Expr &e); /** If an expression is a FloatImm or a Broadcast of a FloatImm, - * return a pointer to its value. Otherwise returns nullptr. */ -const double *as_const_float(const Expr &e); + * return its value. Otherwise returns std::nullopt. */ +std::optional as_const_float(const Expr &e); -/** Is the expression a constant integer power of two. Also returns - * log base two of the expression if it is. Only returns true for - * integer types. */ -bool is_const_power_of_two_integer(const Expr &e, int *bits); +/** Is the expression a constant integer power of two. Returns log base two of + * the expression if it is, or std::nullopt if not. Also returns std::nullopt + * for non-integer types. */ +// @{ +std::optional is_const_power_of_two_integer(const Expr &e); +std::optional is_const_power_of_two_integer(uint64_t); +std::optional is_const_power_of_two_integer(int64_t); +// @} /** Is the expression a const (as defined by is_const), and also * strictly greater than zero (in all lanes, if a vector expression) */ diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index ad48c37db78f..fa3dfecb9a1f 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -379,7 +379,7 @@ class LowerWarpShuffles : public IRMutator { bool should_mask = false; ScopedValue old_warp_size(warp_size); if (op->for_type == ForType::GPULane) { - const int64_t *loop_size = as_const_int(op->extent); + auto loop_size = as_const_int(op->extent); user_assert(loop_size && *loop_size <= 32) << "CUDA gpu lanes loop must have constant extent of at most 32: " << op->extent << "\n"; @@ -411,7 +411,7 @@ class LowerWarpShuffles : public IRMutator { Expr new_size = (alloc->extents[0] + op->extent - 1) / op->extent; new_size = simplify(new_size, true, bounds); new_size = find_constant_bound(new_size, Direction::Upper, bounds); - const int64_t *sz = as_const_int(new_size); + auto sz = as_const_int(new_size); user_assert(sz) << "Warp-level allocation with non-constant size: " << alloc->extents[0] << ". Use Func::bound_extent."; DetermineAllocStride stride(alloc->name, op->name, warp_size); @@ -587,7 +587,7 @@ class LowerWarpShuffles : public IRMutator { Expr wild = Variable::make(Int(32), "*"); vector result; - int bits = 0; + std::optional bits; // Move this_lane as far left as possible in the expression to // reduce the number of cases to check below. @@ -602,18 +602,18 @@ class LowerWarpShuffles : public IRMutator { shfl_args({membermask, base_val, result[0], 31}), Call::PureExtern); shuffled = down; } else if (expr_match((this_lane + wild) % wild, lane, result) && - is_const_power_of_two_integer(result[1], &bits) && - bits <= 5) { + (bits = is_const_power_of_two_integer(result[1])) && + *bits <= 5) { result[0] = simplify(result[0] % result[1], true, bounds); // Rotate. Mux a shuffle up and a shuffle down. Uses fewer // intermediate registers than using a general gather for // this. - Expr mask = (1 << bits) - 1; + Expr mask = (1 << *bits) - 1; Expr down = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".down" + intrin_suffix, shfl_args({membermask, base_val, result[0], mask}), Call::PureExtern); Expr up = Call::make(shuffle_type, "llvm.nvvm.shfl" + sync_suffix + ".up" + intrin_suffix, - shfl_args({membermask, base_val, (1 << bits) - result[0], 0}), Call::PureExtern); - Expr cond = (this_lane >= (1 << bits) - result[0]); + shfl_args({membermask, base_val, (1 << *bits) - result[0], 0}), Call::PureExtern); + Expr cond = (this_lane >= (1 << *bits) - result[0]); Expr equiv = select(cond, up, down); shuffled = simplify(equiv, true, bounds); } else { diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index 1450faade800..82934a31de5e 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -33,15 +33,15 @@ using std::string; namespace { -const int64_t *as_const_int_or_uint(const Expr &e) { - if (const int64_t *i = as_const_int(e)) { +std::optional as_const_int_or_uint(const Expr &e) { + if (auto i = as_const_int(e)) { return i; - } else if (const uint64_t *u = as_const_uint(e)) { + } else if (auto u = as_const_uint(e)) { if (*u <= (uint64_t)std::numeric_limits::max()) { - return (const int64_t *)u; + return (int64_t)(*u); } } - return nullptr; + return {}; } bool is_constant(const ConstantInterval &x) { @@ -154,9 +154,9 @@ class DerivativeBounds : public IRVisitor { // This is essentially the product rule: a*rb + b*ra // but only implemented for the case where a or b is constant. - if (const int64_t *b = as_const_int_or_uint(op->b)) { + if (auto b = as_const_int_or_uint(op->b)) { result = ra * (*b); - } else if (const int64_t *a = as_const_int_or_uint(op->a)) { + } else if (auto a = as_const_int_or_uint(op->a)) { result = rb * (*a); } else { result = ConstantInterval::everything(); @@ -168,7 +168,7 @@ class DerivativeBounds : public IRVisitor { void visit(const Div *op) override { if (op->type.is_scalar()) { - if (const int64_t *b = as_const_int_or_uint(op->b)) { + if (auto b = as_const_int_or_uint(op->b)) { op->a.accept(this); // We don't just want to divide by b. For the min we want to // take floor division, and for the max we want to use ceil diff --git a/src/Profiling.cpp b/src/Profiling.cpp index 9951b74f0af0..3c7b5d6e0090 100644 --- a/src/Profiling.cpp +++ b/src/Profiling.cpp @@ -286,8 +286,8 @@ class InjectProfiling : public IRMutator { } else { idx = get_func_id(op->name); } - const uint64_t *int_size = as_const_uint(size); - internal_assert(int_size != nullptr); // Stack size is always a const int + auto int_size = as_const_uint(size); + internal_assert(int_size); // Stack size is always a const int func_stack_current[idx] += *int_size; func_stack_peak[idx] = std::max(func_stack_peak[idx], func_stack_current[idx]); debug(3) << " Allocation on stack: " << op->name @@ -355,8 +355,8 @@ class InjectProfiling : public IRMutator { stmt = Block::make(tasks); } } else { - const uint64_t *int_size = as_const_uint(alloc.size); - internal_assert(int_size != nullptr); + auto int_size = as_const_uint(alloc.size); + internal_assert(int_size); int idx; Function func = lookup_function(op->name); diff --git a/src/Random.cpp b/src/Random.cpp index 111ec73ebb5e..57eeb69f9210 100644 --- a/src/Random.cpp +++ b/src/Random.cpp @@ -55,7 +55,7 @@ Expr rng32(const Expr &x) { // So I declare this good enough for image processing. // If it's just a const (which it often is), save the simplifier some work: - if (const uint64_t *i = as_const_uint(x)) { + if (auto i = as_const_uint(x)) { return make_const(UInt(32), ((C2 * (*i)) + C1) * (*i) + C0); } @@ -73,8 +73,8 @@ Expr random_int(const vector &e) { // Add in the next term and permute again string name = unique_name('R'); // If it's a const, save the simplifier some work - const uint64_t *ir = as_const_uint(result); - const uint64_t *ie = as_const_uint(e[i]); + auto ir = as_const_uint(result); + auto ie = as_const_uint(e[i]); if (ir && ie) { result = rng32(make_const(UInt(32), (*ir) + (*ie))); } else { diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 6bb34fc4db15..d430b53a965c 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -21,11 +21,11 @@ Simplify::Simplify(bool r, const Scope *bi, const Scopecbegin(); iter != bi->cend(); ++iter) { ExprInfo info; - if (const int64_t *i_min = as_const_int(iter.value().min)) { + if (auto i_min = as_const_int(iter.value().min)) { info.bounds.min_defined = true; info.bounds.min = *i_min; } - if (const int64_t *i_max = as_const_int(iter.value().max)) { + if (auto i_max = as_const_int(iter.value().max)) { info.bounds.max_defined = true; info.bounds.max = *i_max; } @@ -87,33 +87,6 @@ void Simplify::found_buffer_reference(const string &name, size_t dimensions) { } } -bool Simplify::const_float(const Expr &e, double *f) { - if (const double *p = as_const_float(e)) { - *f = *p; - return true; - } else { - return false; - } -} - -bool Simplify::const_int(const Expr &e, int64_t *i) { - if (const int64_t *p = as_const_int(e)) { - *i = *p; - return true; - } else { - return false; - } -} - -bool Simplify::const_uint(const Expr &e, uint64_t *u) { - if (const uint64_t *p = as_const_uint(e)) { - *u = *p; - return true; - } else { - return false; - } -} - void Simplify::ScopedFact::learn_false(const Expr &fact) { Simplify::VarInfo info; info.old_uses = info.new_uses = 0; @@ -211,8 +184,8 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { } else if (const EQ *eq = fact.as()) { const Variable *v = eq->a.as(); const Mod *m = eq->a.as(); - const int64_t *modulus = m ? as_const_int(m->b) : nullptr; - const int64_t *remainder = m ? as_const_int(eq->b) : nullptr; + auto modulus = m ? as_const_int(m->b) : std::nullopt; + auto remainder = m ? as_const_int(eq->b) : std::nullopt; if (v) { if (is_const(eq->b) || eq->b.as()) { // TODO: consider other cases where we might want to entirely substitute diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index db3fe526418c..0d5d9a5ffdda 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -82,22 +82,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return mutate(unbroadcast, info); } - uint64_t ua = 0; - if (const_int(a, (int64_t *)(&ua)) || const_uint(a, &ua)) { + auto ia = as_const_int(a); + auto ua = as_const_uint(a); + uint64_t u = ua.value_or(reinterpret_bits(ia.value_or(0))); + if (ia || ua) { const int bits = op->type.bits(); const uint64_t mask = std::numeric_limits::max() >> (64 - bits); - ua &= mask; + u &= mask; static_assert(sizeof(unsigned long long) >= sizeof(uint64_t), ""); int r = 0; if (op->is_intrinsic(Call::popcount)) { // popcount *is* well-defined for ua = 0 - r = popcount64(ua); + r = popcount64(u); } else if (op->is_intrinsic(Call::count_leading_zeros)) { // clz64() is undefined for 0, but Halide's count_leading_zeros defines clz(0) = bits - r = ua == 0 ? bits : (clz64(ua) - (64 - bits)); + r = u == 0 ? bits : (clz64(u) - (64 - bits)); } else /* if (op->is_intrinsic(Call::count_trailing_zeros)) */ { // ctz64() is undefined for 0, but Halide's count_trailing_zeros defines clz(0) = bits - r = ua == 0 ? bits : (ctz64(ua)); + r = u == 0 ? bits : (ctz64(u)); } return make_const(op->type, r); } @@ -140,16 +142,15 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { } // If the shift is by a constant, it should now be unsigned. - uint64_t ub = 0; - if (const_uint(b, &ub)) { + if (auto ub = as_const_uint(b)) { // LLVM shl and shr instructions produce poison for // shifts >= typesize, so we will follow suit in our simplifier. - if (ub >= (uint64_t)(t.bits())) { + if (*ub >= (uint64_t)(t.bits())) { clear_expr_info(info); return make_signed_integer_overflow(t); } - if (a.type().is_uint() || ub < ((uint64_t)t.bits() - 1)) { - b = make_const(t, ((int64_t)1LL) << ub); + if (a.type().is_uint() || *ub < ((uint64_t)t.bits() - 1)) { + b = make_const(t, ((int64_t)1LL) << *ub); if (result_op == Call::get_intrinsic_name(Call::shift_left)) { return mutate(Mul::make(a, b), info); } else { @@ -160,7 +161,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // (-32768 >> (t.bits() - 1)) propagates the sign bit, making decomposition // into mul or div problematic, so just special-case them here. if (result_op == Call::get_intrinsic_name(Call::shift_left)) { - return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << ub), make_zero(t)), info); + return mutate(select((a & 1) != 0, make_const(t, ((int64_t)1LL) << *ub), make_zero(t)), info); } else { return mutate(select(a < 0, make_const(t, -1), make_zero(t)), info); } @@ -193,29 +194,24 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return mutate(unbroadcast, info); } - int64_t ia, ib = 0; - uint64_t ua, ub = 0; - int bits; - - if (const_int(a, &ia) && - const_int(b, &ib)) { - return make_const(op->type, ia & ib); - } else if (const_uint(a, &ua) && - const_uint(b, &ub)) { - return make_const(op->type, ua & ub); - } else if (const_int(b, &ib) && - !b.type().is_max(ib) && - is_const_power_of_two_integer(make_const(a.type(), ib + 1), &bits)) { - return Mod::make(a, make_const(a.type(), ib + 1)); - } else if (const_uint(b, &ub) && - b.type().is_max(ub)) { + auto ia = as_const_int(a), ib = as_const_int(b); + auto ua = as_const_uint(a), ub = as_const_uint(b); + + if (ia && ib) { + return make_const(op->type, *ia & *ib); + } else if (ua && ub) { + return make_const(op->type, *ua & *ub); + } else if (ib && + !b.type().is_max(*ib) && + is_const_power_of_two_integer(*ib + 1)) { + return Mod::make(a, make_const(a.type(), *ib + 1)); + } else if (ub && b.type().is_max(*ub)) { return a; - } else if (const_int(b, &ib) && - ib == -1) { + } else if (ib && *ib == -1) { return a; - } else if (const_uint(b, &ub) && - is_const_power_of_two_integer(make_const(a.type(), ub + 1), &bits)) { - return Mod::make(a, make_const(a.type(), ub + 1)); + } else if (ub && + is_const_power_of_two_integer(*ub + 1)) { + return Mod::make(a, make_const(a.type(), *ub + 1)); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { @@ -230,14 +226,12 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return mutate(unbroadcast, info); } - int64_t ia, ib; - uint64_t ua, ub; - if (const_int(a, &ia) && - const_int(b, &ib)) { - return make_const(op->type, ia | ib); - } else if (const_uint(a, &ua) && - const_uint(b, &ub)) { - return make_const(op->type, ua | ub); + auto ia = as_const_int(a), ib = as_const_int(b); + auto ua = as_const_uint(a), ub = as_const_uint(b); + if (ia && ib) { + return make_const(op->type, *ia | *ib); + } else if (ua && ub) { + return make_const(op->type, *ua | *ub); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { @@ -251,12 +245,10 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return mutate(unbroadcast, info); } - int64_t ia; - uint64_t ua; - if (const_int(a, &ia)) { - return make_const(op->type, ~ia); - } else if (const_uint(a, &ua)) { - return make_const(op->type, ~ua); + if (auto ia = as_const_int(a)) { + return make_const(op->type, ~(*ia)); + } else if (auto ua = as_const_uint(a)) { + return make_const(op->type, ~(*ua)); } else if (a.same_as(op->args[0])) { return op; } else { @@ -271,14 +263,12 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { return mutate(unbroadcast, info); } - int64_t ia, ib; - uint64_t ua, ub; - if (const_int(a, &ia) && - const_int(b, &ib)) { - return make_const(op->type, ia ^ ib); - } else if (const_uint(a, &ua) && - const_uint(b, &ub)) { - return make_const(op->type, ua ^ ub); + auto ia = as_const_int(a), ib = as_const_int(b); + auto ua = as_const_uint(a), ub = as_const_uint(b); + if (ia && ib) { + return make_const(op->type, *ia ^ *ib); + } else if (ua && ub) { + return make_const(op->type, *ua ^ *ub); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; } else { @@ -300,21 +290,19 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { } Type ta = a.type(); - int64_t ia = 0; - double fa = 0; - if (ta.is_int() && const_int(a, &ia)) { - if (ia < 0 && !(Int(64).is_min(ia))) { - ia = -ia; + if (auto ia = as_const_int(a)) { + if (*ia < 0 && !(Int(64).is_min(*ia))) { + *ia = -(*ia); } - return make_const(op->type, ia); + return make_const(op->type, *ia); } else if (ta.is_uint()) { // abs(uint) is a no-op. return a; - } else if (const_float(a, &fa)) { - if (fa < 0) { - fa = -fa; + } else if (auto fa = as_const_float(a)) { + if (*fa < 0) { + *fa = -(*fa); } - return make_const(a.type(), fa); + return make_const(a.type(), *fa); } else if (a.type().is_int() && a_info.bounds >= 0) { return cast(op->type, a); } else if (a.type().is_int() && a_info.bounds <= 0) { @@ -339,19 +327,19 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { // absd() should enforce identical types for a and b when the node is created internal_assert(ta == b.type()); - int64_t ia = 0, ib = 0; - uint64_t ua = 0, ub = 0; - double fa = 0, fb = 0; - if (ta.is_int() && const_int(a, &ia) && const_int(b, &ib)) { + auto ia = as_const_int(a), ib = as_const_int(b); + auto ua = as_const_uint(a), ub = as_const_uint(b); + auto fa = as_const_float(a), fb = as_const_float(b); + if (ta.is_int() && ia && ib) { // Note that absd(int, int) always produces a uint result internal_assert(op->type.is_uint()); - const uint64_t d = ia > ib ? (uint64_t)(ia - ib) : (uint64_t)(ib - ia); + const uint64_t d = *ia > *ib ? (uint64_t)(*ia - *ib) : (uint64_t)(*ib - *ia); return make_const(op->type, d); - } else if (ta.is_uint() && const_uint(a, &ua) && const_uint(b, &ub)) { - const uint64_t d = ua > ub ? ua - ub : ub - ua; + } else if (ta.is_uint() && ua && ub) { + const uint64_t d = *ua > *ub ? *ua - *ub : *ub - *ua; return make_const(op->type, d); - } else if (const_float(a, &fa) && const_float(b, &fb)) { - const double d = fa > fb ? fa - fb : fb - fa; + } else if (fa && fb) { + const double d = *fa > *fb ? *fa - *fb : *fb - *fa; return make_const(op->type, d); } else if (a.same_as(op->args[0]) && b.same_as(op->args[1])) { return op; @@ -664,10 +652,9 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { auto it = pure_externs_f1b.find(op->name); if (it != pure_externs_f1b.end()) { Expr arg = mutate(op->args[0], nullptr); - double f = 0.0; - if (const_float(arg, &f)) { + if (auto f = as_const_float(arg)) { auto fn = it->second; - return make_bool(fn(f)); + return make_bool(fn(*f)); } else if (arg.same_as(op->args[0])) { return op; } else { @@ -703,7 +690,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { auto it = pure_externs_f1.find(op->name); if (it != pure_externs_f1.end()) { Expr arg = mutate(op->args[0], nullptr); - if (const double *f = as_const_float(arg)) { + if (auto f = as_const_float(arg)) { auto fn = it->second; return make_const(arg.type(), fn(*f)); } else if (arg.same_as(op->args[0])) { @@ -735,7 +722,7 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { Expr arg = mutate(op->args[0], nullptr); const Call *call = arg.as(); - if (const double *f = as_const_float(arg)) { + if (auto f = as_const_float(arg)) { auto fn = it->second; return make_const(arg.type(), fn(*f)); } else if (call && (call->call_type == Call::PureExtern || call->call_type == Call::PureIntrinsic) && @@ -765,8 +752,8 @@ Expr Simplify::visit(const Call *op, ExprInfo *info) { Expr arg0 = mutate(op->args[0], nullptr); Expr arg1 = mutate(op->args[1], nullptr); - const double *f0 = as_const_float(arg0); - const double *f1 = as_const_float(arg1); + auto f0 = as_const_float(arg0); + auto f1 = as_const_float(arg1); if (f0 && f1) { auto fn = it->second; return make_const(arg0.type(), fn(*f0, *f1)); diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 985707ce2cfb..ae08ea3944fd 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -27,74 +27,74 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { const Cast *cast = value.as(); const Broadcast *broadcast_value = value.as(); const Ramp *ramp_value = value.as(); - double f = 0.0; - int64_t i = 0; - uint64_t u = 0; + std::optional f; + std::optional i; + std::optional u; if (Call::as_intrinsic(value, {Call::signed_integer_overflow})) { clear_expr_info(info); return make_signed_integer_overflow(op->type); } else if (value.type() == op->type) { return value; } else if (op->type.is_int() && - const_float(value, &f) && - std::isfinite(f)) { + (f = as_const_float(value)) && + std::isfinite(*f)) { // float -> int // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(f)), info); + return mutate(make_const(op->type, safe_numeric_cast(*f)), info); } else if (op->type.is_uint() && - const_float(value, &f) && - std::isfinite(f)) { + (f = as_const_float(value)) && + std::isfinite(*f)) { // float -> uint // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(f)), info); + return mutate(make_const(op->type, safe_numeric_cast(*f)), info); } else if (op->type.is_float() && - const_float(value, &f)) { + (f = as_const_float(value))) { // float -> float - return make_const(op->type, f); + return make_const(op->type, *f); } else if (op->type.is_int() && - const_int(value, &i)) { + (i = as_const_int(value))) { // int -> int // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, i), info); + return mutate(make_const(op->type, *i), info); } else if (op->type.is_uint() && - const_int(value, &i)) { + (i = as_const_int(value))) { // int -> uint - return make_const(op->type, safe_numeric_cast(i)); + return make_const(op->type, safe_numeric_cast(*i)); } else if (op->type.is_float() && - const_int(value, &i)) { + (i = as_const_int(value))) { // int -> float - return mutate(make_const(op->type, safe_numeric_cast(i)), info); + return mutate(make_const(op->type, safe_numeric_cast(*i)), info); } else if (op->type.is_int() && - const_uint(value, &u) && + (u = as_const_uint(value)) && op->type.bits() < value.type().bits()) { // uint -> int narrowing // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), info); + return mutate(make_const(op->type, safe_numeric_cast(*u)), info); } else if (op->type.is_int() && - const_uint(value, &u) && + (u = as_const_uint(value)) && op->type.bits() == value.type().bits()) { // uint -> int reinterpret // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), info); + return mutate(make_const(op->type, safe_numeric_cast(*u)), info); } else if (op->type.is_int() && - const_uint(value, &u) && + (u = as_const_uint(value)) && op->type.bits() > value.type().bits()) { // uint -> int widening - if (op->type.can_represent(u) || op->type.bits() < 32) { + if (op->type.can_represent(*u) || op->type.bits() < 32) { // If the type can represent the value or overflow is well-defined. // Recursively call mutate just to set the bounds - return mutate(make_const(op->type, safe_numeric_cast(u)), info); + return mutate(make_const(op->type, safe_numeric_cast(*u)), info); } else { return make_signed_integer_overflow(op->type); } } else if (op->type.is_uint() && - const_uint(value, &u)) { + (u = as_const_uint(value))) { // uint -> uint - return mutate(make_const(op->type, u), info); + return mutate(make_const(op->type, *u), info); } else if (op->type.is_float() && - const_uint(value, &u)) { + (u = as_const_uint(value))) { // uint -> float - return make_const(op->type, safe_numeric_cast(u)); + return make_const(op->type, safe_numeric_cast(*u)); } else if (cast && op->type.code() == cast->type.code() && op->type.bits() < cast->type.bits()) { diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index 19666cc77294..851b5d05c810 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -153,11 +153,11 @@ class Simplify : public VariadicVisitor { if (b) { debug(1) << spaces << "Bounds: " << b->bounds << " " << b->alignment << "\n"; - if (const int64_t *i = as_const_int(new_e)) { + if (auto i = as_const_int(new_e)) { internal_assert(b->bounds.contains(*i)) << e << "\n" << new_e << "\n" << b->bounds; - } else if (const uint64_t *i = as_const_uint(new_e)) { + } else if (auto i = as_const_uint(new_e)) { internal_assert(b->bounds.contains(*i)) << e << "\n" << new_e << "\n" << b->bounds; @@ -259,12 +259,6 @@ class Simplify : public VariadicVisitor { // symbols. void found_buffer_reference(const std::string &name, size_t dimensions = 0); - // Wrappers for as_const_foo that are more convenient to use in - // the large chains of conditions in the visit methods below. - bool const_float(const Expr &e, double *f); - bool const_int(const Expr &e, int64_t *i); - bool const_uint(const Expr &e, uint64_t *u); - // Put the args to a commutative op in a canonical order HALIDE_ALWAYS_INLINE bool should_commute(const Expr &a, const Expr &b) { diff --git a/src/Simplify_Reinterpret.cpp b/src/Simplify_Reinterpret.cpp index 51289aac9b87..259b7fb4f486 100644 --- a/src/Simplify_Reinterpret.cpp +++ b/src/Simplify_Reinterpret.cpp @@ -6,17 +6,15 @@ namespace Internal { Expr Simplify::visit(const Reinterpret *op, ExprInfo *info) { Expr a = mutate(op->value, nullptr); - int64_t ia; - uint64_t ua; bool vector = op->type.is_vector() || a.type().is_vector(); if (op->type == a.type()) { return a; - } else if (const_int(a, &ia) && op->type.is_uint() && !vector) { + } else if (auto ia = as_const_int(a); ia && op->type.is_uint() && !vector) { // int -> uint - return make_const(op->type, (uint64_t)ia); - } else if (const_uint(a, &ua) && op->type.is_int() && !vector) { + return make_const(op->type, reinterpret_bits(*ia)); + } else if (auto ua = as_const_uint(a); ua && op->type.is_int() && !vector) { // uint -> int - return make_const(op->type, (int64_t)ua); + return make_const(op->type, reinterpret_bits(*ua)); } else if (const Reinterpret *as_r = a.as()) { // Fold double-reinterprets. return mutate(reinterpret(op->type, as_r->value), info); diff --git a/src/Solve.cpp b/src/Solve.cpp index 10e6232e379d..3f124601345a 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -392,7 +392,7 @@ class SolveExpression : public IRMutator { const Mul *mul_a = a.as(); Expr expr; if (a_uses_var && !b_uses_var) { - const int64_t *ib = as_const_int(b); + auto ib = as_const_int(b); auto is_multiple_of_b = [&](const Expr &e) { if (ib && op->type.is_scalar()) { int64_t r = 0; diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index 723fc738ce51..dc4e5b9eed4f 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -87,16 +87,16 @@ class FindStridedLoads : public IRVisitor { // CSE). Expr idx = substitute_in_all_lets(simplify(common_subexpression_elimination(op->index))); if (const Ramp *r = idx.as()) { - const int64_t *stride_ptr = as_const_int(r->stride); - int64_t stride = stride_ptr ? *stride_ptr : 0; + int64_t stride = as_const_int(r->stride).value_or(0); Expr base = r->base; int64_t offset = 0; - const Add *base_add = base.as(); - const int64_t *offset_ptr = base_add ? as_const_int(base_add->b) : nullptr; - if (offset_ptr) { - base = base_add->a; - offset = *offset_ptr; + if (const Add *base_add = base.as()) { + if (auto off = as_const_int(base_add->b)) { + base = base_add->a; + offset = *off; + } } + // TODO: We do not yet handle nested vectorization here for // ramps which have not already collapsed. We could potentially // handle more interesting types of shuffle than simple flat slices. diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index a207b3ce63f5..04e743e33fbd 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -740,7 +740,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { scope.pop(op->name); const int max_fold = 1024; - const int64_t *const_max_extent = as_const_int(max_extent); + auto const_max_extent = as_const_int(max_extent); if (const_max_extent && *const_max_extent <= max_fold) { factor = static_cast(next_power_of_two(*const_max_extent)); } else { diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 8bb3096f4c4f..576bc0f1ab41 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -261,7 +261,7 @@ bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRam return true; } } else if (const Mul *mul = e.as()) { - const int64_t *b = nullptr; + std::optional b; if (is_interleaved_ramp(mul->a, scope, result) && (b = as_const_int(mul->b))) { result->base = simplify(result->base * (int)(*b)); @@ -269,7 +269,7 @@ bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRam return true; } } else if (const Div *div = e.as
()) { - const int64_t *b = nullptr; + std::optional b; if (is_interleaved_ramp(div->a, scope, result) && (b = as_const_int(div->b)) && is_const_one(result->stride) && @@ -284,7 +284,7 @@ bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRam return true; } } else if (const Mod *mod = e.as()) { - const int64_t *b = nullptr; + std::optional b; if (is_interleaved_ramp(mod->a, scope, result) && (b = as_const_int(mod->b)) && (result->outer_repetitions == 1 || @@ -655,8 +655,8 @@ class VectorSubs : public IRMutator { if (!changed) { return op; } else if (op->name == Call::trace) { - const int64_t *event = as_const_int(op->args[6]); - internal_assert(event != nullptr); + auto event = as_const_int(op->args[6]); + internal_assert(event); if (*event == halide_trace_begin_realization || *event == halide_trace_end_realization) { // Call::trace vectorizes uniquely for begin/end realization, because the coordinates // for these are actually min/extent pairs; we need to maintain the proper dimensionality diff --git a/src/autoschedulers/adams2019/FunctionDAG.cpp b/src/autoschedulers/adams2019/FunctionDAG.cpp index 4bb17b265d46..35b872477150 100644 --- a/src/autoschedulers/adams2019/FunctionDAG.cpp +++ b/src/autoschedulers/adams2019/FunctionDAG.cpp @@ -210,7 +210,7 @@ class Featurizer : public IRVisitor { return a; } else if (const Mul *op = e.as()) { auto a = differentiate(op->a, v); - if (const int64_t *ib = as_const_int(op->b)) { + if (auto ib = as_const_int(op->b)) { a.numerator *= *ib; return a; } else { @@ -218,7 +218,7 @@ class Featurizer : public IRVisitor { } } else if (const Div *op = e.as
()) { auto a = differentiate(op->a, v); - if (const int64_t *ib = as_const_int(op->b)) { + if (auto ib = as_const_int(op->b)) { if (a.numerator != 0) { a.denominator *= *ib; } @@ -414,8 +414,8 @@ void FunctionDAG::Node::loop_nest_for_region(int stage_idx, const Span *computed } else { Expr min = simplify(substitute(computed_map, l.min)); Expr max = simplify(substitute(computed_map, l.max)); - const int64_t *imin = as_const_int(min); - const int64_t *imax = as_const_int(max); + auto imin = as_const_int(min); + auto imax = as_const_int(max); internal_assert(imin && imax) << min << ", " << max << "\n"; loop[i] = Span(*imin, *imax, false); } @@ -442,8 +442,8 @@ void FunctionDAG::Node::required_to_computed(const Span *required, Span *compute } else { Expr min = simplify(substitute(required_map, comp.in.min)); Expr max = simplify(substitute(required_map, comp.in.max)); - const int64_t *imin = as_const_int(min); - const int64_t *imax = as_const_int(max); + auto imin = as_const_int(min); + auto imax = as_const_int(max); internal_assert(imin && imax) << min << ", " << max << "\n"; computed[i] = Span(*imin, *imax, false); } @@ -542,7 +542,7 @@ void FunctionDAG::Edge::expand_footprint(const Span *consumer_loop, Span *produc } else { Expr substituted = substitute(s, b.expr); Expr e = simplify(substituted); - const int64_t *i = as_const_int(e); + auto i = as_const_int(e); internal_assert(i) << "Should be constant: " << b.expr << " -> " << substituted << " -> " << e << "\n"; bounds_are_constant = false; return *i; @@ -702,8 +702,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) } else { const Min *min = comp.in.min.as(); const Max *max = comp.in.max.as(); - const int64_t *min_b = min ? as_const_int(min->b) : nullptr; - const int64_t *max_b = max ? as_const_int(max->b) : nullptr; + auto min_b = min ? as_const_int(min->b) : std::nullopt; + auto max_b = max ? as_const_int(max->b) : std::nullopt; if (min_b && max_b && equal(min->a, req.min) && equal(max->a, req.max)) { comp.equals_union_of_required_with_constants = true; comp.c_min = *min_b; @@ -758,7 +758,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) } if (!l.equals_region_computed) { - const int64_t *c_min = as_const_int(l.min), *c_max = as_const_int(l.max); + auto c_min = as_const_int(l.min); + auto c_max = as_const_int(l.max); if (c_min && c_max) { l.bounds_are_constant = true; l.c_min = *c_min; @@ -876,8 +877,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) // Get the bounds estimate map estimates; for (const auto &b : consumer.schedule().estimates()) { - const int64_t *i_min = as_const_int(b.min); - const int64_t *i_extent = as_const_int(b.extent); + auto i_min = as_const_int(b.min); + auto i_extent = as_const_int(b.extent); user_assert(i_min && i_extent) << "Min/extent of estimate or bound is not constant in \"" << consumer.name() << "\", var:" << b.var << ", min:" << b.min << ", extent:" << b.extent; @@ -902,8 +903,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) } } for (const auto &b : consumer.schedule().bounds()) { - const int64_t *i_min = as_const_int(b.min); - const int64_t *i_extent = as_const_int(b.extent); + auto i_min = as_const_int(b.min); + auto i_extent = as_const_int(b.extent); if (i_min && i_extent) { // It's a true bound, not just an estimate estimates[b.var] = Span(*i_min, *i_min + *i_extent - 1, true); diff --git a/src/autoschedulers/adams2019/State.cpp b/src/autoschedulers/adams2019/State.cpp index 7c4545fae57b..a12860d5d295 100644 --- a/src/autoschedulers/adams2019/State.cpp +++ b/src/autoschedulers/adams2019/State.cpp @@ -321,8 +321,7 @@ void State::generate_children(const FunctionDAG &dag, int num_dims = output.dimensions(); for (int i = 0; i < num_dims; ++i) { const Expr stride = output.stride_constraint(i); - const int64_t *s = as_const_int(stride); - if (s && *s == 1) { + if (stride.defined() && is_const_one(stride)) { vector_dims.push_back(i); } } diff --git a/src/autoschedulers/anderson2021/FunctionDAG.cpp b/src/autoschedulers/anderson2021/FunctionDAG.cpp index 1a057187dcbd..e127a02a7bd3 100644 --- a/src/autoschedulers/anderson2021/FunctionDAG.cpp +++ b/src/autoschedulers/anderson2021/FunctionDAG.cpp @@ -210,7 +210,7 @@ class Featurizer : public IRVisitor { return a; } else if (const Mul *op = e.as()) { auto a = differentiate(op->a, v); - if (const int64_t *ib = as_const_int(op->b)) { + if (auto ib = as_const_int(op->b)) { a.numerator *= *ib; return a; } else { @@ -218,7 +218,7 @@ class Featurizer : public IRVisitor { } } else if (const Div *op = e.as
()) { auto a = differentiate(op->a, v); - if (const int64_t *ib = as_const_int(op->b)) { + if (auto ib = as_const_int(op->b)) { if (a.numerator != 0) { a.denominator *= *ib; } @@ -414,8 +414,8 @@ void FunctionDAG::Node::loop_nest_for_region(int stage_idx, const Span *computed } else { Expr min = simplify(substitute(computed_map, l.min)); Expr max = simplify(substitute(computed_map, l.max)); - const int64_t *imin = as_const_int(min); - const int64_t *imax = as_const_int(max); + auto imin = as_const_int(min); + auto imax = as_const_int(max); internal_assert(imin && imax) << min << ", " << max << "\n"; loop[i] = Span(*imin, *imax, false); } @@ -442,8 +442,8 @@ void FunctionDAG::Node::required_to_computed(const Span *required, Span *compute } else { Expr min = simplify(substitute(required_map, comp.in.min)); Expr max = simplify(substitute(required_map, comp.in.max)); - const int64_t *imin = as_const_int(min); - const int64_t *imax = as_const_int(max); + auto imin = as_const_int(min); + auto imax = as_const_int(max); internal_assert(imin && imax) << min << ", " << max << "\n"; computed[i] = Span(*imin, *imax, false); } @@ -550,7 +550,7 @@ void FunctionDAG::Edge::expand_footprint(const Span *consumer_loop, Span *produc } else { Expr substituted = substitute(s, b.expr); Expr e = simplify(substituted); - const int64_t *i = as_const_int(e); + auto i = as_const_int(e); internal_assert(i) << "Should be constant: " << b.expr << " -> " << substituted << " -> " << e << "\n"; bounds_are_constant = false; return *i; @@ -693,8 +693,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) } else { const Min *min = comp.in.min.as(); const Max *max = comp.in.max.as(); - const int64_t *min_b = min ? as_const_int(min->b) : nullptr; - const int64_t *max_b = max ? as_const_int(max->b) : nullptr; + auto min_b = min ? as_const_int(min->b) : std::nullopt; + auto max_b = max ? as_const_int(max->b) : std::nullopt; if (min_b && max_b && equal(min->a, req.min) && equal(max->a, req.max)) { comp.equals_union_of_required_with_constants = true; comp.c_min = *min_b; @@ -749,7 +749,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) } if (!l.equals_region_computed) { - const int64_t *c_min = as_const_int(l.min), *c_max = as_const_int(l.max); + auto c_min = as_const_int(l.min); + auto c_max = as_const_int(l.max); if (c_min && c_max) { l.bounds_are_constant = true; l.c_min = *c_min; @@ -866,8 +867,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) // Get the bounds estimate map estimates; for (const auto &b : consumer.schedule().estimates()) { - const int64_t *i_min = as_const_int(b.min); - const int64_t *i_extent = as_const_int(b.extent); + auto i_min = as_const_int(b.min); + auto i_extent = as_const_int(b.extent); user_assert(i_min && i_extent) << "Min/extent of estimate or bound is not constant in \"" << consumer.name() << "\", var:" << b.var << ", min:" << b.min << ", extent:" << b.extent; @@ -892,8 +893,8 @@ FunctionDAG::FunctionDAG(const vector &outputs, const Target &target) } } for (const auto &b : consumer.schedule().bounds()) { - const int64_t *i_min = as_const_int(b.min); - const int64_t *i_extent = as_const_int(b.extent); + auto i_min = as_const_int(b.min); + auto i_extent = as_const_int(b.extent); if (i_min && i_extent) { // It's a true bound, not just an estimate estimates[b.var] = Span(*i_min, *i_min + *i_extent - 1, true); diff --git a/src/autoschedulers/li2018/GradientAutoscheduler.cpp b/src/autoschedulers/li2018/GradientAutoscheduler.cpp index 1f1ad91a1567..8379f561fbd0 100644 --- a/src/autoschedulers/li2018/GradientAutoscheduler.cpp +++ b/src/autoschedulers/li2018/GradientAutoscheduler.cpp @@ -40,8 +40,8 @@ std::vector get_int_bounds(const Box &bounds) { const Interval &interval = bounds[i]; Expr extent = simplify(interval.max - interval.min + 1); extent = simplify(substitute_var_estimates(extent)); - const int64_t *extent_int = as_const_int(extent); - user_assert(extent_int != nullptr) + auto extent_int = as_const_int(extent); + user_assert(extent_int) << "extent:" << extent << " is not constant.\n"; int_bounds.push_back(*extent_int); } @@ -53,8 +53,8 @@ std::vector get_rvar_bounds(const std::vector &rvars) { rvar_bounds.reserve(rvars.size()); for (const auto &rvar : rvars) { Expr extent = simplify(substitute_var_estimates(rvar.extent)); - const int64_t *extent_int = as_const_int(extent); - user_assert(extent_int != nullptr) + auto extent_int = as_const_int(extent); + user_assert(extent_int) << "extent:" << extent << " is not constant.\n"; rvar_bounds.push_back(*extent_int); } diff --git a/test/correctness/bound_storage.cpp b/test/correctness/bound_storage.cpp index 0e2d50ee332b..64d8af225e5a 100644 --- a/test/correctness/bound_storage.cpp +++ b/test/correctness/bound_storage.cpp @@ -12,12 +12,7 @@ class FindAllocations : public Internal::IRMutator { Internal::Stmt visit(const Internal::Allocate *op) override { int total_size = 1; for (const auto &e : op->extents) { - const auto *size = Internal::as_const_int(e); - if (size) { - total_size = total_size * (*size); - } else { - total_size = 0; - } + total_size *= Internal::as_const_int(e).value_or(0); } // Trim of the suffix. std::string name = op->name.substr(0, op->name.find("$")); @@ -130,4 +125,4 @@ int main(int argc, char **argv) { printf("Success!\n"); return 0; -} \ No newline at end of file +} diff --git a/test/correctness/constant_expr.cpp b/test/correctness/constant_expr.cpp index 9c9360f2dce8..b4619ed299c2 100644 --- a/test/correctness/constant_expr.cpp +++ b/test/correctness/constant_expr.cpp @@ -17,17 +17,17 @@ bool bit_flip(T a) { template bool scalar_from_constant_expr(Expr e, T *val) { if (type_of().is_int()) { - const int64_t *i = as_const_int(e); + auto i = as_const_int(e); if (!i) return false; *val = (T)(*i); return true; } else if (type_of().is_uint()) { - const uint64_t *u = as_const_uint(e); + auto u = as_const_uint(e); if (!u) return false; *val = (T)(*u); return true; } else if (type_of().is_float()) { - const double *f = as_const_float(e); + auto f = as_const_float(e); if (!f) return false; *val = (T)(*f); return true; diff --git a/test/correctness/fuse_gpu_threads.cpp b/test/correctness/fuse_gpu_threads.cpp index 63361e76b928..efd690c4ef4c 100644 --- a/test/correctness/fuse_gpu_threads.cpp +++ b/test/correctness/fuse_gpu_threads.cpp @@ -8,8 +8,8 @@ class CheckThreadExtent : public IRVisitor { void visit(const For *op) override { if (op->for_type == ForType::GPUThread) { // Assert the min and extent to be 0 and 16 for this particular test case - const int64_t *min = as_const_int(op->min); - const int64_t *extent = as_const_int(op->extent); + auto min = as_const_int(op->min); + auto extent = as_const_int(op->extent); assert(min && (*min == 0)); assert(extent && (*extent == 16)); } From f605ec8cfe92eac804fa6acb4d19cb28c1166dd9 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 20 Nov 2024 12:03:23 -0800 Subject: [PATCH 2/5] Also use std::optional for get_md_string and get_md_bool --- src/CodeGen_Internal.cpp | 97 ++++++++++++++++++---------------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index 56df50ce371f..4ca58285df16 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -555,47 +555,42 @@ Expr lower_round_to_nearest_ties_to_even(const Expr &x) { } namespace { -bool get_md_bool(llvm::Metadata *value, bool &result) { +std::optional get_md_bool(llvm::Metadata *value) { if (!value) { - return false; + return {}; } llvm::ConstantAsMetadata *cam = llvm::cast(value); if (!cam) { - return false; + return {}; } llvm::ConstantInt *c = llvm::cast(cam->getValue()); if (!c) { - return false; + return {}; } - result = !c->isZero(); - return true; + return !c->isZero(); } -bool get_md_string(llvm::Metadata *value, std::string &result) { +std::optional get_md_string(llvm::Metadata *value) { if (!value) { - result = ""; - return false; + return {}; } llvm::MDString *c = llvm::dyn_cast(value); if (c) { - result = c->getString().str(); - return true; + return c->getString().str(); } - return false; + return {}; } } // namespace void get_target_options(const llvm::Module &module, llvm::TargetOptions &options) { - bool use_soft_float_abi = false; - get_md_bool(module.getModuleFlag("halide_use_soft_float_abi"), use_soft_float_abi); - std::string mabi; - get_md_string(module.getModuleFlag("halide_mabi"), mabi); - bool use_pic = true; - get_md_bool(module.getModuleFlag("halide_use_pic"), use_pic); + bool use_soft_float_abi = + get_md_bool(module.getModuleFlag("halide_use_soft_float_abi")).value_or(false); + std::string mabi = + get_md_string(module.getModuleFlag("halide_mabi")).value_or(std::string{}); // FIXME: can this be migrated into `set_function_attributes_from_halide_target_options()`? - bool per_instruction_fast_math_flags = false; - get_md_bool(module.getModuleFlag("halide_per_instruction_fast_math_flags"), per_instruction_fast_math_flags); + bool per_instruction_fast_math_flags = + get_md_bool(module.getModuleFlag("halide_per_instruction_fast_math_flags")).value_or(false); options = llvm::TargetOptions(); options.AllowFPOpFusion = per_instruction_fast_math_flags ? llvm::FPOpFusion::Strict : llvm::FPOpFusion::Fast; @@ -622,29 +617,22 @@ void clone_target_options(const llvm::Module &from, llvm::Module &to) { llvm::LLVMContext &context = to.getContext(); - bool use_soft_float_abi = false; - if (get_md_bool(from.getModuleFlag("halide_use_soft_float_abi"), use_soft_float_abi)) { - to.addModuleFlag(llvm::Module::Warning, "halide_use_soft_float_abi", use_soft_float_abi ? 1 : 0); - } - - std::string mcpu_target; - if (get_md_string(from.getModuleFlag("halide_mcpu_target"), mcpu_target)) { - to.addModuleFlag(llvm::Module::Warning, "halide_mcpu_target", llvm::MDString::get(context, mcpu_target)); - } - - std::string mcpu_tune; - if (get_md_string(from.getModuleFlag("halide_mcpu_tune"), mcpu_tune)) { - to.addModuleFlag(llvm::Module::Warning, "halide_mcpu_tune", llvm::MDString::get(context, mcpu_tune)); + // Clone bool metadata + for (const char *s : {"halide_use_soft_float_abi", + "halide_use_pic"}) { + if (auto md = get_md_bool(from.getModuleFlag(s))) { + to.addModuleFlag(llvm::Module::Warning, s, *md ? 1 : 0); + } } - std::string mattrs; - if (get_md_string(from.getModuleFlag("halide_mattrs"), mattrs)) { - to.addModuleFlag(llvm::Module::Warning, "halide_mattrs", llvm::MDString::get(context, mattrs)); - } + // Clone string metadata + for (const char *s : {"halide_mcpu_target", + "halide_mcpu_tune", + "halide_mattrs"}) { - bool use_pic = true; - if (get_md_bool(from.getModuleFlag("halide_use_pic"), use_pic)) { - to.addModuleFlag(llvm::Module::Warning, "halide_use_pic", use_pic ? 1 : 0); + if (auto md = get_md_string(from.getModuleFlag(s))) { + to.addModuleFlag(llvm::Module::Warning, s, llvm::MDString::get(context, *md)); + } } } @@ -662,11 +650,11 @@ std::unique_ptr make_target_machine(const llvm::Module &mod llvm::TargetOptions options; get_target_options(module, options); - bool use_pic = true; - get_md_bool(module.getModuleFlag("halide_use_pic"), use_pic); + bool use_pic = + get_md_bool(module.getModuleFlag("halide_use_pic")).value_or(true); - bool use_large_code_model = false; - get_md_bool(module.getModuleFlag("halide_use_large_code_model"), use_large_code_model); + bool use_large_code_model = + get_md_bool(module.getModuleFlag("halide_use_large_code_model")).value_or(false); #if LLVM_VERSION >= 180 const auto opt_level = llvm::CodeGenOptLevel::Aggressive; @@ -675,10 +663,10 @@ std::unique_ptr make_target_machine(const llvm::Module &mod #endif // Get module mcpu_target and mattrs. - std::string mcpu_target; - get_md_string(module.getModuleFlag("halide_mcpu_target"), mcpu_target); - std::string mattrs; - get_md_string(module.getModuleFlag("halide_mattrs"), mattrs); + std::string mcpu_target = + get_md_string(module.getModuleFlag("halide_mcpu_target")).value_or(std::string{}); + std::string mattrs = + get_md_string(module.getModuleFlag("halide_mattrs")).value_or(std::string{}); auto *tm = llvm_target->createTargetMachine(module.getTargetTriple(), mcpu_target, @@ -693,11 +681,14 @@ std::unique_ptr make_target_machine(const llvm::Module &mod void set_function_attributes_from_halide_target_options(llvm::Function &fn) { llvm::Module &module = *fn.getParent(); - std::string mcpu_target, mcpu_tune, mattrs, vscale_range; - get_md_string(module.getModuleFlag("halide_mcpu_target"), mcpu_target); - get_md_string(module.getModuleFlag("halide_mcpu_tune"), mcpu_tune); - get_md_string(module.getModuleFlag("halide_mattrs"), mattrs); - get_md_string(module.getModuleFlag("halide_vscale_range"), vscale_range); + std::string mcpu_target = + get_md_string(module.getModuleFlag("halide_mcpu_target")).value_or(std::string{}); + std::string mcpu_tune = + get_md_string(module.getModuleFlag("halide_mcpu_tune")).value_or(std::string{}); + std::string mattrs = + get_md_string(module.getModuleFlag("halide_mattrs")).value_or(std::string{}); + std::string vscale_range = + get_md_string(module.getModuleFlag("halide_vscale_range")).value_or(std::string{}); fn.addFnAttr("target-cpu", mcpu_target); fn.addFnAttr("tune-cpu", mcpu_tune); From 0aa0c03fd810ebb00b1c317175c2cda86c003bd4 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 20 Nov 2024 13:32:19 -0800 Subject: [PATCH 3/5] Fix some unchecked accesses, and turn off checking of the others (false positives) --- .clang-tidy | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.clang-tidy b/.clang-tidy index 283acd5f9bd3..82793ffae4b3 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -70,7 +70,7 @@ Checks: > bugprone-terminating-continue, bugprone-throw-keyword-missing, bugprone-too-small-loop-variable, - bugprone-unchecked-optional-access, + -bugprone-unchecked-optional-access, # Too many false-positives bugprone-undefined-memory-manipulation, bugprone-undelegated-constructor, bugprone-unhandled-exception-at-new, From af079f031184256a35957328a0f83ca301a3bfca Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 21 Nov 2024 09:04:40 -0800 Subject: [PATCH 4/5] Fix onnx converter, one unchecked use --- apps/onnx/model.cpp | 2 +- apps/onnx/onnx_converter.cc | 46 ++++++++++++++++++------------------- src/TargetQueryOps.cpp | 6 ++--- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/apps/onnx/model.cpp b/apps/onnx/model.cpp index 2d8676ed32bc..147841f60608 100644 --- a/apps/onnx/model.cpp +++ b/apps/onnx/model.cpp @@ -174,7 +174,7 @@ void prepare_random_input( const Tensor &t = pipeline.model->tensors.at(input_name); std::vector input_shape; for (int i = 0; i < t.shape.size(); ++i) { - const int64_t *dim = Halide::Internal::as_const_int(t.shape[i]); + auto dim = Halide::Internal::as_const_int(t.shape[i]); if (!dim) { // The dimension isn't fixed: use the estimated typical value instead if // one was provided. diff --git a/apps/onnx/onnx_converter.cc b/apps/onnx/onnx_converter.cc index bad3747802d8..9902beaeae0f 100644 --- a/apps/onnx/onnx_converter.cc +++ b/apps/onnx/onnx_converter.cc @@ -774,8 +774,8 @@ Halide::Func generate_padding_expr( Halide::Expr pad_before = pads[i]; Halide::Expr pad_after = input_shape[i + skip_dims] + pad_before - 1; padding_extents.emplace_back(pad_before, pad_after); - const int64_t *p1 = Halide::Internal::as_const_int(pad_before); - const int64_t *p2 = + auto p1 = Halide::Internal::as_const_int(pad_before); + auto p2 = Halide::Internal::as_const_int(pads[rank - skip_dims - i]); if (!p1 || *p1 != 0 || !p2 || *p2 != 0) { maybe_has_padding = true; @@ -1089,7 +1089,7 @@ Node convert_conv_node( bool supported_shape = true; for (int i = 2; i < rank; ++i) { const Halide::Expr w_shape_expr = Halide::Internal::simplify(W.shape[i]); - const int64_t *dim = Halide::Internal::as_const_int(w_shape_expr); + auto dim = Halide::Internal::as_const_int(w_shape_expr); if (!dim || *dim != 3) { supported_shape = false; break; @@ -1912,7 +1912,7 @@ Node convert_split_node( axis += inputs[0].shape.size(); } Halide::Expr axis_dim = inputs[0].shape.at(axis); - const int64_t *axis_dim_size = Halide::Internal::as_const_int(axis_dim); + auto axis_dim_size = Halide::Internal::as_const_int(axis_dim); if (user_splits.size() == 0) { if (axis_dim_size && (*axis_dim_size % num_outputs != 0)) { @@ -2041,11 +2041,11 @@ Node convert_slice_node( Halide::Internal::simplify(starts_tensor.shape[0]); const Halide::Expr ends_shape_expr = Halide::Internal::simplify(ends_tensor.shape[0]); - const int64_t *starts_shape_dim_0 = + auto starts_shape_dim_0 = Halide::Internal::as_const_int(starts_shape_expr); - const int64_t *ends_shape_dim_0 = + auto ends_shape_dim_0 = Halide::Internal::as_const_int(ends_shape_expr); - if (starts_shape_dim_0 == nullptr && ends_shape_dim_0 == nullptr) { + if (!starts_shape_dim_0 && !ends_shape_dim_0) { throw std::invalid_argument( "Can't statisticaly infer slice dim size for slice node " + node.name()); @@ -2053,7 +2053,7 @@ Node convert_slice_node( result.requirements.push_back(starts_shape_expr == ends_shape_expr); } num_slice_dims = - starts_shape_dim_0 != nullptr ? *starts_shape_dim_0 : *ends_shape_dim_0; + starts_shape_dim_0 ? *starts_shape_dim_0 : *ends_shape_dim_0; if (num_slice_dims != *ends_shape_dim_0) { throw std::invalid_argument( "Starts and ends input tensor must have the same shape for " @@ -2074,9 +2074,9 @@ Node convert_slice_node( const Tensor &axes_tensor = inputs[3]; const Halide::Expr axes_shape_expr = Halide::Internal::simplify(axes_tensor.shape[0]); - const int64_t *axes_shape_dim_0 = + auto axes_shape_dim_0 = Halide::Internal::as_const_int(axes_shape_expr); - if (axes_shape_dim_0 != nullptr && *axes_shape_dim_0 != num_slice_dims) { + if (axes_shape_dim_0 && *axes_shape_dim_0 != num_slice_dims) { throw std::invalid_argument( "Axes tensor must have the same shape as starts and ends for slice " "node " + @@ -2099,9 +2099,9 @@ Node convert_slice_node( const Tensor &steps_tensor = inputs[4]; const Halide::Expr steps_shape_expr = Halide::Internal::simplify(steps_tensor.shape[0]); - const int64_t *steps_shape_dim_0 = + auto steps_shape_dim_0 = Halide::Internal::as_const_int(steps_shape_expr); - if (steps_shape_dim_0 != nullptr && *steps_shape_dim_0 != num_slice_dims) { + if (steps_shape_dim_0 && *steps_shape_dim_0 != num_slice_dims) { throw std::invalid_argument( "Steps tensor must have the same shape as starts and ends for slice " "node " + @@ -2414,7 +2414,7 @@ Node convert_squeeze_node( if (implicit) { for (int i = 0; i < rank; ++i) { const Halide::Expr dim_expr = Halide::Internal::simplify(input.shape[i]); - const int64_t *dim = Halide::Internal::as_const_int(dim_expr); + auto dim = Halide::Internal::as_const_int(dim_expr); if (!dim) { throw std::invalid_argument( "Unknown dimension for input dim " + std::to_string(i) + @@ -2471,7 +2471,7 @@ Node convert_constant_of_shape( Tensor &out = result.outputs[0]; const Halide::Expr shape_expr = Halide::Internal::simplify(inputs[0].shape[0]); - const int64_t *shape_dim_0 = Halide::Internal::as_const_int(shape_expr); + auto shape_dim_0 = Halide::Internal::as_const_int(shape_expr); if (!shape_dim_0) { throw std::invalid_argument( "Can't infer rank statically for ConstantOfShape node " + node.name()); @@ -2744,7 +2744,7 @@ Node convert_expand_node( const int in_rank = input.shape.size(); const Halide::Expr shape_expr = Halide::Internal::simplify(expand_shape.shape[0]); - const int64_t *shape_dim_0 = Halide::Internal::as_const_int(shape_expr); + auto shape_dim_0 = Halide::Internal::as_const_int(shape_expr); if (!shape_dim_0) { throw std::invalid_argument( "Can't infer rank statically for expand node " + node.name()); @@ -3098,7 +3098,7 @@ Node convert_reshape_node( } const Halide::Expr shape_expr = Halide::Internal::simplify(new_shape.shape[0]); - const int64_t *num_dims = Halide::Internal::as_const_int(shape_expr); + auto num_dims = Halide::Internal::as_const_int(shape_expr); if (!num_dims) { throw std::domain_error( "Couldn't statically infer the rank of the output of " + node.name()); @@ -3285,7 +3285,7 @@ Node convert_gru_node( } const Halide::Expr dim_expr = Halide::Internal::simplify(inputs[0].shape[0]); - const int64_t *dim = Halide::Internal::as_const_int(dim_expr); + auto dim = Halide::Internal::as_const_int(dim_expr); if (!dim) { throw std::domain_error("Unknown number of timesteps"); } @@ -3683,7 +3683,7 @@ Node convert_rnn_node( } const Halide::Expr dim_expr = Halide::Internal::simplify(inputs[0].shape[0]); - const int64_t *dim = Halide::Internal::as_const_int(dim_expr); + auto dim = Halide::Internal::as_const_int(dim_expr); if (!dim) { throw std::domain_error("Unknown number of timesteps"); } @@ -3925,7 +3925,7 @@ Node convert_lstm_node( throw std::domain_error("Invalid rank"); } const Halide::Expr dim_expr = Halide::Internal::simplify(inputs[0].shape[0]); - const int64_t *dim = Halide::Internal::as_const_int(dim_expr); + auto dim = Halide::Internal::as_const_int(dim_expr); if (!dim) { throw std::domain_error("Unknown number of timesteps"); } @@ -4722,7 +4722,7 @@ Model convert_model( throw std::domain_error("Invalid dimensions for output " + output.name()); } for (int i = 0; i < args.size(); ++i) { - const int64_t *dim_value = Halide::Internal::as_const_int(dims[i]); + auto dim_value = Halide::Internal::as_const_int(dims[i]); if (dim_value) { int dim = static_cast(*dim_value); f.set_estimate(args[i], 0, dim); @@ -4777,7 +4777,7 @@ static int64_t infer_dim_from_inputs( replacement.min, replacement.extent, result); } result = Halide::Internal::simplify(result); - const int64_t *actual_dim = Halide::Internal::as_const_int(result); + auto actual_dim = Halide::Internal::as_const_int(result); if (!actual_dim) { throw std::invalid_argument( "Couldn't statically infer one of the dimensions of output " + name); @@ -4812,7 +4812,7 @@ void compute_output_shapes( std::vector &output_shape = (*output_shapes)[name]; const int rank = t.shape.size(); for (int i = 0; i < rank; ++i) { - const int64_t *dim = Halide::Internal::as_const_int(t.shape[i]); + auto dim = Halide::Internal::as_const_int(t.shape[i]); if (!dim) { output_shape.push_back( infer_dim_from_inputs(t.shape[i], replacements, name)); @@ -4833,7 +4833,7 @@ void extract_expected_input_shapes( const Tensor &t = model.tensors.at(input_name); std::vector input_shape; for (int i = 0; i < t.shape.size(); ++i) { - const int64_t *dim = Halide::Internal::as_const_int(t.shape[i]); + auto dim = Halide::Internal::as_const_int(t.shape[i]); if (!dim) { // The dimension isn't fixed: use the estimated typical value instead if // one was provided. diff --git a/src/TargetQueryOps.cpp b/src/TargetQueryOps.cpp index 337d90c29b70..f76d2058f7e4 100644 --- a/src/TargetQueryOps.cpp +++ b/src/TargetQueryOps.cpp @@ -16,16 +16,16 @@ class LowerTargetQueryOps : public IRMutator { Expr visit(const Call *call) override { if (call->is_intrinsic(Call::target_arch_is)) { - Target::Arch arch = (Target::Arch)*as_const_int(call->args[0]); + Target::Arch arch = (Target::Arch)as_const_int(call->args[0]).value(); return make_bool(t.arch == arch); } else if (call->is_intrinsic(Call::target_has_feature)) { - Target::Feature feat = (Target::Feature)*as_const_int(call->args[0]); + Target::Feature feat = (Target::Feature)as_const_int(call->args[0]).value(); return make_bool(t.has_feature(feat)); } else if (call->is_intrinsic(Call::target_natural_vector_size)) { Expr zero = call->args[0]; return Expr(t.natural_vector_size(zero.type())); } else if (call->is_intrinsic(Call::target_os_is)) { - Target::OS os = (Target::OS)*as_const_int(call->args[0]); + Target::OS os = (Target::OS)as_const_int(call->args[0]).value(); return make_bool(t.os == os); } else if (call->is_intrinsic(Call::target_bits)) { return Expr(t.bits); From ba1009731215866dc393930d29b3471ea4f3cf9f Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 22 Nov 2024 08:44:13 -0800 Subject: [PATCH 5/5] return std::nullopt --- src/CodeGen_Internal.cpp | 12 ++++++------ src/IROperator.cpp | 18 +++++++++--------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/CodeGen_Internal.cpp b/src/CodeGen_Internal.cpp index c46ee612477f..c2147ff62b53 100644 --- a/src/CodeGen_Internal.cpp +++ b/src/CodeGen_Internal.cpp @@ -557,15 +557,15 @@ Expr lower_round_to_nearest_ties_to_even(const Expr &x) { namespace { std::optional get_md_int(llvm::Metadata *value) { if (!value) { - return {}; + return std::nullopt; } llvm::ConstantAsMetadata *cam = llvm::cast(value); if (!cam) { - return {}; + return std::nullopt; } llvm::ConstantInt *c = llvm::cast(cam->getValue()); if (!c) { - return {}; + return std::nullopt; } return c->getSExtValue(); } @@ -574,19 +574,19 @@ std::optional get_md_bool(llvm::Metadata *value) { if (auto r = get_md_int(value)) { return *r != 0; } else { - return {}; + return std::nullopt; } } std::optional get_md_string(llvm::Metadata *value) { if (!value) { - return {}; + return std::nullopt; } llvm::MDString *c = llvm::dyn_cast(value); if (c) { return c->getString().str(); } - return {}; + return std::nullopt; } } // namespace diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 68f9aa2d2747..41ea946e10f4 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -181,43 +181,43 @@ bool is_pure(const Expr &e) { std::optional as_const_int(const Expr &e) { if (!e.defined()) { - return {}; + return std::nullopt; } else if (const Broadcast *b = e.as()) { return as_const_int(b->value); } else if (const IntImm *i = e.as()) { return i->value; } else { - return {}; + return std::nullopt; } } std::optional as_const_uint(const Expr &e) { if (!e.defined()) { - return {}; + return std::nullopt; } else if (const Broadcast *b = e.as()) { return as_const_uint(b->value); } else if (const UIntImm *i = e.as()) { return i->value; } else { - return {}; + return std::nullopt; } } std::optional as_const_float(const Expr &e) { if (!e.defined()) { - return {}; + return std::nullopt; } else if (const Broadcast *b = e.as()) { return as_const_float(b->value); } else if (const FloatImm *f = e.as()) { return f->value; } else { - return {}; + return std::nullopt; } } std::optional is_const_power_of_two_integer(const Expr &e) { if (!(e.type().is_int() || e.type().is_uint())) { - return {}; + return std::nullopt; } if (const Broadcast *b = e.as()) { @@ -229,7 +229,7 @@ std::optional is_const_power_of_two_integer(const Expr &e) { } else if (auto u = as_const_uint(e)) { return is_const_power_of_two_integer(*u); } else { - return {}; + return std::nullopt; } } @@ -237,7 +237,7 @@ std::optional is_const_power_of_two_integer(uint64_t val) { if (val && ((val & (val - 1)) == 0)) { return ctz64(val); } else { - return {}; + return std::nullopt; } }