Skip to content

Commit

Permalink
HIR/Trans - Add support for trait upcasting (with a test!)
Browse files Browse the repository at this point in the history
  • Loading branch information
thepowersgang committed May 3, 2024
1 parent 754ac7e commit 3f2bca3
Show file tree
Hide file tree
Showing 15 changed files with 284 additions and 50 deletions.
27 changes: 27 additions & 0 deletions samples/test/trait_upcast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//

trait A {
fn bar(&self) -> i32;
}
trait B: A {
fn baz(&self) -> i32;
}

struct Foo;
impl A for Foo {
fn bar(&self) -> i32 {
12345
}
}
impl B for Foo {
fn baz(&self) -> i32 {
54321
}
}

fn main() {
let v = Foo;
let b: &dyn B = &v;
let a: &dyn A = b;
assert_eq!( a.bar(), 12345 );
}
3 changes: 2 additions & 1 deletion src/hir/deserialise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1366,11 +1366,12 @@ namespace {
rv.m_values = deserialise_istrumap< ::HIR::TraitValueItem>();
rv.m_value_indexes = deserialise_istrummap< ::std::pair<unsigned int, ::HIR::GenericPath> >();
rv.m_type_indexes = deserialise_istrumap< unsigned int>();
rv.m_vtable_parent_traits_start = m_in.read_count();
rv.m_all_parent_traits = deserialise_vec< ::HIR::TraitPath>();
rv.m_vtable_path = deserialise_simplepath();
return rv;
}

::HIR::ConstGeneric HirDeserialiser::deserialise_constgeneric()
{
switch( auto tag = m_in.read_tag() )
Expand Down
2 changes: 1 addition & 1 deletion src/hir/hir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ const ::HIR::TypeItem& ::HIR::Crate::get_typeitem_by_path(const Span& sp, const

auto it = mod.m_mod_items.find( ignore_last_node ? path.m_components[path.m_components.size()-2] : path.m_components.back() );
if( it == mod.m_mod_items.end() ) {
BUG(sp, "Could not find type name in " << path);
BUG(sp, "Could not find type " << path);
}

return it->second->ent;
Expand Down
4 changes: 4 additions & 0 deletions src/hir/hir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ class Trait
::std::unordered_multimap< RcString, ::std::pair<unsigned int,::HIR::GenericPath> > m_value_indexes;
// Indexes in the vtable parameter list for each associated type
::std::unordered_map< RcString, unsigned int > m_type_indexes;
/// Index of the first vtable entry for parent traits
unsigned m_vtable_parent_traits_start;

// Flattend set of parent traits (monomorphised and associated types fixed)
::std::vector< ::HIR::TraitPath > m_all_parent_traits;
Expand All @@ -461,10 +463,12 @@ class Trait
m_lifetime( mv$(lifetime) ),
m_parent_traits( mv$(parents) ),
m_is_marker( false )
, m_vtable_parent_traits_start(0)
{}

::HIR::TypeRef get_vtable_type(const Span& sp, const ::HIR::Crate& crate, const ::HIR::TypeData::Data_TraitObject& te) const;
unsigned get_vtable_value_index(const HIR::GenericPath& trait_path, const RcString& name) const;
unsigned get_vtable_parent_index(const Span& sp, const HIR::PathParams& this_params, const HIR::GenericPath& trait_path) const;
};

class ProcMacro
Expand Down
13 changes: 13 additions & 0 deletions src/hir/hir_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,19 @@ unsigned HIR::Trait::get_vtable_value_index(const HIR::GenericPath& trait_path,
}
return 0;
}
unsigned HIR::Trait::get_vtable_parent_index(const Span& sp, const HIR::PathParams& this_params, const HIR::GenericPath& trait_path) const
{
for(const auto& pt : this->m_all_parent_traits) {
if( pt.m_path.m_path == trait_path.m_path )
{
auto p = MonomorphStatePtr(nullptr, &this_params, nullptr).monomorph_genericpath(sp, pt.m_path);
if( p == trait_path ) {
return m_vtable_parent_traits_start + (&pt - this->m_all_parent_traits.data());
}
}
}
return 0;
}

/// Helper for getting the struct associated with a pattern path
const ::HIR::Struct& HIR::pattern_get_struct(const Span& sp, const ::HIR::Path& path, const ::HIR::Pattern::PathBinding& binding, bool is_tuple)
Expand Down
1 change: 1 addition & 0 deletions src/hir/serialise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,6 +1335,7 @@
serialise_strmap( item.m_values );
serialise_strmap( item.m_value_indexes );
serialise_strmap( item.m_type_indexes );
m_out.write_count(item.m_vtable_parent_traits_start);
serialise_vec( item.m_all_parent_traits );
serialise( item.m_vtable_path );
}
Expand Down
125 changes: 117 additions & 8 deletions src/hir_expand/vtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace {
::HIR::Trait* trait_ptr;
::HIR::t_struct_fields fields;

bool add_ents_from_trait(const ::HIR::Trait& tr, const ::HIR::GenericPath& trait_path)
bool add_ents_from_trait(const ::HIR::Trait& tr, const ::HIR::GenericPath& trait_path, std::vector<bool>* supertrait_flags)
{
TRACE_FUNCTION_F(trait_path);
struct M: public Monomorphiser {
Expand Down Expand Up @@ -230,11 +230,14 @@ namespace {
}
}
}
for(const auto& st : tr.m_all_parent_traits) {
::HIR::TypeRef self("Self", 0xFFFF);
auto st_gp = MonomorphStatePtr(&self, &trait_path.m_params, nullptr).monomorph_genericpath(sp, st.m_path, false);
// NOTE: Doesn't trigger non-object-safe
add_ents_from_trait(*st.m_trait_ptr, st_gp);
if( supertrait_flags ) {
supertrait_flags->reserve(tr.m_all_parent_traits.size());
for(const auto& st : tr.m_all_parent_traits) {
::HIR::TypeRef self("Self", 0xFFFF);
auto st_gp = MonomorphStatePtr(&self, &trait_path.m_params, nullptr).monomorph_genericpath(sp, st.m_path, false);
// NOTE: Doesn't trigger non-object-safe
supertrait_flags->push_back( add_ents_from_trait(*st.m_trait_ptr, st_gp, nullptr) );
}
}
return true;
}
Expand All @@ -253,12 +256,30 @@ namespace {
// - Alignment of data
vtc.fields.push_back(::std::make_pair( "#align", ::HIR::VisEnt<::HIR::TypeRef> { ::HIR::Publicity::new_none(), ::HIR::CoreType::Usize } ));
// - Add methods
if( ! vtc.add_ents_from_trait(tr, trait_path) || has_conflicting_aty_name )
::std::vector<bool> supertrait_flags;
if( ! vtc.add_ents_from_trait(tr, trait_path, &supertrait_flags) || has_conflicting_aty_name )
{
tr.m_value_indexes.clear();
tr.m_type_indexes.clear();
return ;
}
tr.m_vtable_parent_traits_start = vtc.fields.size();
// Add parent vtables too.
for(size_t i = 0; i < tr.m_all_parent_traits.size(); i ++ )
{
const auto& pt = tr.m_all_parent_traits[i];
auto parent_vtable_spath = pt.m_path.m_path;
parent_vtable_spath.m_components.back() = RcString::new_interned(FMT( parent_vtable_spath.m_components.back().c_str() << "#vtable" ));
auto parent_vtable_path = ::HIR::GenericPath(mv$(parent_vtable_spath), pt.m_path.m_params.clone());
auto ty = true || supertrait_flags[i]
? ::HIR::TypeRef::new_borrow( ::HIR::BorrowType::Shared, ::HIR::TypeRef::new_path(mv$(parent_vtable_path), {}) )
: ::HIR::TypeRef::new_unit()
;
vtc.fields.push_back(::std::make_pair(
RcString::new_interned(FMT("#parent_" << i)),
::HIR::VisEnt<::HIR::TypeRef> { ::HIR::Publicity::new_none(), mv$(ty) }
));
}
auto fields = mv$(vtc.fields);

::HIR::PathParams params;
Expand Down Expand Up @@ -335,11 +356,99 @@ namespace {
#endif
}
};
}

class FixupVisitor:
public ::HIR::Visitor
{
const ::HIR::Crate& m_crate;
public:
FixupVisitor(const ::HIR::Crate& crate):
m_crate(crate)
{
}

void visit_struct(HIR::ItemPath ip, HIR::Struct& str)
{
static Span sp;
auto p = std::strchr(ip.name, '#');
if( p && std::strcmp(p, "#vtable") == 0 )
{
auto trait_path = ip.parent->get_simple_path();
trait_path.m_components.push_back( RcString::new_interned(ip.name, p - ip.name) );
const auto& trait = m_crate.get_trait_by_path(sp, trait_path);

auto& fields = str.m_data.as_Named();
for(size_t i = 0; i < trait.m_all_parent_traits.size(); i ++)
{
const auto& pt = trait.m_all_parent_traits[i];
const auto& parent_trait = *pt.m_trait_ptr;
auto& fld_ty = fields[trait.m_vtable_parent_traits_start + i].second.ent;
DEBUG(pt << " " << fld_ty);

if( parent_trait.m_vtable_path == HIR::SimplePath() ) {
// Not object safe, so clear this entry
fld_ty = ::HIR::TypeRef::new_unit();
}
else {
auto& te = fld_ty.data_mut().as_Borrow().inner.data_mut().as_Path();
auto& vtable_gpath = te.path.m_data.as_Generic();
te.binding = &m_crate.get_struct_by_path(sp, vtable_gpath.m_path);

for(const auto& aty_idx : parent_trait.m_type_indexes)
{
if( vtable_gpath.m_params.m_types.size() <= aty_idx.second ) {
vtable_gpath.m_params.m_types.resize( aty_idx.second+1 );
}
auto& slot = vtable_gpath.m_params.m_types[aty_idx.second];
// If this associated type is in the trait path `pt`
auto it = pt.m_type_bounds.find( aty_idx.first );
if( it != pt.m_type_bounds.end() ) {
slot = it->second.type.clone();
}
// If this type is not in the trait path, then check if it has a defined generic
else if( trait.m_type_indexes.count(aty_idx.first) != 0 ) {
slot = HIR::TypeRef(RcString(), trait.m_type_indexes.at(aty_idx.first));
}
else {
// Otherwise, it has to have been defined in another parent trait
const HIR::GenericPath* gp = nullptr;
for( const auto& pptrait_path : parent_trait.m_all_parent_traits ) {
if( pptrait_path.m_trait_ptr->m_types.count(aty_idx.first) != 0 ) {
// Found the trait that defined this ATY
DEBUG("Found " << aty_idx.first << " in " << pptrait_path);
gp = &pptrait_path.m_path;
}
}
ASSERT_BUG(sp, gp, "Failed to a find trait that defined " << aty_idx.first << " in " << pt.m_path.m_path);

// Monomorph into the top trait
auto gp_mono = MonomorphStatePtr(nullptr, &pt.m_path.m_params, nullptr).monomorph_genericpath(sp, *gp);
// Search the parent list
const HIR::TraitPath* p = nullptr;
for(const auto& pt : trait.m_all_parent_traits) {
if( pt.m_path == gp_mono ) {
p = &pt;
}
}
ASSERT_BUG(sp, p, "Failed to find " << gp_mono << " in parent trait list for " << trait_path);
auto it = p->m_type_bounds.find( aty_idx.first );
ASSERT_BUG(sp, it != p->m_type_bounds.end(), "Failed to find " << aty_idx.first << " in " << *p);
slot = it->second.type.clone();
}
}
}
}
}
}
};
} // namespace

void HIR_Expand_VTables(::HIR::Crate& crate)
{
OuterVisitor ov(crate);
ov.visit_crate( crate );

FixupVisitor fv(crate);
fv.visit_crate(crate);
}

3 changes: 3 additions & 0 deletions src/hir_typeck/expr_cs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4286,6 +4286,9 @@ namespace {
if( dep->m_trait.m_path.m_path != sep->m_trait.m_path.m_path )
{
// Trait mismatch!
#if 1 // 1.74: `trait_upcasting` feature
return CoerceResult::Unsize;
#endif
return CoerceResult::Equality;
}
const auto& tys_d = dep->m_trait.m_path.m_params.m_types;
Expand Down
1 change: 1 addition & 0 deletions src/hir_typeck/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ bool HMTypeInferrence::type_contains_ivars(const ::HIR::TypeRef& ty, bool only_u
return type_contains_ivars(ee, only_unbound);
}
TU_ARMA(Alias, ee) {
return false;
}
}
),
Expand Down
6 changes: 3 additions & 3 deletions src/hir_typeck/outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,12 @@ namespace {
{
auto _ = m_resolve.set_item_generics(e.second.data.m_params);

const auto& vi = trait.m_values.at(e.first);
if(!vi.is_Function()) {
const auto v_it = trait.m_values.find(e.first);
if( v_it == trait.m_values.end() || !v_it->second.is_Function() ) {
ERROR(sp, E0000, "Trait " << trait_path << " doesn't have a method named " << e.first);
}
auto& impl_fcn = e.second.data;
const auto& trait_fcn = vi.as_Function();
const auto& trait_fcn = v_it->second.as_Function();

auto fcn_params = impl_fcn.m_params.make_nop_params(1);
MonomorphStatePtr ms { &impl.m_type, &impl.m_trait_args, &fcn_params };
Expand Down
27 changes: 24 additions & 3 deletions src/hir_typeck/static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2391,11 +2391,32 @@ bool StaticTraitResolve::can_unsize(const Span& sp, const ::HIR::TypeRef& dst_ty
if( const auto* se = src_ty.data().opt_TraitObject() )
{
// 1. Data trait must be the same
if( de->m_trait != se->m_trait )
if( de->m_trait.m_path.m_path != se->m_trait.m_path.m_path )
{
return false;
// Ensure that `de->m_trait` is a parent of `se->m_trait`
const auto& trait = *se->m_trait.m_trait_ptr;
bool found = false;
for(const auto& pt : trait.m_all_parent_traits) {
if( pt.m_path.m_path == de->m_trait.m_path.m_path )
{
auto p = MonomorphStatePtr(nullptr, &se->m_trait.m_path.m_params, nullptr).monomorph_genericpath(sp, pt.m_path);
if( p == de->m_trait.m_path ) {
found = true;
break;
}
}
}
if( !found ) {
DEBUG("Not a parent trait");
return false;
}
}
else {
if( de->m_trait.m_path != se->m_trait.m_path ) {
DEBUG("Mismatched data trait params");
return false;
}
}

// 2. Destination markers must be a strict subset
for(const auto& mt : de->m_markers)
{
Expand Down
18 changes: 16 additions & 2 deletions src/mir/cleanup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -754,10 +754,24 @@ bool MIR_Cleanup_Unsize_GetMetadata(const ::MIR::TypeResolve& state, MirMutator&
out_meta_ty = ::HIR::TypeRef::new_pointer(::HIR::BorrowType::Shared, mv$(vtable_ty));

// If the data trait hasn't changed, return the vtable pointer
if( src_ty.data().is_TraitObject() )
if( const auto* se = src_ty.data().opt_TraitObject() )
{
out_src_is_dst = true;
out_meta_val = mutator.in_temporary( out_meta_ty.clone(), ::MIR::RValue::make_DstMeta({ ptr_value.clone() }) );
if( se->m_trait.m_trait_ptr != de.m_trait.m_trait_ptr )
{
const auto& trait = *se->m_trait.m_trait_ptr;
auto vtable_ty = trait.get_vtable_type(state.sp, state.m_crate, *se);
auto in_meta_ty = ::HIR::TypeRef::new_pointer(::HIR::BorrowType::Shared, mv$(vtable_ty));

auto parent_trait_field = trait.get_vtable_parent_index(state.sp, se->m_trait.m_path.m_params, de.m_trait.m_path);
MIR_ASSERT(state, parent_trait_field != 0, "Unable to find parent trait for trait object upcast - " << se->m_trait.m_path << " in " << de.m_trait.m_path);
auto in_meta_val = mutator.in_temporary( mv$(in_meta_ty), ::MIR::RValue::make_DstMeta({ ptr_value.clone() }) );
out_meta_val = MIR::LValue::new_Field( MIR::LValue::new_Deref( mv$(in_meta_val) ), parent_trait_field );
}
else
{
out_meta_val = mutator.in_temporary( out_meta_ty.clone(), ::MIR::RValue::make_DstMeta({ ptr_value.clone() }) );
}
}
else
{
Expand Down
1 change: 1 addition & 0 deletions src/mir/from_hir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,7 @@ namespace {
}
}
TU_ARMA(TraitObject, e) {
// NOTE: This pattern (an empty ItemAddr) is detected by cleanup, which populates the vtable properly
m_builder.set_result( node.span(), ::MIR::RValue::make_MakeDst({ mv$(ptr_lval), ::MIR::Constant::make_ItemAddr({}) }) );
}
}
Expand Down
Loading

0 comments on commit 3f2bca3

Please sign in to comment.