Skip to content

Commit

Permalink
update LLVM & rework runOnUpgradeableConstantCasts
Browse files Browse the repository at this point in the history
- fix GetIdxsForTyFromOffset when called on unsized types
- rework runOnUpgradeableConstantCasts to be more versatile/general
- add a case in runOnImplicitGEP for implicit cast between
  clspvResourceOrLocal and GEP
- fix runOnAllocaNotAliasing to avoid lowering alloca that do not need
  to be lowered
- Update tests
- Some tests need more work and are marked FAIL

Ref google#1292
  • Loading branch information
rjodinchr committed Feb 1, 2024
1 parent 709972b commit 7ecb4cc
Show file tree
Hide file tree
Showing 53 changed files with 231 additions and 165 deletions.
2 changes: 1 addition & 1 deletion deps.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"subrepo" : "llvm/llvm-project",
"branch" : "main",
"subdir" : "third_party/llvm",
"commit" : "6ec350b4834689af5192a970dc959017f732a8d8"
"commit" : "c105848fd29d3b46eeb794bb6b10dad04f903b09"
},
{
"name" : "SPIRV-Headers",
Expand Down
31 changes: 22 additions & 9 deletions lib/BitcastUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,19 @@ bool IsUnsizedType(const DataLayout &DL, Type *Ty) {
// Interface types are often something like: { [ 0 x Ty ] }.
// SizeInBits returns zero for such types. Try to avoid it by go through the
// type as long as SizeInBits returns zero to get the real type size for it.
Type *reworkUnsizedType(const DataLayout &DL, Type *Ty) {
Type *reworkUnsizedType(const DataLayout &DL, Type *Ty, unsigned *steps) {
unsigned s = 0;
auto size = SizeInBits(DL, Ty);
auto Ele = GetEleType(Ty);
while (size == 0 && Ty != Ele) {
s++;
Ty = Ele;
Ele = GetEleType(Ty);
size = SizeInBits(DL, Ty);
}
if (steps != nullptr) {
*steps = s;
}
return Ty;
}

Expand Down Expand Up @@ -1259,20 +1264,24 @@ uint64_t GoThroughTypeAtOffset(const DataLayout &DataLayout,
return Offset;
}

bool IsClspvResourceOrLocal(Value *val) {
if (auto call = dyn_cast<CallInst>(val)) {
auto builtin_type =
clspv::Builtins::Lookup(call->getCalledFunction()).getType();
return builtin_type == clspv::Builtins::kClspvResource ||
builtin_type == clspv::Builtins::kClspvLocal;
}
return false;
}

SmallVector<Value *, 2>
GetIdxsForTyFromOffset(const DataLayout &DataLayout, IRBuilder<> &Builder,
Type *SrcTy, Type *DstTy, uint64_t CstVal, Value *DynVal,
size_t SmallerBitWidths, Value *Src) {
SmallVector<Value *, 2> Idxs;

assert(Src->getType()->isPointerTy());
bool clspv_resource = false;
if (auto call = dyn_cast<CallInst>(Src)) {
auto builtin_type =
clspv::Builtins::Lookup(call->getCalledFunction()).getType();
clspv_resource = builtin_type == clspv::Builtins::kClspvResource ||
builtin_type == clspv::Builtins::kClspvLocal;
}
bool clspv_resource = IsClspvResourceOrLocal(Src);

unsigned startIdx = 0;
if ((isa<GlobalVariable>(Src) || clspv_resource || isa<AllocaInst>(Src)) &&
Expand All @@ -1289,8 +1298,12 @@ GetIdxsForTyFromOffset(const DataLayout &DataLayout, IRBuilder<> &Builder,
DstTy = Builder.getInt8Ty();
}

SrcTy = reworkUnsizedType(DataLayout, SrcTy);
unsigned steps;
SrcTy = reworkUnsizedType(DataLayout, SrcTy, &steps);
DstTy = reworkUnsizedType(DataLayout, DstTy);
for (unsigned i = 1; i < steps; i++) {
Idxs.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));
}

if (SizeInBits(DataLayout, DstTy) >= SizeInBits(DataLayout, SrcTy) &&
DstTy != SrcTy) {
Expand Down
5 changes: 4 additions & 1 deletion lib/BitcastUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ using namespace llvm;

namespace BitcastUtils {

Type *reworkUnsizedType(const DataLayout &DL, Type *Ty);
Type *reworkUnsizedType(const DataLayout &DL, Type *Ty,
unsigned *steps = nullptr);

size_t SizeInBits(const DataLayout &DL, Type *Ty);
size_t SizeInBits(IRBuilder<> &builder, Type *Ty);
Expand Down Expand Up @@ -76,6 +77,8 @@ uint64_t GoThroughTypeAtOffset(const DataLayout &DataLayout,
IRBuilder<> &Builder, Type *Ty, Type *TargetTy,
uint64_t Offset, SmallVector<Value *, 2> *Idxs);

bool IsClspvResourceOrLocal(Value *val);

SmallVector<Value *, 2>
GetIdxsForTyFromOffset(const DataLayout &DataLayout, IRBuilder<> &Builder,
Type *SrcTy, Type *DstTy, uint64_t CstVal, Value *DynVal,
Expand Down
132 changes: 72 additions & 60 deletions lib/SimplifyPointerBitcastPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ clspv::SimplifyPointerBitcastPass::run(Module &M, ModuleAnalysisManager &) {

changed |= runOnTrivialBitcast(M);
changed |= runOnBitcastFromBitcast(M);
changed |= runOnAllocaNotAliasing(M);
changed |= runOnGEPFromGEP(M);
changed |= runOnImplicitGEP(M);
changed |= runOnUpgradeableConstantCasts(M);
changed |= runOnUnneededIndices(M);
changed |= runOnImplicitCasts(M);
changed |= runOnAllocaNotAliasing(M);
changed |= runOnPHIFromGEP(M);
}

Expand Down Expand Up @@ -389,6 +389,7 @@ bool clspv::SimplifyPointerBitcastPass::runOnImplicitGEP(Module &M) const {
SmallVector<ImplicitGEPAliasing> GEPAliasingList;
SmallVector<ImplicitGEPBeforeStore> GEPBeforeStoreList;
SmallVector<LoadInst *> GEPBeforeLoadList;
SmallVector<GetElementPtrInst *> GEPCastList;
for (auto &F : M) {
for (auto &BB : F) {
for (auto &I : BB) {
Expand Down Expand Up @@ -424,6 +425,10 @@ bool clspv::SimplifyPointerBitcastPass::runOnImplicitGEP(Module &M) const {
} else if (isa<LoadInst>(&I) && isa<GetElementPtrInst>(source) &&
SizeInBits(DL, dest_ty) < SizeInBits(DL, source_ty)) {
GEPBeforeLoadList.push_back(dyn_cast<LoadInst>(&I));
} else if (auto gep = dyn_cast<GetElementPtrInst>(&I)) {
if (IsClspvResourceOrLocal(gep->getPointerOperand())) {
GEPCastList.push_back(dyn_cast<GetElementPtrInst>(&I));
}
}
}
}
Expand Down Expand Up @@ -516,6 +521,33 @@ bool clspv::SimplifyPointerBitcastPass::runOnImplicitGEP(Module &M) const {
LoadInst->getPointerOperand()->dump(););
LLVM_DEBUG(dbgs() << "of: "; LoadInst->dump(););
LoadInst->setOperand(PointerOperandNum, gep);

if (initial_gep->getNumUses() == 0) {
initial_gep->eraseFromParent();
}

changed = true;
}

for (auto gep : GEPCastList) {
auto ptr = gep->getPointerOperand();
auto ty = InferType(ptr, M.getContext(), &type_cache);
IRBuilder<> Builder{gep};
uint64_t cstVal;
Value *dynVal;
size_t smallerBitWidths;
ExtractOffsetFromGEP(DL, Builder, gep, cstVal, dynVal,
smallerBitWidths);
auto new_gep_idxs =
GetIdxsForTyFromOffset(DL, Builder, ty,
reworkUnsizedType(DL, ty), cstVal,
dynVal, smallerBitWidths, ptr);
auto new_gep = GetElementPtrInst::Create(ty, ptr, new_gep_idxs, "", gep);
LLVM_DEBUG(dbgs() << "\n##runOnImplicitGEP (gep cast):\nreplacing: ";
gep->dump());
LLVM_DEBUG(dbgs() << "by: "; new_gep->dump(););
gep->replaceAllUsesWith(new_gep);
gep->eraseFromParent();
changed = true;
}

Expand Down Expand Up @@ -631,13 +663,12 @@ bool clspv::SimplifyPointerBitcastPass::runOnUpgradeableConstantCasts(
bool changed = false;
DenseMap<Value *, Type *> type_cache;

DenseSet<GetElementPtrInst *> seen;
struct UpgradeInfo {
GetElementPtrInst *gep;
Instruction *inst;
ConstantInt *constant;
Type *source_ty;
uint64_t cst;
size_t smallerBitWidth;
Type *dest_ty;
Value *ptr;
};
SmallVector<UpgradeInfo, 8> Worklist;
for (auto &F : M) {
Expand All @@ -652,74 +683,55 @@ bool clspv::SimplifyPointerBitcastPass::runOnUpgradeableConstantCasts(
}

if (auto *gep = dyn_cast<GetElementPtrInst>(source)) {
if (!seen.insert(gep).second) {
if (SizeInBits(DL, source_ty) >= SizeInBits(DL, dest_ty) ||
IsClspvResourceOrLocal(gep->getPointerOperand())) {
continue;
}
auto isIntegerOrFloatingPointTy = [](Type *Ty) {
return Ty->isIntegerTy() || Ty->isFloatingPointTy();
};
if (!isIntegerOrFloatingPointTy(source_ty) ||
!isIntegerOrFloatingPointTy(dest_ty)) {
if (!gep->hasAllConstantIndices()) {
continue;
}

// For some reason, with opaque pointer, LLVM tends to transform
// memcpy/memset into a series of gep and load/store. But while the
// load/store are on i32 for example, it keeps the gep on i8 but
// with index multiples of sizeof(i32). To avoid such bitcast which
// leads to trying to store an i8 into a i32 element (which is not
// supported), upgrade those gep into gep on i32 with the
// appropriate indexes.
SmallVector<Value *, 2> Indices(gep->indices());
if (Indices.size() == 1) {
if (auto cst = dyn_cast<ConstantInt>(Indices[0])) {
Worklist.push_back({gep, &I, cst, source_ty, dest_ty});
}
// should not be used as all indices are constant
IRBuilder<> Builder{gep};

uint64_t cstVal;
Value *dynVal;
size_t smallerBitWidths;
ExtractOffsetFromGEP(DL, Builder, gep, cstVal, dynVal,
smallerBitWidths);
assert(dynVal == nullptr);
if (((cstVal * smallerBitWidths) % SizeInBits(DL, dest_ty)) != 0) {
continue;
}

Worklist.push_back({&I, cstVal, smallerBitWidths, dest_ty,
gep->getPointerOperand()});
}
}
}
}

for (auto GEPInfo : Worklist) {
auto *GEP = GEPInfo.gep;
Instruction *I = GEPInfo.inst;
ConstantInt *cst = GEPInfo.constant;
Type *source_ty = GEPInfo.source_ty;
uint64_t cst = GEPInfo.cst;
size_t smallerBitWidths = GEPInfo.smallerBitWidth;
Type *dest_ty = GEPInfo.dest_ty;
auto source_ty_size = SizeInBits(DL, source_ty);
auto dest_ty_size = SizeInBits(DL, dest_ty);
auto value = cst->getZExtValue();
unsigned new_source_ty_size = source_ty_size;
while (dest_ty_size > source_ty_size &&
dest_ty_size % source_ty_size == 0 && value > 0 && value % 2 == 0 &&
new_source_ty_size < 32) {
value /= 2;
new_source_ty_size *= 2;
}
if (source_ty_size != new_source_ty_size) {
SmallVector<Value *, 2> Indices;
Indices.clear();
Indices.push_back(
ConstantInt::get(Type::getInt32Ty(M.getContext()), value));
auto new_type = Type::getIntNTy(M.getContext(), new_source_ty_size);
auto new_gep = GetElementPtrInst::Create(
new_type, GEP->getPointerOperand(), Indices, "", I);

unsigned PointerOperandNum = BitcastUtils::PointerOperandNum(I);

LLVM_DEBUG(
dbgs() << "\n##runOnUpgradeableConstantCasts:\nreplace operand "
<< PointerOperandNum << " of: ";
I->dump(); dbgs() << "by: "; new_gep->dump());
I->setOperand(PointerOperandNum, new_gep);

if (GEP->getNumUses() == 0) {
GEP->eraseFromParent();
}
Value *ptr = GEPInfo.ptr;
IRBuilder Builder{I};

changed = true;
}
auto NewGEPIdxs =
GetIdxsForTyFromOffset(M.getDataLayout(), Builder, dest_ty, dest_ty,
cst, nullptr, smallerBitWidths, ptr);

auto new_gep = GetElementPtrInst::Create(dest_ty, ptr, NewGEPIdxs, "", I);

unsigned PointerOperandNum = BitcastUtils::PointerOperandNum(I);

LLVM_DEBUG(dbgs() << "\n##runOnUpgradeableConstantCasts:\nreplace operand "
<< PointerOperandNum << " of: ";
I->dump(); dbgs() << "by: "; new_gep->dump());
I->setOperand(PointerOperandNum, new_gep);

changed = true;
}

return changed;
Expand Down Expand Up @@ -907,7 +919,7 @@ bool clspv::SimplifyPointerBitcastPass::runOnAllocaNotAliasing(

auto alloca = dyn_cast<AllocaInst>(source);
auto gep = dyn_cast<GetElementPtrInst>(&I);
if (!alloca || !gep) {
if (!alloca || !gep || gep->getNumUses() == 0) {
continue;
}
int Steps;
Expand Down
2 changes: 1 addition & 1 deletion test/CPlusPlus/object-and-overload.cl
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@
// CHECK-64-DAG: %[[ulong_1:[0-9a-zA-Z_]+]] = OpConstant %[[ulong]] 1
// CHECK-DAG: %[[__original_id_27:[0-9]+]] = OpVariable %[[_ptr_StorageBuffer__struct_7]] StorageBuffer
// CHECK-DAG: %[[__original_id_1:[0-9]+]] = OpVariable %[[_ptr_Workgroup__arr_uint_2]] Workgroup
// CHECK: %[[__original_id_30:[0-9]+]] = OpAccessChain %[[_ptr_Workgroup_uint]] %[[__original_id_1]] %[[uint_0]]
// CHECK: %[[__original_id_31:[0-9]+]] = OpAccessChain %[[_ptr_StorageBuffer_uint]] %[[__original_id_27]] %[[uint_0]] %[[uint_0]]
// CHECK: OpStore %[[__original_id_31]] %[[uint_0]]
// CHECK: %[[__original_id_32:[0-9]+]] = OpAccessChain %[[_ptr_StorageBuffer_uint]] %[[__original_id_27]] %[[uint_0]] %[[uint_1]]
// CHECK: OpStore %[[__original_id_32]] %[[uint_46]]
// CHECK: %[[__original_id_33:[0-9]+]] = OpAccessChain %[[_ptr_StorageBuffer_uint]] %[[__original_id_27]] %[[uint_0]] %[[uint_2]]
// CHECK: OpStore %[[__original_id_33]] %[[uint_92]]
// CHECK: %[[__original_id_30:[0-9]+]] = OpAccessChain %[[_ptr_Workgroup_uint]] %[[__original_id_1]] %[[uint_0]]
// CHECK: OpStore %[[__original_id_30]] %[[uint_25]]
// CHECK-64: %[[__original_id_34:[0-9]+]] = OpAccessChain %[[_ptr_Workgroup_uint]] %[[__original_id_1]] %[[ulong_1]]
// CHECK-32: %[[__original_id_34:[0-9]+]] = OpAccessChain %[[_ptr_Workgroup_uint]] %[[__original_id_1]] %[[uint_1]]
Expand Down
3 changes: 3 additions & 0 deletions test/Coherent/coherent_multiple_subfunctions.cl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// RUN: FileCheck %s < %t.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

// TODO(#1292)
// XFAIL: *

__attribute__((noinline))
int baz(global int* x) { return x[0]; }

Expand Down
3 changes: 3 additions & 0 deletions test/Coherent/coherent_subfunction_parameter.cl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// RUN: FileCheck %s < %t.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

// TODO(#1292)
// XFAIL: *

__attribute__((noinline))
int bar(global int* x) { return x[0]; }

Expand Down
3 changes: 3 additions & 0 deletions test/Coherent/parameter_one_use_is_coherent.cl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// RUN: FileCheck %s < %t.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

// TODO(#1292)
// XFAIL: *

__attribute__((noinline))
int bar(global int* x) { return x[0]; }

Expand Down
3 changes: 3 additions & 0 deletions test/Coherent/selection.cl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// RUN: FileCheck %s < %t.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

// TODO(#1292)
// XFAIL: *

// Both x's should be coherent. y should not be coherent because it is not read.
__attribute__((noinline))
void bar(global int* x, int y) { *x = y; }
Expand Down
3 changes: 3 additions & 0 deletions test/DirectResourceAccess/partial_access_chain_global.cl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
// RUN: FileCheck %s < %t2.spvasm
// RUN: spirv-val --target-env vulkan1.0 %t.spv

// TODO(#1292)
// XFAIL: *

// Kernel |bar| does a non-trivial access chain before calling the helper.

__attribute__((noinline))
Expand Down
2 changes: 1 addition & 1 deletion test/HalfStorage/clspv_vloada_half2_global.cl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ kernel void foo(global float2* A, global uint* B, uint n) {
// CHECK-64-DAG: [[_ulong:%[0-9a-zA-Z_]+]] = OpTypeInt 64 0
// CHECK-DAG: [[_uint_0:%[0-9a-zA-Z_]+]] = OpConstant [[_uint]] 0
// CHECK-DAG: [[_uint_1:%[0-9a-zA-Z_]+]] = OpConstant [[_uint]] 1
// CHECK: [[_31:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[A:%[0-9a-zA-Z_]+]] [[_uint_0]] [[_uint_0]]
// CHECK: [[_32:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[B:%[0-9a-zA-Z_]+]] [[_uint_0]] [[_uint_0]]
// CHECK: [[_34:%[0-9a-zA-Z_]+]] = OpCompositeExtract [[_uint]]
// CHECK-64: [[_offset_long:%[0-9a-zA-Z_]+]] = OpUConvert [[_ulong]] [[_34]]
// CHECK-64: [[_35:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[B]] [[_uint_0]] [[_offset_long]]
// CHECK-32: [[_35:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[B]] [[_uint_0]] [[_34]]
// CHECK: [[_36:%[0-9a-zA-Z_]+]] = OpLoad [[_uint]] [[_35]]
// CHECK: [[_37:%[0-9a-zA-Z_]+]] = OpExtInst [[_v2float]] {{.*}} UnpackHalf2x16 [[_36]]
// CHECK: [[_31:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[A:%[0-9a-zA-Z_]+]] [[_uint_0]] [[_uint_0]]
// CHECK: OpStore [[_31]] [[_37]]
// CHECK: [[_38:%[0-9a-zA-Z_]+]] = OpLoad [[_uint]] [[_32]]
// CHECK: [[_39:%[0-9a-zA-Z_]+]] = OpExtInst [[_v2float]] {{.*}} UnpackHalf2x16 [[_38]]
Expand Down
2 changes: 1 addition & 1 deletion test/HalfStorage/clspv_vloada_half2_local.cl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ kernel void foo(global float2* A, local uint* B, uint n) {
// CHECK-DAG: [[_uint_0:%[0-9a-zA-Z_]+]] = OpConstant [[_uint]] 0
// CHECK-DAG: [[_uint_1:%[0-9a-zA-Z_]+]] = OpConstant [[_uint]] 1
// CHECK: [[_5:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[B:%[0-9a-zA-Z_]+]] [[_uint_0]]
// CHECK: [[_33:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[A:%[0-9a-zA-Z_]+]] [[_uint_0]] [[_uint_0]]
// CHECK: [[_35:%[0-9a-zA-Z_]+]] = OpCompositeExtract [[_uint]]
// CHECK-64: [[_offset_long:%[0-9a-zA-Z_]+]] = OpUConvert [[_ulong]] [[_35]]
// CHECK-64: [[_36:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[B]] [[_offset_long]]
// CHECK-32: [[_36:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[B]] [[_35]]
// CHECK: [[_37:%[0-9a-zA-Z_]+]] = OpLoad [[_uint]] [[_36]]
// CHECK: [[_38:%[0-9a-zA-Z_]+]] = OpExtInst [[_v2float]] {{.*}} UnpackHalf2x16 [[_37]]
// CHECK: [[_33:%[0-9a-zA-Z_]+]] = OpAccessChain {{.*}} [[A:%[0-9a-zA-Z_]+]] [[_uint_0]] [[_uint_0]]
// CHECK: OpStore [[_33]] [[_38]]
// CHECK: [[_39:%[0-9a-zA-Z_]+]] = OpLoad [[_uint]] [[_5]]
// CHECK: [[_40:%[0-9a-zA-Z_]+]] = OpExtInst [[_v2float]] {{.*}} UnpackHalf2x16 [[_39]]
Expand Down
Loading

0 comments on commit 7ecb4cc

Please sign in to comment.