diff --git a/src/hir_conv/constant_evaluation.cpp b/src/hir_conv/constant_evaluation.cpp index fd2a1ce3..9c5674c9 100644 --- a/src/hir_conv/constant_evaluation.cpp +++ b/src/hir_conv/constant_evaluation.cpp @@ -1215,7 +1215,7 @@ namespace MIR { namespace eval { return val; } - const EncodedLiteral& get_const(const ::HIR::Path& in_p, ::HIR::TypeRef* out_ty) + const EncodedLiteral& get_const(const ::HIR::Path& in_p, ::HIR::TypeRef* out_ty) const { auto p = ms.monomorph_path(state.sp, in_p); state.m_resolve.expand_associated_types_path(state.sp, p); @@ -1373,11 +1373,11 @@ namespace MIR { namespace eval { } /// Read a floating point value from a MIR::Param - double read_param_float(unsigned bits, const ::MIR::Param& p) + double read_param_float(unsigned bits, const ::MIR::Param& p) const { - TU_MATCH_HDRA( (p), { ) - TU_ARMA(LValue, e) - return this->get_lval(e).read_float(state, bits); + TU_MATCH_HDRA( (p), {) + TU_ARMA(LValue, e) + return const_cast(this)->get_lval(e).read_float(state, bits); TU_ARMA(Borrow, e) MIR_BUG(state, "Expected a float, got a MIR::Param::Borrow"); TU_ARMA(Constant, e) { @@ -1388,16 +1388,16 @@ namespace MIR { namespace eval { } MIR_ASSERT(state, e.is_Float(), "Expected a float, got " << e); return e.as_Float().v; - } + } } abort(); } - U128 read_param_uint(unsigned bits, const ::MIR::Param& p) + U128 read_param_uint(unsigned bits, const ::MIR::Param& p) const { TU_MATCH_HDRA( (p), { ) - TU_ARMA(LValue, e) - return this->get_lval(e).read_uint(state, bits); + TU_ARMA(LValue, e) + return const_cast(this)->get_lval(e).read_uint(state, bits); TU_ARMA(Borrow, e) MIR_BUG(state, "Expected an integer, got a MIR::Param::Borrow"); TU_ARMA(Constant, e) { @@ -1414,15 +1414,15 @@ namespace MIR { namespace eval { return U128( e.as_Bool().v ? 1 : 0 ); MIR_ASSERT(state, e.is_Uint(), "Expected an integer, got " << e.tag_str() << " " << e); return U128( e.as_Uint().v ); - } + } } abort(); } - S128 read_param_sint(unsigned bits, const ::MIR::Param& p) + S128 read_param_sint(unsigned bits, const ::MIR::Param& p) const { TU_MATCH_HDRA( (p), { ) - TU_ARMA(LValue, e) - return this->get_lval(e).read_sint(state, bits); + TU_ARMA(LValue, e) + return const_cast(this)->get_lval(e).read_sint(state, bits); TU_ARMA(Borrow, e) MIR_BUG(state, "Expected an integer, got a MIR::Param::Borrow"); TU_ARMA(Constant, e) { @@ -1435,7 +1435,7 @@ namespace MIR { namespace eval { throw Defer(); MIR_ASSERT(state, e.is_Int(), "Expected an integer, got " << e.tag_str() << " " << e); return S128( e.as_Int().v ); - } + } } abort(); } @@ -1451,6 +1451,192 @@ namespace MIR { namespace eval { } } // namespace ::MIR::eval +namespace { + bool do_arith_checked( + const ::MIR::eval::CallStackEntry& local_state, + const HIR::TypeRef& ty, + ::MIR::eval::ValueRef& dst, + const ::MIR::Param& val_l, + ::MIR::eBinOp op, + const ::MIR::Param& val_r + ) + { + bool did_overflow = false; + const auto& state = local_state.state; + auto ti = TypeInfo::for_type(ty); + switch(ti.ty) + { + case TypeInfo::Float: { + auto l = local_state.read_param_float(ti.bits, val_l); + auto r = local_state.read_param_float(ti.bits, val_r); + switch(op) + { + case ::MIR::eBinOp::ADD: dst.write_float(state, ti.bits, l + r); break; + case ::MIR::eBinOp::SUB: dst.write_float(state, ti.bits, l - r); break; + case ::MIR::eBinOp::MUL: dst.write_float(state, ti.bits, l * r); break; + case ::MIR::eBinOp::DIV: dst.write_float(state, ti.bits, l / r); break; + case ::MIR::eBinOp::MOD: + case ::MIR::eBinOp::ADD_OV: + case ::MIR::eBinOp::SUB_OV: + case ::MIR::eBinOp::MUL_OV: + case ::MIR::eBinOp::DIV_OV: + MIR_TODO(state, "do_arith float unimplemented - val = " << l << " , " << r); + + case ::MIR::eBinOp::BIT_OR : + case ::MIR::eBinOp::BIT_AND: + case ::MIR::eBinOp::BIT_XOR: + MIR_BUG(state, "do_arith float with bitwise - val = " << l << " , " << r); + case ::MIR::eBinOp::BIT_SHL: + case ::MIR::eBinOp::BIT_SHR: + MIR_BUG(state, "Bitshifts should be handled in caller"); + case ::MIR::eBinOp::EQ: dst.write_byte(state, l == r); break; + case ::MIR::eBinOp::NE: dst.write_byte(state, l != r); break; + case ::MIR::eBinOp::GT: dst.write_byte(state, l > r); break; + case ::MIR::eBinOp::GE: dst.write_byte(state, l >= r); break; + case ::MIR::eBinOp::LT: dst.write_byte(state, l < r); break; + case ::MIR::eBinOp::LE: dst.write_byte(state, l <= r); break; + } + break; }; + case TypeInfo::Unsigned: { + auto l = local_state.read_param_uint(ti.bits, val_l); + auto r = local_state.read_param_uint(ti.bits, val_r); + switch(op) + { + case ::MIR::eBinOp::ADD: { + auto res = ti.mask(l + r); + did_overflow = res < l; + dst.write_uint(state, ti.bits, res); + break; } + case ::MIR::eBinOp::SUB: { + auto res = ti.mask(l - r); + did_overflow = res > l; + dst.write_uint(state, ti.bits, res); + break; } + case ::MIR::eBinOp::MUL: { + auto res = ti.mask(l * r); + if( l != 0 && r != 0 ) { + did_overflow = res < l || res < r; + } + dst.write_uint(state, ti.bits, res); + break; } + case ::MIR::eBinOp::DIV: + // Early-prevent division by zero + if( r == 0 ) { + dst.write_uint(state, ti.bits, U128(0)); + return true; + } + dst.write_uint(state, ti.bits, ti.mask(l / r)); + break; + case ::MIR::eBinOp::MOD: + // Early-prevent division by zero + if( r == 0 ) { + dst.write_uint(state, ti.bits, U128(0)); + return true; + } + dst.write_uint(state, ti.bits, ti.mask(l % r)); + break; + case ::MIR::eBinOp::ADD_OV: + case ::MIR::eBinOp::SUB_OV: + case ::MIR::eBinOp::MUL_OV: + case ::MIR::eBinOp::DIV_OV: + MIR_TODO(state, "do_arith unsigned - val = " << l << " , " << r); + + case ::MIR::eBinOp::BIT_OR : dst.write_uint(state, ti.bits, l | r); break; + case ::MIR::eBinOp::BIT_AND: dst.write_uint(state, ti.bits, l & r); break; + case ::MIR::eBinOp::BIT_XOR: dst.write_uint(state, ti.bits, l ^ r); break; + case ::MIR::eBinOp::BIT_SHL: + case ::MIR::eBinOp::BIT_SHR: + MIR_BUG(state, "Bitshifts should be handled in caller"); + + case ::MIR::eBinOp::EQ: dst.write_byte(state, l == r); break; + case ::MIR::eBinOp::NE: dst.write_byte(state, l != r); break; + case ::MIR::eBinOp::GT: dst.write_byte(state, l > r); break; + case ::MIR::eBinOp::GE: dst.write_byte(state, l >= r); break; + case ::MIR::eBinOp::LT: dst.write_byte(state, l < r); break; + case ::MIR::eBinOp::LE: dst.write_byte(state, l <= r); break; + } + break; } + case TypeInfo::Signed: { + auto l = local_state.read_param_sint(ti.bits, val_l); + auto r = local_state.read_param_sint(ti.bits, val_r); + switch(op) + { + case ::MIR::eBinOp::ADD: { + // Convert to raw/unsigned repr + auto v1u = l.get_inner(); + auto v2u = r.get_inner(); + // Then convert into a sign and absolute value + auto v1s = (l < 0); + auto v2s = (r < 0); + auto v1a = v1s ? ~v1u + 1 : v1u; + auto v2a = v2s ? ~v2u + 1 : v2u; + + // Determine the sign + // - Equal has the same sign + // - V2 negative is negative if |v2| > |v1| + // - V1 negative is negative if |v2| < |v1| + bool res_sign = (v1s == v2s) ? v1s : (v2s ? v1a < v2a : v1a > v2a); + auto res = S128(v1u + v2u); + did_overflow = ((res < 0) != res_sign); + dst.write_sint(state, ti.bits, res); + break; } + case ::MIR::eBinOp::SUB: { + auto res = l - r; + // If the masked value isn't equal to the non-masked, then it's an overflow. + did_overflow = res.get_inner() != ti.mask(res); + dst.write_uint( state, ti.bits, ti.mask(res) ); + break; } + case ::MIR::eBinOp::MUL: { + auto res = l * r; + if( l != 0 && r != 0 ) { + if( res.u_abs() < l.u_abs() || res.u_abs() < r.u_abs() ) { + did_overflow = true; + } + } + dst.write_uint( state, ti.bits, ti.mask(res) ); + break; } + case ::MIR::eBinOp::DIV: + if( r == 0 ) { + dst.write_uint(state, ti.bits, U128(0)); + return true; + } + dst.write_sint(state, ti.bits, ti.mask(l / r)); + break; + case ::MIR::eBinOp::MOD: + if( r == 0 ) { + dst.write_uint(state, ti.bits, U128(0)); + return true; + } + dst.write_sint(state, ti.bits, ti.mask(l % r)); + break; + case ::MIR::eBinOp::ADD_OV: + case ::MIR::eBinOp::SUB_OV: + case ::MIR::eBinOp::MUL_OV: + case ::MIR::eBinOp::DIV_OV: + MIR_TODO(state, "do_arith signed - val = " << l << " , " << r); + + case ::MIR::eBinOp::BIT_OR : dst.write_uint( state, ti.bits, (l | r).get_inner() ); break; + case ::MIR::eBinOp::BIT_AND: dst.write_uint( state, ti.bits, (l & r).get_inner() ); break; + case ::MIR::eBinOp::BIT_XOR: dst.write_uint( state, ti.bits, (l ^ r).get_inner() ); break; + case ::MIR::eBinOp::BIT_SHL: + case ::MIR::eBinOp::BIT_SHR: + MIR_BUG(state, "Bitshifts should be handled in caller"); + + case ::MIR::eBinOp::EQ: dst.write_byte(state, l == r); break; + case ::MIR::eBinOp::NE: dst.write_byte(state, l != r); break; + case ::MIR::eBinOp::GT: dst.write_byte(state, l > r); break; + case ::MIR::eBinOp::GE: dst.write_byte(state, l >= r); break; + case ::MIR::eBinOp::LT: dst.write_byte(state, l < r); break; + case ::MIR::eBinOp::LE: dst.write_byte(state, l <= r); break; + } + break; } + case TypeInfo::Other: + MIR_BUG(state, "BinOp on " << ty); + } + return did_overflow; + } +} + namespace HIR { using namespace ::MIR::eval; @@ -1681,100 +1867,17 @@ namespace HIR { // Skip the rest of this arm (breaks both loops in `TU_ARMA`) break ; } - switch(ti.ty) + bool did_overflow = do_arith_checked(local_state, ty_l, dst, e.val_l, e.op, e.val_r); + switch(e.op) { - case TypeInfo::Float: { - auto l = local_state.read_param_float(ti.bits, e.val_l); - auto r = local_state.read_param_float(ti.bits, e.val_r); - switch(e.op) - { - case ::MIR::eBinOp::ADD: dst.write_float(state, ti.bits, l + r); break; - case ::MIR::eBinOp::SUB: dst.write_float(state, ti.bits, l - r); break; - case ::MIR::eBinOp::MUL: dst.write_float(state, ti.bits, l * r); break; - case ::MIR::eBinOp::DIV: dst.write_float(state, ti.bits, l / r); break; - case ::MIR::eBinOp::MOD: - case ::MIR::eBinOp::ADD_OV: - case ::MIR::eBinOp::SUB_OV: - case ::MIR::eBinOp::MUL_OV: - case ::MIR::eBinOp::DIV_OV: - MIR_TODO(state, "RValue::BinOp - " << sa.src << ", val = " << l << " , " << r); - - case ::MIR::eBinOp::BIT_OR : - case ::MIR::eBinOp::BIT_AND: - case ::MIR::eBinOp::BIT_XOR: - case ::MIR::eBinOp::BIT_SHL: - case ::MIR::eBinOp::BIT_SHR: - MIR_TODO(state, "RValue::BinOp - " << sa.src << ", val = " << l << " , " << r); - case ::MIR::eBinOp::EQ: dst.write_byte(state, l == r); break; - case ::MIR::eBinOp::NE: dst.write_byte(state, l != r); break; - case ::MIR::eBinOp::GT: dst.write_byte(state, l > r); break; - case ::MIR::eBinOp::GE: dst.write_byte(state, l >= r); break; - case ::MIR::eBinOp::LT: dst.write_byte(state, l < r); break; - case ::MIR::eBinOp::LE: dst.write_byte(state, l <= r); break; - } - break; }; - case TypeInfo::Unsigned: { - auto l = local_state.read_param_uint(ti.bits, e.val_l); - auto r = local_state.read_param_uint(ti.bits, e.val_r); - switch(e.op) - { - case ::MIR::eBinOp::ADD: dst.write_uint(state, ti.bits, ti.mask(l + r)); break; - case ::MIR::eBinOp::SUB: dst.write_uint(state, ti.bits, ti.mask(l - r)); break; - case ::MIR::eBinOp::MUL: dst.write_uint(state, ti.bits, ti.mask(l * r)); break; - case ::MIR::eBinOp::DIV: dst.write_uint(state, ti.bits, ti.mask(l / r)); break; - case ::MIR::eBinOp::MOD: dst.write_uint(state, ti.bits, ti.mask(l % r)); break; - case ::MIR::eBinOp::ADD_OV: - case ::MIR::eBinOp::SUB_OV: - case ::MIR::eBinOp::MUL_OV: - case ::MIR::eBinOp::DIV_OV: - MIR_TODO(state, "RValue::BinOp - " << sa.src << ", val = " << l << " , " << r); - - case ::MIR::eBinOp::BIT_OR : dst.write_uint(state, ti.bits, l | r); break; - case ::MIR::eBinOp::BIT_AND: dst.write_uint(state, ti.bits, l & r); break; - case ::MIR::eBinOp::BIT_XOR: dst.write_uint(state, ti.bits, l ^ r); break; - case ::MIR::eBinOp::BIT_SHL: dst.write_uint(state, ti.bits, ti.mask(l << r.truncate_u64())); break; - case ::MIR::eBinOp::BIT_SHR: dst.write_uint(state, ti.bits, ti.mask(l >> r.truncate_u64())); break; - - case ::MIR::eBinOp::EQ: dst.write_byte(state, l == r); break; - case ::MIR::eBinOp::NE: dst.write_byte(state, l != r); break; - case ::MIR::eBinOp::GT: dst.write_byte(state, l > r); break; - case ::MIR::eBinOp::GE: dst.write_byte(state, l >= r); break; - case ::MIR::eBinOp::LT: dst.write_byte(state, l < r); break; - case ::MIR::eBinOp::LE: dst.write_byte(state, l <= r); break; - } - break; } - case TypeInfo::Signed: { - auto l = local_state.read_param_sint(ti.bits, e.val_l); - auto r = local_state.read_param_sint(ti.bits, e.val_r); - switch(e.op) - { - case ::MIR::eBinOp::ADD: dst.write_uint( state, ti.bits, ti.mask(l + r) ); break; - case ::MIR::eBinOp::SUB: dst.write_uint( state, ti.bits, ti.mask(l - r) ); break; - case ::MIR::eBinOp::MUL: dst.write_uint( state, ti.bits, ti.mask(l * r) ); break; - case ::MIR::eBinOp::DIV: dst.write_uint( state, ti.bits, ti.mask(l / r) ); break; - case ::MIR::eBinOp::MOD: dst.write_uint( state, ti.bits, ti.mask(l % r) ); break; - case ::MIR::eBinOp::ADD_OV: - case ::MIR::eBinOp::SUB_OV: - case ::MIR::eBinOp::MUL_OV: - case ::MIR::eBinOp::DIV_OV: - MIR_TODO(state, "RValue::BinOp - " << sa.src << ", val = " << l << " , " << r); - - case ::MIR::eBinOp::BIT_OR : dst.write_uint( state, ti.bits, (l | r).get_inner() ); break; - case ::MIR::eBinOp::BIT_AND: dst.write_uint( state, ti.bits, (l & r).get_inner() ); break; - case ::MIR::eBinOp::BIT_XOR: dst.write_uint( state, ti.bits, (l ^ r).get_inner() ); break; - case ::MIR::eBinOp::BIT_SHL: dst.write_uint( state, ti.bits, ti.mask(l << static_cast(r.get_inner().truncate_u64())) ); break; - case ::MIR::eBinOp::BIT_SHR: dst.write_uint( state, ti.bits, ti.mask(l >> static_cast(r.get_inner().truncate_u64())) ); break; - - case ::MIR::eBinOp::EQ: dst.write_byte(state, l == r); break; - case ::MIR::eBinOp::NE: dst.write_byte(state, l != r); break; - case ::MIR::eBinOp::GT: dst.write_byte(state, l > r); break; - case ::MIR::eBinOp::GE: dst.write_byte(state, l >= r); break; - case ::MIR::eBinOp::LT: dst.write_byte(state, l < r); break; - case ::MIR::eBinOp::LE: dst.write_byte(state, l <= r); break; + case ::MIR::eBinOp::DIV: + case ::MIR::eBinOp::MOD: + if(did_overflow) { + MIR_BUG(state, "Division/modulo by zero!"); } - break; } - case TypeInfo::Other: - MIR_BUG(state, "BinOp on " << ty_l); + break; + default: + break; } } TU_ARMA(UniOp, e) { @@ -2202,44 +2305,48 @@ namespace HIR { // --- else if( te->name == "add_with_overflow" ) { auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); - MIR_ASSERT(state, ty.data().is_Primitive(), "`add_with_overflow` with non-primitive " << ty); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); auto ti = TypeInfo::for_type(ty); - switch(ti.ty) - { - case TypeInfo::Unsigned: { - auto v1 = local_state.read_param_uint(ti.bits, e.args.at(0)); - auto v2 = local_state.read_param_uint(ti.bits, e.args.at(1)); - auto res = ti.mask(v1 + v2); - bool overflowed = res < v1; - dst.write_uint(state, ti.bits, res); - dst.slice(ti.bits / 8).write_uint(state, 8, U128(overflowed ? 1 : 0)); - } break; - case TypeInfo::Signed: { - auto v1r = local_state.read_param_sint(ti.bits, e.args.at(0)); - auto v2r = local_state.read_param_sint(ti.bits, e.args.at(1)); - // Convert to raw/unsigned repr - auto v1u = v1r.get_inner(); - auto v2u = v2r.get_inner(); - // Then convert into a sign and absolute value - auto v1s = (v1r < 0); - auto v2s = (v2r < 0); - auto v1a = v1s ? ~v1u + 1 : v1u; - auto v2a = v2s ? ~v2u + 1 : v2u; - - // Determine the sign - // - Equal has the same sign - // - V2 negative is negative if |v2| > |v1| - // - V1 negative is negative if |v2| < |v1| - bool res_sign = (v1s == v2s) ? v1s : (v2s ? v1a < v2a : v1a > v2a); - auto res = S128(v1u + v2u); - bool overflowed = ((res < 0) != res_sign); - dst.write_sint(state, ti.bits, res); - dst.slice(ti.bits / 8).write_uint(state, 8, U128(overflowed ? 1 : 0)); - } break; - case TypeInfo::Float: - case TypeInfo::Other: - MIR_TODO(state, "add_with_overflow on unexpected type - " << ty); - } + bool overflowed = do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::ADD, e.args.at(1)); + dst.slice(ti.bits / 8).write_uint(state, 8, U128(overflowed ? 1 : 0)); + } + else if( te->name == "sub_with_overflow" ) { + auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); + auto ti = TypeInfo::for_type(ty); + bool overflowed = do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::SUB, e.args.at(1)); + dst.slice(ti.bits / 8).write_uint(state, 8, U128(overflowed ? 1 : 0)); + } + // Unchecked and wrapping are the same + else if( te->name == "wrapping_add" || te->name == "unchecked_add" ) { + auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); + do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::ADD, e.args.at(1)); + } + else if( te->name == "wrapping_sub" || te->name == "unchecked_sub" ) { + auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); + do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::SUB, e.args.at(1)); + } + // - Except for div/rem, which add checking just in case + else if( te->name == "unchecked_rem" ) { + auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); + bool was_overflow = do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::MOD, e.args.at(1)); + MIR_ASSERT(state, !was_overflow, "`" << te->name << "` overflowed"); + } + else if( te->name == "unchecked_div" ) { + auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); + bool was_overflow = do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::DIV, e.args.at(1)); + MIR_ASSERT(state, !was_overflow, "`" << te->name << "` overflowed"); + } + // `exact_div` is UB if the division results in a non-zero remainder (or if the division overflows) + else if( te->name == "exact_div" ) { + auto ty = ms.monomorph_type(state.sp, te->params.m_types.at(0)); + MIR_ASSERT(state, ty.data().is_Primitive(), "`" << te->name << "` with non-primitive " << ty); + bool was_overflow = do_arith_checked(local_state, ty, dst, e.args.at(0), ::MIR::eBinOp::DIV, e.args.at(1)); + MIR_ASSERT(state, !was_overflow, "`" << te->name << "` overflowed"); } // --- else if( te->name == "transmute" ) { @@ -2248,6 +2355,10 @@ namespace HIR { else if( te->name == "unlikely" ) { local_state.write_param(dst, e.args.at(0)); } + else if( te->name == "assume" ) { + auto val = local_state.read_param_uint(8, e.args.at(0)); + MIR_ASSERT(state, val != 0, "`assume` failed"); + } else { MIR_TODO(state, "Call intrinsic \"" << te->name << "\" - " << terminator); }