From bcf4c99e30582c5b3988a4da7434bdf098cd2246 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sat, 11 Jun 2022 06:57:55 -0400 Subject: [PATCH] Revert "codegen: explicitly handle Float16 intrinsics (#45249)" This reverts commit f2c627ef8af37c3cf94c19a5403bc6cd796d5031. --- src/APInt-C.cpp | 6 +- src/julia.expmap | 6 + src/julia_internal.h | 14 +- src/llvm-demote-float16.cpp | 296 +++++++----------------------------- src/runtime_intrinsics.c | 64 ++++---- 5 files changed, 92 insertions(+), 294 deletions(-) diff --git a/src/APInt-C.cpp b/src/APInt-C.cpp index f06d4362bf958..bc0a62e21dd3e 100644 --- a/src/APInt-C.cpp +++ b/src/APInt-C.cpp @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) { void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) { double Val; if (numbits == 16) - Val = julia__gnu_h2f_ieee(*(uint16_t*)pa); + Val = __gnu_h2f_ieee(*(uint16_t*)pa); else if (numbits == 32) Val = *(float*)pa; else if (numbits == 64) @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(true); } if (onumbits == 16) - *(uint16_t*)pr = julia__gnu_f2h_ieee(val); + *(uint16_t*)pr = __gnu_f2h_ieee(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(false); } if (onumbits == 16) - *(uint16_t*)pr = julia__gnu_f2h_ieee(val); + *(uint16_t*)pr = __gnu_f2h_ieee(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) diff --git a/src/julia.expmap b/src/julia.expmap index 6e373798102b2..13de1b873f7c3 100644 --- a/src/julia.expmap +++ b/src/julia.expmap @@ -37,6 +37,12 @@ environ; __progname; + /* compiler run-time intrinsics */ + __gnu_h2f_ieee; + __extendhfsf2; + __gnu_f2h_ieee; + __truncdfhf2; + local: *; }; diff --git a/src/julia_internal.h b/src/julia_internal.h index 072c1141de653..6055bbfd8a922 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1544,18 +1544,8 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; #define JL_GC_ASSERT_LIVE(x) (void)(x) #endif -JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; -JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; -JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT; -//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT; +float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; +uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; #ifdef __cplusplus } diff --git a/src/llvm-demote-float16.cpp b/src/llvm-demote-float16.cpp index 7701bb26508a8..054ec46162160 100644 --- a/src/llvm-demote-float16.cpp +++ b/src/llvm-demote-float16.cpp @@ -45,194 +45,15 @@ INST_STATISTIC(FCmp); namespace { -inline AttributeSet getFnAttrs(const AttributeList &Attrs) -{ -#if JL_LLVM_VERSION >= 140000 - return Attrs.getFnAttrs(); -#else - return Attrs.getFnAttributes(); -#endif -} - -inline AttributeSet getRetAttrs(const AttributeList &Attrs) -{ -#if JL_LLVM_VERSION >= 140000 - return Attrs.getRetAttrs(); -#else - return Attrs.getRetAttributes(); -#endif -} - -static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef args) -{ - Intrinsic::ID ID = call->getIntrinsicID(); - assert(ID); - auto oldfType = call->getFunctionType(); - auto nargs = oldfType->getNumParams(); - assert(args.size() > nargs); - SmallVector argTys(nargs); - for (unsigned i = 0; i < nargs; i++) - argTys[i] = args[i]->getType(); - auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg()); - - // Accumulate an array of overloaded types for the given intrinsic - // and compute the new name mangling schema - SmallVector overloadTys; - { - SmallVector Table; - getIntrinsicInfoTableEntries(ID, Table); - ArrayRef TableRef = Table; - auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys); - assert(res == Intrinsic::MatchIntrinsicTypes_Match); - (void)res; - bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef); - assert(matchvararg); - (void)matchvararg; - } - auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys); - assert(newF->getFunctionType() == newfType); - newF->setCallingConv(call->getCallingConv()); - assert(args.back() == call->getCalledFunction()); - auto newCall = CallInst::Create(newF, args.drop_back(), "", call); - newCall->setTailCallKind(call->getTailCallKind()); - auto old_attrs = call->getAttributes(); - newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs), - getRetAttrs(old_attrs), {})); // drop parameter attributes - return newCall; -} - - -static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder) -{ - Type *SrcTy = V->getType(); - Type *RetTy = DestTy; - if (auto *VC = dyn_cast(V)) { - // The input IR often has things of the form - // fcmp olt half %0, 0xH7C00 - // and we would like to avoid turning that constant into a call here - // if we can simply constant fold it to the new type. - VC = ConstantExpr::getCast(opcode, VC, DestTy, true); - if (VC) - return VC; - } - assert(SrcTy->isVectorTy() == DestTy->isVectorTy()); - if (SrcTy->isVectorTy()) { - unsigned NumElems = cast(SrcTy)->getNumElements(); - assert(cast(DestTy)->getNumElements() == NumElems && "Mismatched cast"); - Value *NewV = UndefValue::get(DestTy); - RetTy = RetTy->getScalarType(); - for (unsigned i = 0; i < NumElems; ++i) { - Value *I = builder.getInt32(i); - Value *Vi = builder.CreateExtractElement(V, I); - Vi = CreateFPCast(opcode, Vi, RetTy, builder); - NewV = builder.CreateInsertElement(NewV, Vi, I); - } - return NewV; - } - auto &M = *builder.GetInsertBlock()->getModule(); - auto &ctx = M.getContext(); - // Pick the Function to call in the Julia runtime - StringRef Name; - switch (opcode) { - case Instruction::FPExt: - // this is exact, so we only need one conversion - assert(SrcTy->isHalfTy()); - Name = "julia__gnu_h2f_ieee"; - RetTy = Type::getFloatTy(ctx); - break; - case Instruction::FPTrunc: - assert(DestTy->isHalfTy()); - if (SrcTy->isFloatTy()) - Name = "julia__gnu_f2h_ieee"; - else if (SrcTy->isDoubleTy()) - Name = "julia__truncdfhf2"; - break; - // All F16 fit exactly in Int32 (-65504 to 65504) - case Instruction::FPToSI: JL_FALLTHROUGH; - case Instruction::FPToUI: - assert(SrcTy->isHalfTy()); - Name = "julia__gnu_h2f_ieee"; - RetTy = Type::getFloatTy(ctx); - break; - case Instruction::SIToFP: JL_FALLTHROUGH; - case Instruction::UIToFP: - assert(DestTy->isHalfTy()); - Name = "julia__gnu_f2h_ieee"; - SrcTy = Type::getFloatTy(ctx); - break; - default: - errs() << Instruction::getOpcodeName(opcode) << ' '; - V->getType()->print(errs()); - errs() << " to "; - DestTy->print(errs()); - errs() << " is an "; - llvm_unreachable("invalid cast"); - } - if (Name.empty()) { - errs() << Instruction::getOpcodeName(opcode) << ' '; - V->getType()->print(errs()); - errs() << " to "; - DestTy->print(errs()); - errs() << " is an "; - llvm_unreachable("illegal cast"); - } - // Coerce the source to the required size and type - auto T_int16 = Type::getInt16Ty(ctx); - if (SrcTy->isHalfTy()) - SrcTy = T_int16; - if (opcode == Instruction::SIToFP) - V = builder.CreateSIToFP(V, SrcTy); - else if (opcode == Instruction::UIToFP) - V = builder.CreateUIToFP(V, SrcTy); - else - V = builder.CreateBitCast(V, SrcTy); - // Call our intrinsic - if (RetTy->isHalfTy()) - RetTy = T_int16; - auto FT = FunctionType::get(RetTy, {SrcTy}, false); - FunctionCallee F = M.getOrInsertFunction(Name, FT); - Value *I = builder.CreateCall(F, {V}); - // Coerce the result to the expected type - if (opcode == Instruction::FPToSI) - I = builder.CreateFPToSI(I, DestTy); - else if (opcode == Instruction::FPToUI) - I = builder.CreateFPToUI(I, DestTy); - else if (opcode == Instruction::FPExt) - I = builder.CreateFPCast(I, DestTy); - else - I = builder.CreateBitCast(I, DestTy); - return I; -} - static bool demoteFloat16(Function &F) { auto &ctx = F.getContext(); + auto T_float16 = Type::getHalfTy(ctx); auto T_float32 = Type::getFloatTy(ctx); SmallVector erase; for (auto &BB : F) { for (auto &I : BB) { - // extend Float16 operands to Float32 - bool Float16 = I.getType()->getScalarType()->isHalfTy(); - for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) { - Value *Op = I.getOperand(i); - if (Op->getType()->getScalarType()->isHalfTy()) - Float16 = true; - } - if (!Float16) - continue; - - if (auto CI = dyn_cast(&I)) { - if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL) - ++TotalChanged; - IRBuilder<> builder(&I); - Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder); - I.replaceAllUsesWith(NewI); - erase.push_back(&I); - } - continue; - } - switch (I.getOpcode()) { case Instruction::FNeg: case Instruction::FAdd: @@ -243,9 +64,6 @@ static bool demoteFloat16(Function &F) case Instruction::FCmp: break; default: - if (auto intrinsic = dyn_cast(&I)) - if (intrinsic->getIntrinsicID()) - break; continue; } @@ -257,78 +75,72 @@ static bool demoteFloat16(Function &F) IRBuilder<> builder(&I); // extend Float16 operands to Float32 - // XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct? + bool OperandsChanged = false; SmallVector Operands(I.getNumOperands()); for (size_t i = 0; i < I.getNumOperands(); i++) { Value *Op = I.getOperand(i); - if (Op->getType()->getScalarType()->isHalfTy()) { + if (Op->getType() == T_float16) { ++TotalExt; - Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder); + Op = builder.CreateFPExt(Op, T_float32); + OperandsChanged = true; } Operands[i] = (Op); } // recreate the instruction if any operands changed, // truncating the result back to Float16 - Value *NewI; - ++TotalChanged; - switch (I.getOpcode()) { - case Instruction::FNeg: - assert(Operands.size() == 1); - ++FNegChanged; - NewI = builder.CreateFNeg(Operands[0]); - break; - case Instruction::FAdd: - assert(Operands.size() == 2); - ++FAddChanged; - NewI = builder.CreateFAdd(Operands[0], Operands[1]); - break; - case Instruction::FSub: - assert(Operands.size() == 2); - ++FSubChanged; - NewI = builder.CreateFSub(Operands[0], Operands[1]); - break; - case Instruction::FMul: - assert(Operands.size() == 2); - ++FMulChanged; - NewI = builder.CreateFMul(Operands[0], Operands[1]); - break; - case Instruction::FDiv: - assert(Operands.size() == 2); - ++FDivChanged; - NewI = builder.CreateFDiv(Operands[0], Operands[1]); - break; - case Instruction::FRem: - assert(Operands.size() == 2); - ++FRemChanged; - NewI = builder.CreateFRem(Operands[0], Operands[1]); - break; - case Instruction::FCmp: - assert(Operands.size() == 2); - ++FCmpChanged; - NewI = builder.CreateFCmp(cast(&I)->getPredicate(), - Operands[0], Operands[1]); - break; - default: - if (auto intrinsic = dyn_cast(&I)) { - // XXX: this is not correct in general - // some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.* - Type *RetTy = I.getType(); - if (RetTy->getScalarType()->isHalfTy()) - RetTy = RetTy->getWithNewType(T_float32); - NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands); + if (OperandsChanged) { + Value *NewI; + ++TotalChanged; + switch (I.getOpcode()) { + case Instruction::FNeg: + assert(Operands.size() == 1); + ++FNegChanged; + NewI = builder.CreateFNeg(Operands[0]); + break; + case Instruction::FAdd: + assert(Operands.size() == 2); + ++FAddChanged; + NewI = builder.CreateFAdd(Operands[0], Operands[1]); + break; + case Instruction::FSub: + assert(Operands.size() == 2); + ++FSubChanged; + NewI = builder.CreateFSub(Operands[0], Operands[1]); break; + case Instruction::FMul: + assert(Operands.size() == 2); + ++FMulChanged; + NewI = builder.CreateFMul(Operands[0], Operands[1]); + break; + case Instruction::FDiv: + assert(Operands.size() == 2); + ++FDivChanged; + NewI = builder.CreateFDiv(Operands[0], Operands[1]); + break; + case Instruction::FRem: + assert(Operands.size() == 2); + ++FRemChanged; + NewI = builder.CreateFRem(Operands[0], Operands[1]); + break; + case Instruction::FCmp: + assert(Operands.size() == 2); + ++FCmpChanged; + NewI = builder.CreateFCmp(cast(&I)->getPredicate(), + Operands[0], Operands[1]); + break; + default: + abort(); } - abort(); - } - cast(NewI)->copyMetadata(I); - cast(NewI)->copyFastMathFlags(&I); - if (NewI->getType() != I.getType()) { - ++TotalTrunc; - NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder); + cast(NewI)->copyMetadata(I); + cast(NewI)->copyFastMathFlags(&I); + if (NewI->getType() != I.getType()) { + ++TotalTrunc; + NewI = builder.CreateFPTrunc(NewI, I.getType()); + } + I.replaceAllUsesWith(NewI); + erase.push_back(&I); } - I.replaceAllUsesWith(NewI); - erase.push_back(&I); } } diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index ea912b61ac4c3..89c9449e55920 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -15,6 +15,9 @@ const unsigned int host_char_bit = 8; // float16 intrinsics +// TODO: use LLVM's compiler-rt on all platforms (Xcode already links compiler-rt) + +#if !defined(_OS_DARWIN_) static inline float half_to_float(uint16_t ival) JL_NOTSAFEPOINT { @@ -185,17 +188,22 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT return h; } -JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) +JL_DLLEXPORT float __gnu_h2f_ieee(uint16_t param) { return half_to_float(param); } -JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) +JL_DLLEXPORT float __extendhfsf2(uint16_t param) +{ + return half_to_float(param); +} + +JL_DLLEXPORT uint16_t __gnu_f2h_ieee(float param) { return float_to_half(param); } -JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) +JL_DLLEXPORT uint16_t __truncdfhf2(double param) { float res = (float)param; uint32_t resi; @@ -217,25 +225,7 @@ JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) return float_to_half(res); } -//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) { return (uint32_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) { return (uint64_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) { return julia__gnu_f2h_ieee((float)n); } -//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) { return julia__gnu_f2h_ieee((float)n); } -//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) { return julia__gnu_f2h_ieee((float)n); } -//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) { return julia__gnu_f2h_ieee((float)n); } -//HANDLE_LIBCALL(F16, F128, __extendhftf2) -//HANDLE_LIBCALL(F16, F80, __extendhfxf2) -//HANDLE_LIBCALL(F80, F16, __truncxfhf2) -//HANDLE_LIBCALL(F128, F16, __trunctfhf2) -//HANDLE_LIBCALL(PPCF128, F16, __trunctfhf2) -//HANDLE_LIBCALL(F16, I128, __fixhfti) -//HANDLE_LIBCALL(F16, I128, __fixunshfti) -//HANDLE_LIBCALL(I128, F16, __floattihf) -//HANDLE_LIBCALL(I128, F16, __floatuntihf) - +#endif // run time version of bitcast intrinsic JL_DLLEXPORT jl_value_t *jl_bitcast(jl_value_t *ty, jl_value_t *v) @@ -561,9 +551,9 @@ static inline unsigned select_by_size(unsigned sz) JL_NOTSAFEPOINT } #define fp_select(a, func) \ - sizeof(a) <= sizeof(float) ? func##f((float)a) : func(a) + sizeof(a) == sizeof(float) ? func##f((float)a) : func(a) #define fp_select2(a, b, func) \ - sizeof(a) <= sizeof(float) ? func##f(a, b) : func(a, b) + sizeof(a) == sizeof(float) ? func##f(a, b) : func(a, b) // fast-function generators // @@ -607,11 +597,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ { \ uint16_t a = *(uint16_t*)pa; \ - float A = julia__gnu_h2f_ieee(a); \ + float A = __gnu_h2f_ieee(a); \ if (osize == 16) { \ float R; \ OP(&R, A); \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ + *(uint16_t*)pr = __gnu_f2h_ieee(R); \ } else { \ OP((uint16_t*)pr, A); \ } \ @@ -635,11 +625,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr) { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = julia__gnu_h2f_ieee(a); \ - float B = julia__gnu_h2f_ieee(b); \ + float A = __gnu_h2f_ieee(a); \ + float B = __gnu_h2f_ieee(b); \ runtime_nbits = 16; \ float R = OP(A, B); \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ + *(uint16_t*)pr = __gnu_f2h_ieee(R); \ } // float or integer inputs, bool output @@ -660,8 +650,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = julia__gnu_h2f_ieee(a); \ - float B = julia__gnu_h2f_ieee(b); \ + float A = __gnu_h2f_ieee(a); \ + float B = __gnu_h2f_ieee(b); \ runtime_nbits = 16; \ return OP(A, B); \ } @@ -701,12 +691,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc, uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ uint16_t c = *(uint16_t*)pc; \ - float A = julia__gnu_h2f_ieee(a); \ - float B = julia__gnu_h2f_ieee(b); \ - float C = julia__gnu_h2f_ieee(c); \ + float A = __gnu_h2f_ieee(a); \ + float B = __gnu_h2f_ieee(b); \ + float C = __gnu_h2f_ieee(c); \ runtime_nbits = 16; \ float R = OP(A, B, C); \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ + *(uint16_t*)pr = __gnu_f2h_ieee(R); \ } @@ -1328,7 +1318,7 @@ static inline int fpiseq##nbits(c_type a, c_type b) JL_NOTSAFEPOINT { \ fpiseq_n(float, 32) fpiseq_n(double, 64) #define fpiseq(a,b) \ - sizeof(a) <= sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b) + sizeof(a) == sizeof(float) ? fpiseq32(a, b) : fpiseq64(a, b) bool_fintrinsic(eq,eq_float) bool_fintrinsic(ne,ne_float) @@ -1377,7 +1367,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui) if (!(osize < 8 * sizeof(a))) \ jl_error("fptrunc: output bitsize must be < input bitsize"); \ else if (osize == 16) \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(a); \ + *(uint16_t*)pr = __gnu_f2h_ieee(a); \ else if (osize == 32) \ *(float*)pr = a; \ else if (osize == 64) \