diff --git a/Cargo.lock b/Cargo.lock index 734bb21..865a578 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -449,6 +449,17 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +[[package]] +name = "glob-macro" +version = "0.1.0" +dependencies = [ + "glob", + "proc-macro2", + "quote", + "syn", + "syn-mid", +] + [[package]] name = "heck" version = "0.5.0" @@ -1406,6 +1417,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "syn-mid" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5dc35bb08dd1ca3dfb09dce91fd2d13294d6711c88897d9a9d60acf39bce049" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sync_wrapper" version = "1.0.1" @@ -1718,6 +1740,7 @@ dependencies = [ "clap", "fs-err", "glob", + "glob-macro", "insta", "miette", "pest", diff --git a/Cargo.toml b/Cargo.toml index 34fe282..6a3532a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ insta = { version = "1.30.0" } similar = { version = "2.2.1" } glob = "0.3.1" stacker = "0.1.15" +glob-macro = { path = "glob-macro" } # Spend more time on initial compilation in exchange for faster runs [profile.dev.package.insta] @@ -53,6 +54,9 @@ opt-level = 3 inherits = "release" lto = "thin" +[workspace] +members = ["glob-macro"] + # Config for 'cargo dist' [workspace.metadata.dist] # The preferred cargo-dist version to use in CI (Cargo.toml SemVer syntax) diff --git a/glob-macro/Cargo.toml b/glob-macro/Cargo.toml new file mode 100644 index 0000000..f9f6ce1 --- /dev/null +++ b/glob-macro/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "glob-macro" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +glob = "0.3.1" +proc-macro2 = "1.0.86" +quote = "1.0.37" +syn = "2.0.77" +syn-mid = { version = "0.6.0", features = ["clone-impls"] } diff --git a/glob-macro/README.md b/glob-macro/README.md new file mode 100644 index 0000000..9836473 --- /dev/null +++ b/glob-macro/README.md @@ -0,0 +1,15 @@ +# Glob Macro + +A small and simple crate that lets one apply a [`glob`](https://docs.rs/glob/latest/glob/) in a macro position. + +## Usage + +The main intended use case for this is to write tests that run over all files in some directory. For example: + +```rs +#[glob("./path/to/**/*.inp")] +#[test] +fn test(path: &Path) { + assert!(path.exists()); +} +``` diff --git a/glob-macro/src/lib.rs b/glob-macro/src/lib.rs new file mode 100644 index 0000000..7427509 --- /dev/null +++ b/glob-macro/src/lib.rs @@ -0,0 +1,120 @@ +use proc_macro::TokenStream; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::Error; +use syn_mid::ItemFn; + +#[proc_macro_attribute] +pub fn glob(args: TokenStream, function: TokenStream) -> TokenStream { + if args.is_empty() { + return Error::new(Span::call_site(), "#[glob] attribute requires an argument") + .to_compile_error() + .into(); + } + + let glob_path = match syn::parse(args) { + Ok(p) => p, + Err(err) => return err.to_compile_error().into(), + }; + + let function: ItemFn = syn::parse_macro_input!(function); + + if !function + .attrs + .iter() + .any(|attr| attr.path().is_ident("test")) + { + return Error::new( + Span::call_site(), + "#[glob] attribute currently only supports running on #[test] functions", + ) + .to_compile_error() + .into(); + } + + glob2(glob_path, function).into() +} + +fn glob2(glob_path: syn::LitStr, function: ItemFn) -> TokenStream2 { + let paths = match glob::glob(&glob_path.value()) { + Ok(paths) => paths, + Err(err) => { + return Error::new( + glob_path.span(), + format!("#[glob] called with invalid value: {err:?}"), + ) + .to_compile_error(); + } + }; + + let paths = match paths.collect::, _>>() { + Ok(p) => p, + Err(err) => { + return Error::new(glob_path.span(), format!("#[glob] error: {err:?}")) + .to_compile_error(); + } + }; + let counter_width = (paths.len() - 1).to_string().len(); + + let common_ancestor = paths + .iter() + .fold(paths[0].clone(), |ancestor, path| { + ancestor + .components() + .zip(path.components()) + .take_while(|(a, b)| a == b) + .map(|(a, _)| a) + .collect() + }) + .components() + .count(); + + let mut functions = vec![]; + + for (i, path) in paths.iter().enumerate() { + let function_name = syn::Ident::new( + &format!( + "{}__{:0width$}__{}", + function.sig.ident, + i, + path.components() + .skip(common_ancestor) + .map(|p| p.as_os_str().to_string_lossy()) + .collect::>() + .join("/") + .replace("/", "__") + .replace(".", "_") + .replace("-", "_") + .replace(" ", "_"), + width = counter_width + ), + function.sig.ident.span(), + ); + + let path_buf = syn::LitStr::new(&path.to_string_lossy(), Span::call_site()); + + let mut inner_function = function.clone(); + inner_function.sig.ident = syn::Ident::new("test", function.sig.ident.span()); + inner_function.attrs = inner_function + .attrs + .into_iter() + .filter(|attr| !attr.path().is_ident("test")) + .collect(); + + let function = quote! { + #[test] + #[allow(non_snake_case)] + fn #function_name() { + #inner_function + let path: &str = #path_buf; + test(std::path::Path::new(&path)); + } + }; + + functions.push(function); + } + + quote! { + #(#functions)* + } +} diff --git a/tests/snapshot-examples.rs b/tests/snapshot-examples.rs index 41916cc..79932c7 100644 --- a/tests/snapshot-examples.rs +++ b/tests/snapshot-examples.rs @@ -61,22 +61,20 @@ fn pagetable_rs_unchanged() { check_snapshot(include_str!("../examples/pagetable.rs")); } +#[glob_macro::glob("./examples/verus-snapshot/**/*.rs")] #[test] -fn verus_snapshot_unchanged() { - let rustfmt_toml = - std::fs::read_to_string("./examples/verus-snapshot/source/rustfmt.toml").unwrap(); - for path in glob::glob("./examples/verus-snapshot/**/*.rs").unwrap() { - let path = path.unwrap(); - println!("Checking snapshot for {:?}", path); - check_snapshot_with_config( - &std::fs::read_to_string(path).unwrap(), - verusfmt::RunOptions { - file_name: None, - run_rustfmt: true, - rustfmt_config: verusfmt::RustFmtConfig { - rustfmt_toml: Some(rustfmt_toml.clone()), - }, +fn verus_snapshot_unchanged(path: &std::path::Path) { + check_snapshot_with_config( + &std::fs::read_to_string(path).unwrap(), + verusfmt::RunOptions { + file_name: None, + run_rustfmt: true, + rustfmt_config: verusfmt::RustFmtConfig { + rustfmt_toml: Some( + std::fs::read_to_string("./examples/verus-snapshot/source/rustfmt.toml") + .unwrap(), + ), }, - ); - } + }, + ); }