diff --git a/Cargo.toml b/Cargo.toml index ba9bc59..c119af9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "logcall" -version = "0.1.0" +version = "0.1.1" edition = "2021" authors = ["andylokandy "] description = "An attribute macro that logs the return value from function call." diff --git a/src/lib.rs b/src/lib.rs index dc5eec3..08dbf6f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,13 +12,8 @@ extern crate proc_macro; extern crate proc_macro_error; use proc_macro2::Span; -use proc_macro2::TokenStream; -use proc_macro2::TokenTree; -use quote::format_ident; use quote::quote_spanned; -use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::visit_mut::VisitMut; use syn::Ident; use syn::*; @@ -69,8 +64,13 @@ pub fn logcall( AsyncTraitKind::Async(async_expr) => { // fallback if we couldn't find the '__async_trait' binding, might be // useful for crates exhibiting the same behaviors as async-trait - let instrumented_block = - gen_block(&async_expr.block, true, &input.sig.ident.to_string(), args); + let instrumented_block = gen_block( + &async_expr.block, + true, + false, + &input.sig.ident.to_string(), + args, + ); let async_attrs = &async_expr.attrs; quote! { Box::pin(#(#async_attrs) * #instrumented_block ) @@ -81,23 +81,16 @@ pub fn logcall( gen_block( &input.block, input.sig.asyncness.is_some(), + input.sig.asyncness.is_some(), &input.sig.ident.to_string(), args, ) }; let ItemFn { - attrs, - vis, - mut sig, - .. + attrs, vis, sig, .. } = input; - if sig.asyncness.is_some() { - let has_self = has_self_in_sig(&mut sig); - transform_sig(&mut sig, has_self, true); - } - let Signature { output: return_type, inputs: params, @@ -105,6 +98,7 @@ pub fn logcall( constness, abi, ident, + asyncness, generics: Generics { params: gen_params, @@ -116,7 +110,7 @@ pub fn logcall( quote::quote!( #(#attrs) * - #vis #constness #unsafety #abi fn #ident<#gen_params>(#params) #return_type + #vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type #where_clause { #func_body @@ -129,6 +123,7 @@ pub fn logcall( fn gen_block( block: &Block, async_context: bool, + async_keyword: bool, fn_name: &str, args: Args, ) -> proc_macro2::TokenStream { @@ -136,13 +131,21 @@ fn gen_block( // If the function is an `async fn`, this will wrap it in an async block. if async_context { let log = gen_log(&args.level, fn_name, "__ret_value"); - quote_spanned!(block.span()=> + let block = quote_spanned!(block.span()=> async move { let __ret_value = #block; #log; __ret_value } - ) + ); + + if async_keyword { + quote_spanned!(block.span()=> + #block.await + ) + } else { + block + } } else { let log = gen_log(&args.level, fn_name, "__ret_value"); quote_spanned!(block.span()=> @@ -165,270 +168,6 @@ fn gen_log(level: &str, fn_name: &str, return_value: &str) -> proc_macro2::Token ) } -fn transform_sig(sig: &mut Signature, has_self: bool, is_local: bool) { - sig.fn_token.span = sig.asyncness.take().unwrap().span; - - let ret = match &sig.output { - ReturnType::Default => quote!(()), - ReturnType::Type(_, ret) => quote!(#ret), - }; - - let default_span = sig - .ident - .span() - .join(sig.paren_token.span) - .unwrap_or_else(|| sig.ident.span()); - - let mut lifetimes = CollectLifetimes::new("'life", default_span); - for arg in sig.inputs.iter_mut() { - match arg { - FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg), - FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty), - } - } - - for param in sig.generics.params.iter() { - match param { - GenericParam::Type(param) => { - let param = ¶m.ident; - let span = param.span(); - where_clause_or_default(&mut sig.generics.where_clause) - .predicates - .push(parse_quote_spanned!(span=> #param: 'logcall)); - } - GenericParam::Lifetime(param) => { - let param = ¶m.lifetime; - let span = param.span(); - where_clause_or_default(&mut sig.generics.where_clause) - .predicates - .push(parse_quote_spanned!(span=> #param: 'logcall)); - } - GenericParam::Const(_) => {} - } - } - - if sig.generics.lt_token.is_none() { - sig.generics.lt_token = Some(Token![<](sig.ident.span())); - } - if sig.generics.gt_token.is_none() { - sig.generics.gt_token = Some(Token![>](sig.paren_token.span)); - } - - for (idx, elided) in lifetimes.elided.iter().enumerate() { - sig.generics.params.insert(idx, parse_quote!(#elided)); - where_clause_or_default(&mut sig.generics.where_clause) - .predicates - .push(parse_quote_spanned!(elided.span()=> #elided: 'logcall)); - } - - sig.generics - .params - .insert(0, parse_quote_spanned!(default_span=> 'logcall)); - - if has_self { - let bound_span = sig.ident.span(); - let bound = match sig.inputs.iter().next() { - Some(FnArg::Receiver(Receiver { - reference: Some(_), - mutability: None, - .. - })) => Ident::new("Sync", bound_span), - Some(FnArg::Typed(arg)) - if match (arg.pat.as_ref(), arg.ty.as_ref()) { - (Pat::Ident(pat), Type::Reference(ty)) => { - pat.ident == "self" && ty.mutability.is_none() - } - _ => false, - } => - { - Ident::new("Sync", bound_span) - } - _ => Ident::new("Send", bound_span), - }; - - let where_clause = where_clause_or_default(&mut sig.generics.where_clause); - where_clause.predicates.push(if is_local { - parse_quote_spanned!(bound_span=> Self: 'logcall) - } else { - parse_quote_spanned!(bound_span=> Self: ::core::marker::#bound + 'logcall) - }); - } - - for (i, arg) in sig.inputs.iter_mut().enumerate() { - match arg { - FnArg::Receiver(Receiver { - reference: Some(_), .. - }) => {} - FnArg::Receiver(arg) => arg.mutability = None, - FnArg::Typed(arg) => { - if let Pat::Ident(ident) = &mut *arg.pat { - ident.by_ref = None; - ident.mutability = None; - } else { - let positional = positional_arg(i, &arg.pat); - let m = mut_pat(&mut arg.pat); - arg.pat = parse_quote!(#m #positional); - } - } - } - } - - let ret_span = sig.ident.span(); - let bounds = if is_local { - quote_spanned!(ret_span=> 'logcall) - } else { - quote_spanned!(ret_span=> ::core::marker::Send + 'logcall) - }; - sig.output = parse_quote_spanned! {ret_span=> - -> impl ::core::future::Future + #bounds - }; -} - -struct CollectLifetimes { - pub elided: Vec, - pub explicit: Vec, - pub name: &'static str, - pub default_span: Span, -} - -impl CollectLifetimes { - pub fn new(name: &'static str, default_span: Span) -> Self { - CollectLifetimes { - elided: Vec::new(), - explicit: Vec::new(), - name, - default_span, - } - } - - fn visit_opt_lifetime(&mut self, lifetime: &mut Option) { - match lifetime { - None => *lifetime = Some(self.next_lifetime(None)), - Some(lifetime) => self.visit_lifetime(lifetime), - } - } - - fn visit_lifetime(&mut self, lifetime: &mut Lifetime) { - if lifetime.ident == "_" { - *lifetime = self.next_lifetime(lifetime.span()); - } else { - self.explicit.push(lifetime.clone()); - } - } - - fn next_lifetime>>(&mut self, span: S) -> Lifetime { - let name = format!("{}{}", self.name, self.elided.len()); - let span = span.into().unwrap_or(self.default_span); - let life = Lifetime::new(&name, span); - self.elided.push(life.clone()); - life - } -} - -impl VisitMut for CollectLifetimes { - fn visit_receiver_mut(&mut self, arg: &mut Receiver) { - if let Some((_, lifetime)) = &mut arg.reference { - self.visit_opt_lifetime(lifetime); - } - } - - fn visit_type_reference_mut(&mut self, ty: &mut TypeReference) { - self.visit_opt_lifetime(&mut ty.lifetime); - visit_mut::visit_type_reference_mut(self, ty); - } - - fn visit_generic_argument_mut(&mut self, gen: &mut GenericArgument) { - if let GenericArgument::Lifetime(lifetime) = gen { - self.visit_lifetime(lifetime); - } - visit_mut::visit_generic_argument_mut(self, gen); - } -} - -fn positional_arg(i: usize, pat: &Pat) -> Ident { - format_ident!("__arg{}", i, span = pat.span()) -} - -fn mut_pat(pat: &mut Pat) -> Option { - let mut visitor = HasMutPat(None); - visitor.visit_pat_mut(pat); - visitor.0 -} - -fn has_self_in_sig(sig: &mut Signature) -> bool { - let mut visitor = HasSelf(false); - visitor.visit_signature_mut(sig); - visitor.0 -} - -fn has_self_in_token_stream(tokens: TokenStream) -> bool { - tokens.into_iter().any(|tt| match tt { - TokenTree::Ident(ident) => ident == "Self", - TokenTree::Group(group) => has_self_in_token_stream(group.stream()), - _ => false, - }) -} - -struct HasMutPat(Option); - -impl VisitMut for HasMutPat { - fn visit_pat_ident_mut(&mut self, i: &mut PatIdent) { - if let Some(m) = i.mutability { - self.0 = Some(m); - } else { - visit_mut::visit_pat_ident_mut(self, i); - } - } -} - -struct HasSelf(bool); - -impl VisitMut for HasSelf { - fn visit_expr_path_mut(&mut self, expr: &mut ExprPath) { - self.0 |= expr.path.segments[0].ident == "Self"; - visit_mut::visit_expr_path_mut(self, expr); - } - - fn visit_pat_path_mut(&mut self, pat: &mut PatPath) { - self.0 |= pat.path.segments[0].ident == "Self"; - visit_mut::visit_pat_path_mut(self, pat); - } - - fn visit_type_path_mut(&mut self, ty: &mut TypePath) { - self.0 |= ty.path.segments[0].ident == "Self"; - visit_mut::visit_type_path_mut(self, ty); - } - - fn visit_receiver_mut(&mut self, _arg: &mut Receiver) { - self.0 = true; - } - - fn visit_item_mut(&mut self, _: &mut Item) { - // Do not recurse into nested items. - } - - fn visit_macro_mut(&mut self, mac: &mut Macro) { - if !contains_fn(mac.tokens.clone()) { - self.0 |= has_self_in_token_stream(mac.tokens.clone()); - } - } -} - -fn contains_fn(tokens: TokenStream) -> bool { - tokens.into_iter().any(|tt| match tt { - TokenTree::Ident(ident) => ident == "fn", - TokenTree::Group(group) => contains_fn(group.stream()), - _ => false, - }) -} - -fn where_clause_or_default(clause: &mut Option) -> &mut WhereClause { - clause.get_or_insert_with(|| WhereClause { - where_token: Default::default(), - predicates: Punctuated::new(), - }) -} - enum AsyncTraitKind<'a> { // old construction. Contains the function Function(&'a ItemFn), diff --git a/tests/ui/ok/async-in-trait.rs b/tests/ui/ok/async-in-trait.rs new file mode 100644 index 0000000..14cd414 --- /dev/null +++ b/tests/ui/ok/async-in-trait.rs @@ -0,0 +1,18 @@ +#![feature(async_fn_in_trait)] +#![allow(unused_mut)] + +trait MyTrait { + async fn work(&self) -> usize; +} + +struct MyStruct; + +impl MyTrait for MyStruct { + #[logcall::logcall("debug")] + #[logcall::logcall("debug")] + async fn work(&self) -> usize { + 1 + } +} + +fn main() {} diff --git a/tests/ui/ok/async-mut.rs b/tests/ui/ok/async-mut.rs index 91002fb..763ac4f 100644 --- a/tests/ui/ok/async-mut.rs +++ b/tests/ui/ok/async-mut.rs @@ -1,3 +1,5 @@ +#[allow(unused_mut)] + #[logcall::logcall("warn")] async fn f(mut a: u32) -> u32 { a