diff --git a/compiler/noirc_frontend/src/ast/traits.rs b/compiler/noirc_frontend/src/ast/traits.rs index 723df775b1e..475e3ff1be9 100644 --- a/compiler/noirc_frontend/src/ast/traits.rs +++ b/compiler/noirc_frontend/src/ast/traits.rs @@ -24,6 +24,7 @@ pub struct NoirTrait { pub items: Vec>, pub attributes: Vec, pub visibility: ItemVisibility, + pub is_alias: bool, } /// Any declaration inside the body of a trait that a user is required to @@ -77,6 +78,9 @@ pub struct NoirTraitImpl { pub where_clause: Vec, pub items: Vec>, + + /// true if generated at compile-time, e.g. from a trait alias + pub is_synthetic: bool, } /// Represents a simple trait constraint such as `where Foo: TraitY` @@ -130,12 +134,19 @@ impl Display for TypeImpl { } } +// TODO: display where clauses (follow-up issue) impl Display for NoirTrait { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let generics = vecmap(&self.generics, |generic| generic.to_string()); let generics = if generics.is_empty() { "".into() } else { generics.join(", ") }; write!(f, "trait {}{}", self.name, generics)?; + + if self.is_alias { + let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); + return write!(f, " = {};", bounds); + } + if !self.bounds.is_empty() { let bounds = vecmap(&self.bounds, |bound| bound.to_string()).join(" + "); write!(f, ": {}", bounds)?; @@ -222,6 +233,11 @@ impl Display for TraitBound { impl Display for NoirTraitImpl { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Synthetic NoirTraitImpl's don't get printed + if self.is_synthetic { + return Ok(()); + } + write!(f, "impl")?; if !self.impl_generics.is_empty() { write!( diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index a63601a4280..a6b6120986e 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -208,7 +208,7 @@ pub enum TypeCheckError { #[derive(Debug, Clone, PartialEq, Eq)] pub struct NoMatchingImplFoundError { - constraints: Vec<(Type, String)>, + pub(crate) constraints: Vec<(Type, String)>, pub span: Span, } diff --git a/compiler/noirc_frontend/src/parser/errors.rs b/compiler/noirc_frontend/src/parser/errors.rs index 63471efac43..bcb4ce1c616 100644 --- a/compiler/noirc_frontend/src/parser/errors.rs +++ b/compiler/noirc_frontend/src/parser/errors.rs @@ -95,6 +95,8 @@ pub enum ParserErrorReason { AssociatedTypesNotAllowedInPaths, #[error("Associated types are not allowed on a method call")] AssociatedTypesNotAllowedInMethodCalls, + #[error("Empty trait alias")] + EmptyTraitAlias, #[error( "Wrong number of arguments for attribute `{}`. Expected {}, found {}", name, diff --git a/compiler/noirc_frontend/src/parser/parser/impls.rs b/compiler/noirc_frontend/src/parser/parser/impls.rs index 9215aec2742..8e6b3bae0e9 100644 --- a/compiler/noirc_frontend/src/parser/parser/impls.rs +++ b/compiler/noirc_frontend/src/parser/parser/impls.rs @@ -113,6 +113,7 @@ impl<'a> Parser<'a> { let object_type = self.parse_type_or_error(); let where_clause = self.parse_where_clause(); let items = self.parse_trait_impl_body(); + let is_synthetic = false; NoirTraitImpl { impl_generics, @@ -121,6 +122,7 @@ impl<'a> Parser<'a> { object_type, where_clause, items, + is_synthetic, } } diff --git a/compiler/noirc_frontend/src/parser/parser/item.rs b/compiler/noirc_frontend/src/parser/parser/item.rs index 4fbcd7abac5..34f97753855 100644 --- a/compiler/noirc_frontend/src/parser/parser/item.rs +++ b/compiler/noirc_frontend/src/parser/parser/item.rs @@ -1,3 +1,5 @@ +use iter_extended::vecmap; + use crate::{ parser::{labels::ParsingRuleLabel, Item, ItemKind}, token::{Keyword, Token}, @@ -13,31 +15,32 @@ impl<'a> Parser<'a> { } pub(crate) fn parse_module_items(&mut self, nested: bool) -> Vec { - self.parse_many("items", without_separator(), |parser| { + self.parse_many_to_many("items", without_separator(), |parser| { parser.parse_module_item_in_list(nested) }) } - fn parse_module_item_in_list(&mut self, nested: bool) -> Option { + fn parse_module_item_in_list(&mut self, nested: bool) -> Vec { loop { // We only break out of the loop on `}` if we are inside a `mod { ..` if nested && self.at(Token::RightBrace) { - return None; + return vec![]; } // We always break on EOF (we don't error because if we are inside `mod { ..` // the outer parsing logic will error instead) if self.at_eof() { - return None; + return vec![]; } - let Some(item) = self.parse_item() else { + let parsed_items = self.parse_item(); + if parsed_items.is_empty() { // If we couldn't parse an item we check which token we got match self.token.token() { Token::RightBrace if nested => { - return None; + return vec![]; } - Token::EOF => return None, + Token::EOF => return vec![], _ => (), } @@ -47,7 +50,7 @@ impl<'a> Parser<'a> { continue; }; - return Some(item); + return parsed_items; } } @@ -85,13 +88,13 @@ impl<'a> Parser<'a> { } /// Item = OuterDocComments ItemKind - fn parse_item(&mut self) -> Option { + fn parse_item(&mut self) -> Vec { let start_span = self.current_token_span; let doc_comments = self.parse_outer_doc_comments(); - let kind = self.parse_item_kind()?; + let kinds = self.parse_item_kind(); let span = self.span_since(start_span); - Some(Item { kind, span, doc_comments }) + vecmap(kinds, |kind| Item { kind, span, doc_comments: doc_comments.clone() }) } /// ItemKind @@ -106,9 +109,9 @@ impl<'a> Parser<'a> { /// | TypeAlias /// | Function /// ) - fn parse_item_kind(&mut self) -> Option { + fn parse_item_kind(&mut self) -> Vec { if let Some(kind) = self.parse_inner_attribute() { - return Some(ItemKind::InnerAttribute(kind)); + return vec![ItemKind::InnerAttribute(kind)]; } let start_span = self.current_token_span; @@ -122,78 +125,81 @@ impl<'a> Parser<'a> { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); let use_tree = self.parse_use_tree(); - return Some(ItemKind::Import(use_tree, modifiers.visibility)); + return vec![ItemKind::Import(use_tree, modifiers.visibility)]; } if let Some(is_contract) = self.eat_mod_or_contract() { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(self.parse_mod_or_contract(attributes, is_contract, modifiers.visibility)); + return vec![self.parse_mod_or_contract(attributes, is_contract, modifiers.visibility)]; } if self.eat_keyword(Keyword::Struct) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::Struct(self.parse_struct( + return vec![ItemKind::Struct(self.parse_struct( attributes, modifiers.visibility, start_span, - ))); + ))]; } if self.eat_keyword(Keyword::Impl) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(match self.parse_impl() { + return vec![match self.parse_impl() { Impl::Impl(type_impl) => ItemKind::Impl(type_impl), Impl::TraitImpl(noir_trait_impl) => ItemKind::TraitImpl(noir_trait_impl), - }); + }]; } if self.eat_keyword(Keyword::Trait) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::Trait(self.parse_trait( - attributes, - modifiers.visibility, - start_span, - ))); + let (noir_trait, noir_impl) = + self.parse_trait(attributes, modifiers.visibility, start_span); + let mut output = vec![ItemKind::Trait(noir_trait)]; + if let Some(noir_impl) = noir_impl { + output.push(ItemKind::TraitImpl(noir_impl)); + } + + return output; } if self.eat_keyword(Keyword::Global) { self.unconstrained_not_applicable(modifiers); - return Some(ItemKind::Global( + return vec![ItemKind::Global( self.parse_global( attributes, modifiers.comptime.is_some(), modifiers.mutable.is_some(), ), modifiers.visibility, - )); + )]; } if self.eat_keyword(Keyword::Type) { self.comptime_mutable_and_unconstrained_not_applicable(modifiers); - return Some(ItemKind::TypeAlias( + return vec![ItemKind::TypeAlias( self.parse_type_alias(modifiers.visibility, start_span), - )); + )]; } if self.eat_keyword(Keyword::Fn) { self.mutable_not_applicable(modifiers); - return Some(ItemKind::Function(self.parse_function( + return vec![ItemKind::Function(self.parse_function( attributes, modifiers.visibility, modifiers.comptime.is_some(), modifiers.unconstrained.is_some(), false, // allow_self - ))); + ))]; } - None + vec![] } fn eat_mod_or_contract(&mut self) -> Option { diff --git a/compiler/noirc_frontend/src/parser/parser/parse_many.rs b/compiler/noirc_frontend/src/parser/parser/parse_many.rs index ea4dfe97122..452be2d9a8e 100644 --- a/compiler/noirc_frontend/src/parser/parser/parse_many.rs +++ b/compiler/noirc_frontend/src/parser/parser/parse_many.rs @@ -18,6 +18,19 @@ impl<'a> Parser<'a> { self.parse_many_return_trailing_separator_if_any(items, separated_by, f).0 } + /// parse_many, where the given function `f` may return multiple results + pub(super) fn parse_many_to_many( + &mut self, + items: &'static str, + separated_by: SeparatedBy, + f: F, + ) -> Vec + where + F: FnMut(&mut Parser<'a>) -> Vec, + { + self.parse_many_to_many_return_trailing_separator_if_any(items, separated_by, f).0 + } + /// Same as parse_many, but returns a bool indicating whether a trailing separator was found. pub(super) fn parse_many_return_trailing_separator_if_any( &mut self, @@ -27,6 +40,26 @@ impl<'a> Parser<'a> { ) -> (Vec, bool) where F: FnMut(&mut Parser<'a>) -> Option, + { + let f = |x: &mut Parser<'a>| { + if let Some(result) = f(x) { + vec![result] + } else { + vec![] + } + }; + self.parse_many_to_many_return_trailing_separator_if_any(items, separated_by, f) + } + + /// Same as parse_many, but returns a bool indicating whether a trailing separator was found. + pub(super) fn parse_many_to_many_return_trailing_separator_if_any( + &mut self, + items: &'static str, + separated_by: SeparatedBy, + mut f: F, + ) -> (Vec, bool) + where + F: FnMut(&mut Parser<'a>) -> Vec, { let mut elements: Vec = Vec::new(); let mut trailing_separator = false; @@ -38,12 +71,13 @@ impl<'a> Parser<'a> { } let start_span = self.current_token_span; - let Some(element) = f(self) else { + let mut new_elements = f(self); + if new_elements.is_empty() { if let Some(end) = &separated_by.until { self.eat(end.clone()); } break; - }; + } if let Some(separator) = &separated_by.token { if !trailing_separator && !elements.is_empty() { @@ -51,7 +85,7 @@ impl<'a> Parser<'a> { } } - elements.push(element); + elements.append(&mut new_elements); trailing_separator = if let Some(separator) = &separated_by.token { self.eat(separator.clone()) diff --git a/compiler/noirc_frontend/src/parser/parser/traits.rs b/compiler/noirc_frontend/src/parser/parser/traits.rs index fead6a34c82..e03b629e9ea 100644 --- a/compiler/noirc_frontend/src/parser/parser/traits.rs +++ b/compiler/noirc_frontend/src/parser/parser/traits.rs @@ -1,9 +1,14 @@ +use iter_extended::vecmap; + use noirc_errors::Span; -use crate::ast::{Documented, ItemVisibility, NoirTrait, Pattern, TraitItem, UnresolvedType}; +use crate::ast::{ + Documented, GenericTypeArg, GenericTypeArgs, ItemVisibility, NoirTrait, Path, Pattern, + TraitItem, UnresolvedGeneric, UnresolvedTraitConstraint, UnresolvedType, +}; use crate::{ ast::{Ident, UnresolvedTypeData}, - parser::{labels::ParsingRuleLabel, ParserErrorReason}, + parser::{labels::ParsingRuleLabel, NoirTraitImpl, ParserErrorReason}, token::{Attribute, Keyword, SecondaryAttribute, Token}, }; @@ -12,34 +17,117 @@ use super::Parser; impl<'a> Parser<'a> { /// Trait = 'trait' identifier Generics ( ':' TraitBounds )? WhereClause TraitBody + /// | 'trait' identifier Generics '=' TraitBounds WhereClause ';' pub(crate) fn parse_trait( &mut self, attributes: Vec<(Attribute, Span)>, visibility: ItemVisibility, start_span: Span, - ) -> NoirTrait { + ) -> (NoirTrait, Option) { let attributes = self.validate_secondary_attributes(attributes); let Some(name) = self.eat_ident() else { self.expected_identifier(); - return empty_trait(attributes, visibility, self.span_since(start_span)); + let noir_trait = empty_trait(attributes, visibility, self.span_since(start_span)); + let no_implicit_impl = None; + return (noir_trait, no_implicit_impl); }; let generics = self.parse_generics(); - let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; - let where_clause = self.parse_where_clause(); - let items = self.parse_trait_body(); - NoirTrait { + // Trait aliases: + // trait Foo<..> = A + B + E where ..; + let (bounds, where_clause, items, is_alias) = if self.eat_assign() { + let bounds = self.parse_trait_bounds(); + + if bounds.is_empty() { + self.push_error(ParserErrorReason::EmptyTraitAlias, self.previous_token_span); + } + + let where_clause = self.parse_where_clause(); + let items = Vec::new(); + if !self.eat_semicolon() { + self.expected_token(Token::Semicolon); + } + + let is_alias = true; + (bounds, where_clause, items, is_alias) + } else { + let bounds = if self.eat_colon() { self.parse_trait_bounds() } else { Vec::new() }; + let where_clause = self.parse_where_clause(); + let items = self.parse_trait_body(); + let is_alias = false; + (bounds, where_clause, items, is_alias) + }; + + let span = self.span_since(start_span); + + let noir_impl = is_alias.then(|| { + let object_type_ident = Ident::new("#T".to_string(), span); + let object_type_path = Path::from_ident(object_type_ident.clone()); + let object_type_generic = UnresolvedGeneric::Variable(object_type_ident); + + let is_synthesized = true; + let object_type = UnresolvedType { + typ: UnresolvedTypeData::Named(object_type_path, vec![].into(), is_synthesized), + span, + }; + + let mut impl_generics = generics.clone(); + impl_generics.push(object_type_generic); + + let trait_name = Path::from_ident(name.clone()); + let trait_generics: GenericTypeArgs = vecmap(generics.clone(), |generic| { + let is_synthesized = true; + let generic_type = UnresolvedType { + typ: UnresolvedTypeData::Named( + Path::from_ident(generic.ident().clone()), + vec![].into(), + is_synthesized, + ), + span, + }; + + GenericTypeArg::Ordered(generic_type) + }) + .into(); + + // bounds from trait + let mut where_clause = where_clause.clone(); + for bound in bounds.clone() { + where_clause.push(UnresolvedTraitConstraint { + typ: object_type.clone(), + trait_bound: bound, + }); + } + + let items = vec![]; + let is_synthetic = true; + + NoirTraitImpl { + impl_generics, + trait_name, + trait_generics, + object_type, + where_clause, + items, + is_synthetic, + } + }); + + let noir_trait = NoirTrait { name, generics, bounds, where_clause, - span: self.span_since(start_span), + span, items, attributes, visibility, - } + is_alias, + }; + + (noir_trait, noir_impl) } /// TraitBody = '{' ( OuterDocComments TraitItem )* '}' @@ -188,28 +276,51 @@ fn empty_trait( items: Vec::new(), attributes, visibility, + is_alias: false, } } #[cfg(test)] mod tests { use crate::{ - ast::{NoirTrait, TraitItem}, + ast::{NoirTrait, NoirTraitImpl, TraitItem}, parser::{ - parser::{parse_program, tests::expect_no_errors}, + parser::{parse_program, tests::expect_no_errors, ParserErrorReason}, ItemKind, }, }; - fn parse_trait_no_errors(src: &str) -> NoirTrait { + fn parse_trait_opt_impl_no_errors(src: &str) -> (NoirTrait, Option) { let (mut module, errors) = parse_program(src); expect_no_errors(&errors); - assert_eq!(module.items.len(), 1); - let item = module.items.remove(0); + let (item, impl_item) = if module.items.len() == 2 { + let item = module.items.remove(0); + let impl_item = module.items.remove(0); + (item, Some(impl_item)) + } else { + assert_eq!(module.items.len(), 1); + let item = module.items.remove(0); + (item, None) + }; let ItemKind::Trait(noir_trait) = item.kind else { panic!("Expected trait"); }; - noir_trait + let noir_trait_impl = impl_item.map(|impl_item| { + let ItemKind::TraitImpl(noir_trait_impl) = impl_item.kind else { + panic!("Expected impl"); + }; + noir_trait_impl + }); + (noir_trait, noir_trait_impl) + } + + fn parse_trait_with_impl_no_errors(src: &str) -> (NoirTrait, NoirTraitImpl) { + let (noir_trait, noir_trait_impl) = parse_trait_opt_impl_no_errors(src); + (noir_trait, noir_trait_impl.expect("expected a NoirTraitImpl")) + } + + fn parse_trait_no_errors(src: &str) -> NoirTrait { + parse_trait_opt_impl_no_errors(src).0 } #[test] @@ -220,6 +331,15 @@ mod tests { assert!(noir_trait.generics.is_empty()); assert!(noir_trait.where_clause.is_empty()); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias() { + let src = "trait Foo = ;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -230,6 +350,50 @@ mod tests { assert_eq!(noir_trait.generics.len(), 2); assert!(noir_trait.where_clause.is_empty()); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias_with_generics() { + let src = "trait Foo = Bar + Baz;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.name.to_string(), "Foo"); + assert_eq!(noir_trait_alias.generics.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert!(noir_trait_alias.where_clause.is_empty()); + assert!(noir_trait_alias.items.is_empty()); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 3); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(noir_trait_impl.where_clause.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.is_empty(), noir_trait_alias.where_clause.is_empty()); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias_with_generics() { + let src = "trait Foo = ;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -240,6 +404,54 @@ mod tests { assert_eq!(noir_trait.generics.len(), 2); assert_eq!(noir_trait.where_clause.len(), 1); assert!(noir_trait.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias_with_where_clause() { + let src = "trait Foo = Bar + Baz where A: Z;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.name.to_string(), "Foo"); + assert_eq!(noir_trait_alias.generics.len(), 2); + assert_eq!(noir_trait_alias.bounds.len(), 2); + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + assert_eq!(noir_trait_alias.where_clause.len(), 1); + assert!(noir_trait_alias.items.is_empty()); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 3); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 2); + assert_eq!(noir_trait_impl.where_clause.len(), 3); + assert_eq!(noir_trait_impl.where_clause[0].to_string(), "A: Z"); + assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Bar"); + assert_eq!(noir_trait_impl.where_clause[2].to_string(), "#T: Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz where A: Z {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.len(), noir_trait_alias.where_clause.len()); + assert_eq!( + noir_trait.where_clause[0].to_string(), + noir_trait_alias.where_clause[0].to_string() + ); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_empty_trait_alias_with_where_clause() { + let src = "trait Foo = where A: Z;"; + let (_module, errors) = parse_program(src); + assert_eq!(errors.len(), 2); + assert_eq!(errors[1].reason(), Some(ParserErrorReason::EmptyTraitAlias).as_ref()); } #[test] @@ -253,6 +465,7 @@ mod tests { panic!("Expected type"); }; assert_eq!(name.to_string(), "Elem"); + assert!(!noir_trait.is_alias); } #[test] @@ -268,6 +481,7 @@ mod tests { assert_eq!(name.to_string(), "x"); assert_eq!(typ.to_string(), "Field"); assert_eq!(default_value.unwrap().to_string(), "1"); + assert!(!noir_trait.is_alias); } #[test] @@ -281,6 +495,7 @@ mod tests { panic!("Expected function"); }; assert!(body.is_none()); + assert!(!noir_trait.is_alias); } #[test] @@ -294,6 +509,7 @@ mod tests { panic!("Expected function"); }; assert!(body.is_some()); + assert!(!noir_trait.is_alias); } #[test] @@ -306,5 +522,39 @@ mod tests { assert_eq!(noir_trait.bounds[1].to_string(), "Baz"); assert_eq!(noir_trait.to_string(), "trait Foo: Bar + Baz {\n}"); + assert!(!noir_trait.is_alias); + } + + #[test] + fn parse_trait_alias() { + let src = "trait Foo = Bar + Baz;"; + let (noir_trait_alias, noir_trait_impl) = parse_trait_with_impl_no_errors(src); + assert_eq!(noir_trait_alias.bounds.len(), 2); + + assert_eq!(noir_trait_alias.bounds[0].to_string(), "Bar"); + assert_eq!(noir_trait_alias.bounds[1].to_string(), "Baz"); + + assert_eq!(noir_trait_alias.to_string(), "trait Foo = Bar + Baz;"); + assert!(noir_trait_alias.is_alias); + + assert_eq!(noir_trait_impl.trait_name.to_string(), "Foo"); + assert_eq!(noir_trait_impl.impl_generics.len(), 1); + assert_eq!(noir_trait_impl.trait_generics.ordered_args.len(), 0); + assert_eq!(noir_trait_impl.where_clause.len(), 2); + assert_eq!(noir_trait_impl.where_clause[0].to_string(), "#T: Bar"); + assert_eq!(noir_trait_impl.where_clause[1].to_string(), "#T: Baz"); + assert!(noir_trait_impl.items.is_empty()); + assert!(noir_trait_impl.is_synthetic); + + // Equivalent to + let src = "trait Foo: Bar + Baz {}"; + let noir_trait = parse_trait_no_errors(src); + assert_eq!(noir_trait.name.to_string(), noir_trait_alias.name.to_string()); + assert_eq!(noir_trait.generics.len(), noir_trait_alias.generics.len()); + assert_eq!(noir_trait.bounds.len(), noir_trait_alias.bounds.len()); + assert_eq!(noir_trait.bounds[0].to_string(), noir_trait_alias.bounds[0].to_string()); + assert_eq!(noir_trait.where_clause.is_empty(), noir_trait_alias.where_clause.is_empty()); + assert_eq!(noir_trait.items.is_empty(), noir_trait_alias.items.is_empty()); + assert!(!noir_trait.is_alias); } } diff --git a/compiler/noirc_frontend/src/tests/traits.rs b/compiler/noirc_frontend/src/tests/traits.rs index 0adf5c90bea..811a32bab86 100644 --- a/compiler/noirc_frontend/src/tests/traits.rs +++ b/compiler/noirc_frontend/src/tests/traits.rs @@ -1,5 +1,6 @@ use crate::hir::def_collector::dc_crate::CompilationError; use crate::hir::resolution::errors::ResolverError; +use crate::hir::type_check::TypeCheckError; use crate::tests::{get_program_errors, get_program_with_maybe_parser_errors}; use super::assert_no_errors; @@ -320,6 +321,251 @@ fn regression_6314_double_inheritance() { assert_no_errors(src); } +#[test] +fn trait_alias_single_member() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Baz = Foo; + + impl Foo for Field { + fn foo(self) -> Self { self } + } + + fn baz(x: T) -> T where T: Baz { + x.foo() + } + + fn main() { + let x: Field = 0; + let _ = baz(x); + } + "#; + assert_no_errors(src); +} + +#[test] +fn trait_alias_two_members() { + let src = r#" + pub trait Foo { + fn foo(self) -> Self; + } + + pub trait Bar { + fn bar(self) -> Self; + } + + pub trait Baz = Foo + Bar; + + fn baz(x: T) -> T where T: Baz { + x.foo().bar() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> Self { + self + 2 + } + } + + fn main() { + assert(0.foo().bar() == baz(0)); + }"#; + + assert_no_errors(src); +} + +#[test] +fn trait_alias_polymorphic_inheritance() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Bar { + fn bar(self) -> T; + } + + trait Baz = Foo + Bar; + + fn baz(x: T) -> U where T: Baz { + x.foo().bar() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> bool { + true + } + } + + fn main() { + assert(0.foo().bar() == baz(0)); + }"#; + + assert_no_errors(src); +} + +// TODO(https://github.com/noir-lang/noir/issues/6467): currently fails with the +// same errors as the desugared version +#[test] +fn trait_alias_polymorphic_where_clause() { + let src = r#" + trait Foo { + fn foo(self) -> Self; + } + + trait Bar { + fn bar(self) -> T; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux = Foo + Bar where T: Baz; + + fn qux(x: T) -> bool where T: Qux { + x.foo().bar().baz() + } + + impl Foo for Field { + fn foo(self) -> Self { + self + 1 + } + } + + impl Bar for Field { + fn bar(self) -> bool { + true + } + } + + impl Baz for bool { + fn baz(self) -> bool { + self + } + } + + fn main() { + assert(0.foo().bar().baz() == qux(0)); + } + "#; + + // TODO(https://github.com/noir-lang/noir/issues/6467) + // assert_no_errors(src); + let errors = get_program_errors(src); + assert_eq!(errors.len(), 2); + + match &errors[0].0 { + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name, .. + }) => { + assert_eq!(method_name, "baz"); + } + other => { + panic!("expected UnresolvedMethodCall, but found {:?}", other); + } + } + + match &errors[1].0 { + CompilationError::TypeError(TypeCheckError::NoMatchingImplFound(err)) => { + assert_eq!(err.constraints.len(), 2); + assert_eq!(err.constraints[0].1, "Baz"); + assert_eq!(err.constraints[1].1, "Qux<_>"); + } + other => { + panic!("expected NoMatchingImplFound, but found {:?}", other); + } + } +} + +// TODO(https://github.com/noir-lang/noir/issues/6467): currently failing, so +// this just tests that the trait alias has an equivalent error to the expected +// desugared version +#[test] +fn trait_alias_with_where_clause_has_equivalent_errors() { + let src = r#" + trait Bar { + fn bar(self) -> Self; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux: Bar where T: Baz {} + + impl Qux for U where + U: Bar, + T: Baz, + {} + + pub fn qux(x: T, _: U) -> bool where U: Qux { + x.baz() + } + + fn main() {} + "#; + + let alias_src = r#" + trait Bar { + fn bar(self) -> Self; + } + + trait Baz { + fn baz(self) -> bool; + } + + trait Qux = Bar where T: Baz; + + pub fn qux(x: T, _: U) -> bool where U: Qux { + x.baz() + } + + fn main() {} + "#; + + let errors = get_program_errors(src); + let alias_errors = get_program_errors(alias_src); + + assert_eq!(errors.len(), 1); + assert_eq!(alias_errors.len(), 1); + + match (&errors[0].0, &alias_errors[0].0) { + ( + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name, + object_type, + .. + }), + CompilationError::TypeError(TypeCheckError::UnresolvedMethodCall { + method_name: alias_method_name, + object_type: alias_object_type, + .. + }), + ) => { + assert_eq!(method_name, alias_method_name); + assert_eq!(object_type, alias_object_type); + } + other => { + panic!("expected UnresolvedMethodCall, but found {:?}", other); + } + } +} + #[test] fn removes_assumed_parent_traits_after_function_ends() { let src = r#" diff --git a/docs/docs/noir/concepts/traits.md b/docs/docs/noir/concepts/traits.md index 9da00a77587..b6c0a886eb0 100644 --- a/docs/docs/noir/concepts/traits.md +++ b/docs/docs/noir/concepts/traits.md @@ -490,12 +490,95 @@ trait CompSciStudent: Programmer + Student { } ``` +### Trait Aliases + +Similar to the proposed Rust feature for [trait aliases](https://github.com/rust-lang/rust/blob/4d215e2426d52ca8d1af166d5f6b5e172afbff67/src/doc/unstable-book/src/language-features/trait-alias.md), +Noir supports aliasing one or more traits and using those aliases wherever +traits would normally be used. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> Self; +} + +// Equivalent to: +// trait Baz: Foo + Bar {} +// +// impl Baz for T where T: Foo + Bar {} +trait Baz = Foo + Bar; + +// We can use `Baz` to refer to `Foo + Bar` +fn baz(x: T) -> T where T: Baz { + x.foo().bar() +} +``` + +#### Generic Trait Aliases + +Trait aliases can also be generic by placing the generic arguments after the +trait name. These generics are in scope of every item within the trait alias. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> T; +} + +// Equivalent to: +// trait Baz: Foo + Bar {} +// +// impl Baz for U where U: Foo + Bar {} +trait Baz = Foo + Bar; +``` + +#### Trait Alias Where Clauses + +Trait aliases support where clauses to add trait constraints to any of their +generic arguments, e.g. ensuring `T: Baz` for a trait alias `Qux`. + +```rust +trait Foo { + fn foo(self) -> Self; +} + +trait Bar { + fn bar(self) -> T; +} + +trait Baz { + fn baz(self) -> bool; +} + +// Equivalent to: +// trait Qux: Foo + Bar where T: Baz {} +// +// impl Qux for U where +// U: Foo + Bar, +// T: Baz, +// {} +trait Qux = Foo + Bar where T: Baz; +``` + +Note that while trait aliases support where clauses, +the equivalent traits can fail due to [#6467](https://github.com/noir-lang/noir/issues/6467) + ### Visibility -By default, like functions, traits are private to the module they exist in. You can use `pub` -to make the trait public or `pub(crate)` to make it public to just its crate: +By default, like functions, traits and trait aliases are private to the module +they exist in. You can use `pub` to make the trait public or `pub(crate)` to make +it public to just its crate: ```rust // This trait is now public pub trait Trait {} -``` \ No newline at end of file + +// This trait alias is now public +pub trait Baz = Foo + Bar; +``` diff --git a/tooling/nargo_fmt/build.rs b/tooling/nargo_fmt/build.rs index 47bb375f7d1..bd2db5f5b18 100644 --- a/tooling/nargo_fmt/build.rs +++ b/tooling/nargo_fmt/build.rs @@ -47,7 +47,10 @@ fn generate_formatter_tests(test_file: &mut File, test_data_dir: &Path) { .join("\n"); let output_source_path = outputs_dir.join(file_name).display().to_string(); - let output_source = std::fs::read_to_string(output_source_path.clone()).unwrap(); + let output_source = + std::fs::read_to_string(output_source_path.clone()).unwrap_or_else(|_| { + panic!("expected output source at {:?} was not found", &output_source_path) + }); write!( test_file, diff --git a/tooling/nargo_fmt/src/formatter/trait_impl.rs b/tooling/nargo_fmt/src/formatter/trait_impl.rs index 73d9a61b3d4..b31da8a4101 100644 --- a/tooling/nargo_fmt/src/formatter/trait_impl.rs +++ b/tooling/nargo_fmt/src/formatter/trait_impl.rs @@ -10,6 +10,11 @@ use super::Formatter; impl<'a> Formatter<'a> { pub(super) fn format_trait_impl(&mut self, trait_impl: NoirTraitImpl) { + // skip synthetic trait impl's, e.g. generated from trait aliases + if trait_impl.is_synthetic { + return; + } + let has_where_clause = !trait_impl.where_clause.is_empty(); self.write_indentation(); diff --git a/tooling/nargo_fmt/src/formatter/traits.rs b/tooling/nargo_fmt/src/formatter/traits.rs index 9a6b84c6537..1f192be471e 100644 --- a/tooling/nargo_fmt/src/formatter/traits.rs +++ b/tooling/nargo_fmt/src/formatter/traits.rs @@ -15,9 +15,18 @@ impl<'a> Formatter<'a> { self.write_identifier(noir_trait.name); self.format_generics(noir_trait.generics); + if noir_trait.is_alias { + self.write_space(); + self.write_token(Token::Assign); + } + if !noir_trait.bounds.is_empty() { self.skip_comments_and_whitespace(); - self.write_token(Token::Colon); + + if !noir_trait.is_alias { + self.write_token(Token::Colon); + } + self.write_space(); for (index, trait_bound) in noir_trait.bounds.into_iter().enumerate() { @@ -34,6 +43,12 @@ impl<'a> Formatter<'a> { self.format_where_clause(noir_trait.where_clause, true); } + // aliases have ';' in lieu of '{ items }' + if noir_trait.is_alias { + self.write_semicolon(); + return; + } + self.write_space(); self.write_left_brace(); if noir_trait.items.is_empty() { diff --git a/tooling/nargo_fmt/tests/expected/trait_alias.nr b/tooling/nargo_fmt/tests/expected/trait_alias.nr new file mode 100644 index 00000000000..926f3160279 --- /dev/null +++ b/tooling/nargo_fmt/tests/expected/trait_alias.nr @@ -0,0 +1,86 @@ +trait Foo { + fn foo(self) -> Self; +} + +trait Baz = Foo; + +impl Foo for Field { + fn foo(self) -> Self { + self + } +} + +fn baz(x: T) -> T +where + T: Baz, +{ + x.foo() +} + +pub trait Foo_2 { + fn foo_2(self) -> Self; +} + +pub trait Bar_2 { + fn bar_2(self) -> Self; +} + +pub trait Baz_2 = Foo_2 + Bar_2; + +fn baz_2(x: T) -> T +where + T: Baz_2, +{ + x.foo_2().bar_2() +} + +impl Foo_2 for Field { + fn foo_2(self) -> Self { + self + 1 + } +} + +impl Bar_2 for Field { + fn bar_2(self) -> Self { + self + 2 + } +} + +trait Foo_3 { + fn foo_3(self) -> Self; +} + +trait Bar_3 { + fn bar_3(self) -> T; +} + +trait Baz_3 = Foo_3 + Bar_3; + +fn baz_3(x: T) -> U +where + T: Baz_3, +{ + x.foo_3().bar_3() +} + +impl Foo_3 for Field { + fn foo_3(self) -> Self { + self + 1 + } +} + +impl Bar_3 for Field { + fn bar_3(self) -> bool { + true + } +} + +fn main() { + let x: Field = 0; + let _ = baz(x); + + assert(0.foo_2().bar_2() == baz_2(0)); + + assert(0.foo_3().bar_3() == baz_3(0)); +} + diff --git a/tooling/nargo_fmt/tests/input/trait_alias.nr b/tooling/nargo_fmt/tests/input/trait_alias.nr new file mode 100644 index 00000000000..53ae756795b --- /dev/null +++ b/tooling/nargo_fmt/tests/input/trait_alias.nr @@ -0,0 +1,78 @@ +trait Foo { + fn foo(self) -> Self; +} + +trait Baz = Foo; + +impl Foo for Field { + fn foo(self) -> Self { self } +} + +fn baz(x: T) -> T where T: Baz { + x.foo() +} + + +pub trait Foo_2 { + fn foo_2(self) -> Self; +} + +pub trait Bar_2 { + fn bar_2(self) -> Self; +} + +pub trait Baz_2 = Foo_2 + Bar_2; + +fn baz_2(x: T) -> T where T: Baz_2 { + x.foo_2().bar_2() +} + +impl Foo_2 for Field { + fn foo_2(self) -> Self { + self + 1 + } +} + +impl Bar_2 for Field { + fn bar_2(self) -> Self { + self + 2 + } +} + + +trait Foo_3 { + fn foo_3(self) -> Self; +} + +trait Bar_3 { + fn bar_3(self) -> T; +} + +trait Baz_3 = Foo_3 + Bar_3; + +fn baz_3(x: T) -> U where T: Baz_3 { + x.foo_3().bar_3() +} + +impl Foo_3 for Field { + fn foo_3(self) -> Self { + self + 1 + } +} + +impl Bar_3 for Field { + fn bar_3(self) -> bool { + true + } +} + + +fn main() { + let x: Field = 0; + let _ = baz(x); + + assert(0.foo_2().bar_2() == baz_2(0)); + + assert(0.foo_3().bar_3() == baz_3(0)); +} +