Skip to content

Commit

Permalink
HIR - Avoid potential bad types in closure type expansions
Browse files Browse the repository at this point in the history
  • Loading branch information
thepowersgang committed Sep 7, 2023
1 parent 0dfab83 commit 33dc047
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 26 deletions.
5 changes: 5 additions & 0 deletions src/hir/expr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,9 @@ struct ExprNode_Closure:

// - Path to the generated closure type
const ::HIR::Struct* m_obj_ptr = nullptr;
/// Path to the created object, using types from the original context (for type expansion)
::HIR::GenericPath m_obj_path_base;
/// Path to the created object, using types from the current node location (for MIR)
::HIR::GenericPath m_obj_path;
::std::vector< ::HIR::ExprNodeP> m_captures;

Expand Down Expand Up @@ -868,6 +870,9 @@ struct ExprNode_Generator:

// Generated type information
const ::HIR::Struct* m_obj_ptr = nullptr;
/// Path to the created object, using types from the original context (for type expansion)
::HIR::GenericPath m_obj_path_base;
/// Path to the created object, using types from the current node location (for MIR)
::HIR::GenericPath m_obj_path;
// Captured variables (used for emitting the constructor)
::std::vector< ::HIR::ExprNodeP> m_captures;
Expand Down
20 changes: 17 additions & 3 deletions src/hir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <span.hpp>
#include "expr.hpp" // Hack for cloning array types
#include <cstdint>
#include "../hir_typeck/monomorph.hpp"

namespace HIR {

Expand Down Expand Up @@ -965,6 +966,18 @@ const ::HIR::TraitMarkings* HIR::TypePathBinding::get_trait_markings() const
return markings_ptr;
}

HIR::TypeMapping HIR::TypeMapping::clone() const
{
HIR::TypeMapping rv;
rv.self = self.clone();
rv.impl = impl.clone();
rv.item = item.clone();
return rv;
}
MonomorphStatePtr HIR::TypeMapping::get_ms() const {
return MonomorphStatePtr(self == HIR::TypeRef() ? nullptr : &self, &impl, &item);
}

::HIR::TypeRef HIR::TypeRef::clone() const
{
return HIR::TypeRef(*this);
Expand Down Expand Up @@ -1041,16 +1054,17 @@ ::HIR::TypeRef HIR::TypeRef::clone_shallow() const
return ::HIR::TypeRef(TypeData::make_Function( mv$(ft) ));
}
TU_ARMA(Closure, e) {
assert(e.params);
TypeData::Data_Closure oe;
oe.node = e.node;
//oe.m_closure_rettype = e.m_closure_rettype.clone();
//for(const auto& a : e.m_closure_arg_types)
// oe.m_closure_arg_types.push_back( a.clone() );
oe.params = box$(e.params->clone());
return ::HIR::TypeRef(TypeData::make_Closure( mv$(oe) ));
}
TU_ARMA(Generator, e) {
assert(e.params);
TypeData::Data_Generator oe;
oe.node = e.node;
oe.params = box$(e.params->clone());
return ::HIR::TypeRef(TypeData::make_Generator( mv$(oe) ));
}
}
Expand Down
23 changes: 19 additions & 4 deletions src/hir/type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
constexpr const char* CLOSURE_PATH_PREFIX = "closure#";
constexpr const char* GENERATOR_PATH_PREFIX = "generator#";

struct MonomorphStatePtr;

namespace HIR {

struct TraitMarkings;
Expand Down Expand Up @@ -158,6 +160,15 @@ struct TypeData_FunctionPointer
TypeRef m_rettype;
::std::vector<TypeRef> m_arg_types;
};
struct TypeMapping
{
::HIR::TypeRef self;
::HIR::PathParams impl;
::HIR::PathParams item;

TypeMapping clone() const;
MonomorphStatePtr get_ms() const;
};

TAGGED_UNION(TypeData, Diverge,
(Infer, struct {
Expand Down Expand Up @@ -203,9 +214,13 @@ TAGGED_UNION(TypeData, Diverge,
(Function, TypeData_FunctionPointer), // TODO: Pointer wrap, this is quite large
(Closure, struct {
const ::HIR::ExprNode_Closure* node;
// Mapping from node to current
::std::unique_ptr<TypeMapping> params;
}),
(Generator, struct {
const ::HIR::ExprNode_Generator* node;
// Mapping from node to current
::std::unique_ptr<TypeMapping> params;
})
);

Expand Down Expand Up @@ -302,11 +317,11 @@ inline TypeRef TypeRef::new_array(TypeRef inner, ::HIR::ConstGeneric size_gen) {
inline TypeRef TypeRef::new_path(::HIR::Path path, TypePathBinding binding) {
return TypeRef(TypeData::make_Path({ mv$(path), mv$(binding) }));
}
inline TypeRef TypeRef::new_closure(::HIR::ExprNode_Closure* node_ptr) {
return TypeRef(TypeData::make_Closure({ node_ptr }));
inline TypeRef TypeRef::new_closure(::HIR::ExprNode_Closure* node_ptr, TypeMapping pp) {
return TypeRef(TypeData::make_Closure({ node_ptr, box$(pp) }));
}
inline TypeRef TypeRef::new_generator(::HIR::ExprNode_Generator* node_ptr) {
return TypeRef(TypeData::make_Generator({ node_ptr }));
inline TypeRef TypeRef::new_generator(::HIR::ExprNode_Generator* node_ptr, TypeMapping pp) {
return TypeRef(TypeData::make_Generator({ node_ptr, box$(pp) }));
}

inline const ::HIR::SimplePath* TypeRef::get_sort_path() const {
Expand Down
7 changes: 4 additions & 3 deletions src/hir/type_ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class TypeData;
class TypeInner;
struct TypeData_FunctionPointer;
class TypePathBinding;
struct TypeMapping;

class TypeRef
{
Expand Down Expand Up @@ -96,8 +97,8 @@ class TypeRef
static TypeRef new_array(TypeRef inner, uint64_t size);
static TypeRef new_array(TypeRef inner, ::HIR::ConstGeneric size_expr);
static TypeRef new_path(::HIR::Path path, TypePathBinding binding);
static TypeRef new_closure(::HIR::ExprNode_Closure* node_ptr);
static TypeRef new_generator(::HIR::ExprNode_Generator* node_ptr);
static TypeRef new_closure(::HIR::ExprNode_Closure* node_ptr, TypeMapping pp);
static TypeRef new_generator(::HIR::ExprNode_Generator* node_ptr, TypeMapping pp);

// Duplicate refcount
TypeRef clone() const;
Expand All @@ -124,4 +125,4 @@ class TypeRef
const ::HIR::SimplePath* get_sort_path() const;
};

}
}
12 changes: 8 additions & 4 deletions src/hir_expand/closures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,17 +332,20 @@ namespace {
static void fix_type(const ::HIR::Crate& crate, const Span& sp, const Monomorphiser& monomorphiser, ::HIR::TypeRef& ty) {
if( const auto* e = ty.data().opt_Closure() )
{
DEBUG("Closure: " << e->node->m_obj_path_base); // TODO: Why does this use the `_base`
// NOTE: This uses `m_obj_path_base` as that has the original types (not changed as the node is moved around)
DEBUG("Closure: " << e->node->m_obj_path_base);
ASSERT_BUG(sp, e->node->m_obj_path_base != HIR::GenericPath(), ty);
auto path = monomorphiser.monomorph_genericpath(sp, e->node->m_obj_path_base, false);
auto ms = e->params->get_ms();
auto path = monomorphiser.monomorph_genericpath(sp, ms.monomorph_genericpath(sp, e->node->m_obj_path_base, false), false);
const auto& str = *e->node->m_obj_ptr;
DEBUG(ty << " -> " << path);
ty = ::HIR::TypeRef::new_path( mv$(path), ::HIR::TypePathBinding::make_Struct(&str) );
}
if(const auto* e = ty.data().opt_Generator() )
{
DEBUG("Generator: " << e->node->m_obj_path);
auto path = monomorphiser.monomorph_genericpath(sp, e->node->m_obj_path, false);
DEBUG("Generator: " << e->node->m_obj_path_base);
auto ms = e->params->get_ms();
auto path = monomorphiser.monomorph_genericpath(sp, ms.monomorph_genericpath(sp, e->node->m_obj_path_base, false), false);
const auto& str = *e->node->m_obj_ptr;
DEBUG(ty << " -> " << path);
ty = ::HIR::TypeRef::new_path( mv$(path), ::HIR::TypePathBinding::make_Struct(&str) );
Expand Down Expand Up @@ -1540,6 +1543,7 @@ namespace {
// Mark the object pathname in the closure.
node.m_obj_ptr = &gen_struct_ref;
node.m_obj_path = ::HIR::GenericPath( gen_struct_path, mv$(constructor_path_params) );
node.m_obj_path_base = node.m_obj_path.clone();
node.m_captures = mv$(capture_nodes);

::HIR::TypeRef& self_arg_ty = new_locals[0];
Expand Down
23 changes: 23 additions & 0 deletions src/hir_typeck/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@ bool monomorphise_type_needed(const ::HIR::TypeRef& tpl, bool ignore_lifetimes/*
return v.visit_type(tpl);
}

namespace {
HIR::TypeMapping monomorph_type_mapping(const Monomorphiser& m, const Span& sp, const HIR::TypeMapping& tpl, bool allow_infer) {
HIR::TypeMapping rv;
if( tpl.self != HIR::TypeRef() ) {
rv.self = m.monomorph_type(sp, tpl.self, allow_infer);
}
rv.impl = m.monomorph_path_params(sp, tpl.impl, allow_infer);
rv.item = m.monomorph_path_params(sp, tpl.item, allow_infer);
return rv;
}
}

::HIR::TypeRef Monomorphiser::monomorph_type(const Span& sp, const ::HIR::TypeRef& tpl, bool allow_infer/*=true*/) const
{
Expand Down Expand Up @@ -376,11 +387,13 @@ ::HIR::TypeRef Monomorphiser::monomorph_type(const Span& sp, const ::HIR::TypeRe
TU_ARMA(Closure, e) {
::HIR::TypeData::Data_Closure oe;
oe.node = e.node;
oe.params = box$(monomorph_type_mapping(*this, sp, *e.params, allow_infer));
return ::HIR::TypeRef( mv$(oe) );
}
TU_ARMA(Generator, e) {
::HIR::TypeData::Data_Generator oe;
oe.node = e.node;
oe.params = box$(monomorph_type_mapping(*this, sp, *e.params, allow_infer));
return ::HIR::TypeRef( mv$(oe) );
}
}
Expand Down Expand Up @@ -625,6 +638,15 @@ ::HIR::Path clone_ty_with__path(const Span& sp, const ::HIR::Path& tpl, t_cb_clo
}
throw "";
}
::HIR::TypeMapping clone_type_with__tm(const Span& sp, const ::HIR::TypeMapping& tpl, t_cb_clone_ty callback) {
::HIR::TypeMapping rv;
if( rv.self != HIR::TypeRef() ) {
rv.self = clone_ty_with(sp, rv.self, callback);
}
rv.impl = clone_path_params_with(sp, rv.impl, callback);
rv.item = clone_path_params_with(sp, rv.item, callback);
return rv;
}
::HIR::TypeRef clone_ty_with(const Span& sp, const ::HIR::TypeRef& tpl, t_cb_clone_ty callback)
{
::HIR::TypeRef rv;
Expand Down Expand Up @@ -715,6 +737,7 @@ ::HIR::TypeRef clone_ty_with(const Span& sp, const ::HIR::TypeRef& tpl, t_cb_clo
TU_ARMA(Closure, e) {
::HIR::TypeData::Data_Closure oe;
oe.node = e.node;
oe.params = box$(clone_type_with__tm(sp, *oe.params, callback));
rv = ::HIR::TypeRef( mv$(oe) );
}
TU_ARMA(Generator, e) {
Expand Down
2 changes: 1 addition & 1 deletion src/hir_typeck/expr_cs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7165,7 +7165,7 @@ void Typecheck_Code_CS(const typeck::ModuleState& ms, t_args& args, const ::HIR:

auto root_ptr = expr.into_unique();
assert(!ms.m_mod_paths.empty());
Context context { ms.m_crate, ms.m_impl_generics, ms.m_item_generics, ms.m_mod_paths.back(), ms.m_current_trait };
Context context { ms.m_crate, ms.m_current_trait != nullptr && ms.m_is_trait_def, ms.m_impl_generics, ms.m_item_generics, ms.m_mod_paths.back(), ms.m_current_trait };

// - Build up ruleset from node tree
Typecheck_Code_CS__EnumerateRules(context, ms, args, result_type, expr, root_ptr);
Expand Down
3 changes: 2 additions & 1 deletion src/hir_typeck/expr_cs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,14 @@ struct Context

Context(
const ::HIR::Crate& crate,
bool has_self,
const ::HIR::GenericParams* impl_params,
const ::HIR::GenericParams* item_params,
const ::HIR::SimplePath& mod_path,
const ::HIR::GenericPath* current_trait
)
:m_crate(crate)
,m_resolve(m_ivars, crate, impl_params, item_params, mod_path, current_trait)
,m_resolve(m_ivars, crate, has_self, impl_params, item_params, mod_path, current_trait)
,next_rule_idx( 0 )
,m_lang_Box( crate.get_lang_item_path_opt("owned_box") )
{
Expand Down
16 changes: 10 additions & 6 deletions src/hir_typeck/expr_cs__enum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1969,11 +1969,11 @@ namespace typecheck
this->context.add_ivars( node.m_code->m_res_type );

// Closure result type
::std::vector< ::HIR::TypeRef> arg_types;
for(auto& arg : node.m_args) {
arg_types.push_back( arg.second.clone() );
}
this->context.equate_types( node.span(), node.m_res_type, ::HIR::TypeRef::new_closure(&node/*, mv$(arg_types), node.m_return.clone()*/) );
HIR::TypeMapping tm;
tm.self = this->context.m_resolve.has_self() ? HIR::TypeRef("Self", GENERIC_Self) : HIR::TypeRef();
tm.impl = this->context.m_resolve.impl_generics().make_nop_params(0);
tm.item = this->context.m_resolve.item_generics().make_nop_params(1);
this->context.equate_types( node.span(), node.m_res_type, ::HIR::TypeRef::new_closure(&node, ::std::move(tm)) );

this->context.equate_types_coerce( node.span(), node.m_return, node.m_code );

Expand All @@ -2000,7 +2000,11 @@ namespace typecheck
this->context.add_ivars( node.m_code->m_res_type );

// Generator result type
this->context.equate_types( node.span(), node.m_res_type, ::HIR::TypeRef::new_generator(&node) );
HIR::TypeMapping tm;
tm.self = this->context.m_resolve.has_self() ? HIR::TypeRef("Self", GENERIC_Self) : HIR::TypeRef();
tm.impl = this->context.m_resolve.impl_generics().make_nop_params(0);
tm.item = this->context.m_resolve.item_generics().make_nop_params(1);
this->context.equate_types( node.span(), node.m_res_type, ::HIR::TypeRef::new_generator(&node, ::std::move(tm)) );

this->context.equate_types_coerce( node.span(), node.m_return, node.m_code );

Expand Down
4 changes: 2 additions & 2 deletions src/hir_typeck/expr_visit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ namespace {
for(size_t i = 0; i < item.m_params.m_values.size(); i ++) {
trait_gpath.m_params.m_values.push_back(HIR::GenericRef(item.m_params.m_values[i].m_name, i));
}
auto _1 = this->m_ms.set_current_trait(trait_gpath);
auto _1 = this->m_ms.set_current_trait(trait_gpath, true);
auto _ = this->m_ms.set_impl_generics(item.m_params);
::HIR::Visitor::visit_trait(p, item);
}
Expand All @@ -159,7 +159,7 @@ namespace {
{
TRACE_FUNCTION_F("impl " << trait_path << impl.m_trait_args << " for " << impl.m_type);
auto trait_gpath = ::HIR::GenericPath(trait_path, impl.m_trait_args.clone());
auto _1 = this->m_ms.set_current_trait(trait_gpath);
auto _1 = this->m_ms.set_current_trait(trait_gpath, false);
auto _ = this->m_ms.set_impl_generics(impl.m_params);

const auto& mod = this->m_ms.m_crate.get_mod_by_path(Span(), impl.m_src_module);
Expand Down
5 changes: 4 additions & 1 deletion src/hir_typeck/expr_visit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace typeck {
{
const ::HIR::Crate& m_crate;

bool m_is_trait_def;
const ::HIR::GenericPath* m_current_trait;
const ::HIR::GenericParams* m_impl_generics;
const ::HIR::GenericParams* m_item_generics;
Expand All @@ -22,6 +23,7 @@ namespace typeck {

ModuleState(const ::HIR::Crate& crate):
m_crate(crate),
m_is_trait_def(false),
m_current_trait(nullptr),
m_impl_generics(nullptr),
m_item_generics(nullptr)
Expand All @@ -38,9 +40,10 @@ namespace typeck {
ptr = nullptr;
}
};
NullOnDrop<const ::HIR::GenericPath> set_current_trait(const ::HIR::GenericPath& p) {
NullOnDrop<const ::HIR::GenericPath> set_current_trait(const ::HIR::GenericPath& p, bool is_trait_def) {
assert( !m_current_trait );
m_current_trait = &p;
m_is_trait_def = is_trait_def;
return NullOnDrop<const ::HIR::GenericPath>(m_current_trait);
}
NullOnDrop<const ::HIR::GenericParams> set_impl_generics(const ::HIR::GenericParams& gps) {
Expand Down
6 changes: 5 additions & 1 deletion src/hir_typeck/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,17 +175,19 @@ class TraitResolution:
const HIR::SimplePath& m_lang_Deref;
const HMTypeInferrence& m_ivars;

bool m_has_self;
const ::HIR::SimplePath& m_vis_path;
const ::HIR::GenericPath* m_current_trait_path;
const ::HIR::Trait* m_current_trait_ptr;

mutable ::std::vector<std::unique_ptr<::HIR::TypeRef>> m_eat_active_stack;
public:
TraitResolution(const HMTypeInferrence& ivars, const ::HIR::Crate& crate, const ::HIR::GenericParams* impl_params, const ::HIR::GenericParams* item_params, const ::HIR::SimplePath& vis_path, const ::HIR::GenericPath* current_trait):
TraitResolution(const HMTypeInferrence& ivars, const ::HIR::Crate& crate, bool has_self, const ::HIR::GenericParams* impl_params, const ::HIR::GenericParams* item_params, const ::HIR::SimplePath& vis_path, const ::HIR::GenericPath* current_trait):
TraitResolveCommon(crate)
,m_lang_Deref(crate.get_lang_item_path_opt("deref"))
,m_ivars(ivars)
,m_vis_path(vis_path)
,m_has_self(has_self)
,m_current_trait_path(current_trait)
,m_current_trait_ptr(current_trait ? &crate.get_trait_by_path(Span(), current_trait->m_path) : nullptr)
{
Expand All @@ -194,6 +196,8 @@ class TraitResolution:
prep_indexes(Span());
}

bool has_self() const { return m_has_self; }

::HIR::Compare compare_pp(const Span& sp, const ::HIR::PathParams& left, const ::HIR::PathParams& right) const;

void compact_ivars(HMTypeInferrence& m_ivars);
Expand Down

0 comments on commit 33dc047

Please sign in to comment.