use convert_case::{Case, Casing}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::parse::{Parse, ParseStream}; use syn::{bracketed, parse2, Attribute, Ident, ItemFn, Token}; fn extract_docs(a: &[Attribute]) -> impl Iterator { a.iter().filter(|a| a.path().is_ident("doc")) } #[derive(Debug)] pub enum CommandMacroError { InvalidCliArgumentError(InvalidCliArgumentError), InvalidSubcommandFunctionError(InvalidSubcommandFunctionError), } /// don't use this directly, use apps/macros instead pub fn command( _attr: TokenStream, item: TokenStream, ) -> Result { let input = parse2::(item).map_err(|err| CommandMacroError::InvalidSubcommandFunctionError(InvalidSubcommandFunctionError { message: "Failed to parse function", err, }), )?; let fn_name = &input.sig.ident; let args = &input.sig.inputs; let (fields, arg_names): (Vec<_>, Vec<_>) = args .iter() .map(|arg| match arg { syn::FnArg::Typed(syn::PatType { pat, ty, .. }) => { Ok((quote! { #[arg(long)] #pat: #ty }, quote! { #pat })) } _ => Err(CommandMacroError::InvalidCliArgumentError(InvalidCliArgumentError(format!( "Invalid cli argument: {}", arg.to_token_stream() )))), }) .collect::, _>>()? .into_iter() .unzip(); let docs = extract_docs(&input.attrs); let expanded = quote! { mod #fn_name { use super::#fn_name; use clap::{command, Parser}; #[derive(Parser)] #[command(version, about, long_about = None)] #(#docs)* pub struct Args { #(#fields),* } pub fn run(Args { #(#arg_names),* }: Args) { let result = #fn_name(#(#arg_names),*); println!("{}", result); } } #input }; Ok(expanded) } #[derive(Debug)] pub enum SubcommandsMacroError { InvalidIdentifierError(InvalidIdentifierError), InvalidIdentifierListError(InvalidIdentifierListError), } /// don't use this directly, use apps/macros instead pub fn subcommands( item: TokenStream, ) -> Result { let MergeSubcommandsInput { sub_doc, cli_ident, subcommands, } = parse2::(item).map_err(|err| SubcommandsMacroError::InvalidIdentifierListError(InvalidIdentifierListError { message: "subcommands only accepts lists of identifiers", err, }))?; let match_arms = subcommands.iter().map(|sc| { let ident = &sc.ident; let cmd_name = get_command_name(ident); quote! { Subcommands::#cmd_name(args) => #ident::run(args) } }); let command_enum_fields = subcommands.iter().map(|sc| { let docs = extract_docs(&sc.attrs); let ident = &sc.ident; let cmd_name = get_command_name(ident); quote! { #(#docs)* #cmd_name(#ident::Args) } }); let idents_tokens = subcommands.iter().map(|sc| sc.ident.to_token_stream()); let sub_doc_mod = sub_doc.clone(); let expanded = quote! { #(#sub_doc_mod)* mod #cli_ident { use super::{#(#idents_tokens),*}; use clap::{command, Parser, Subcommand}; #[derive(Subcommand)] #(#sub_doc)* pub enum Subcommands { #(#command_enum_fields),* } #[derive(Parser)] #[command(version, about, long_about = None)] pub struct Args { #[command(subcommand)] command: Subcommands, } pub fn run(Args { command }: Args) { match command { #(#match_arms),* }; } } }; Ok(expanded) } #[allow(dead_code)] struct Subcommand { attrs: Vec, ident: Ident, } struct MergeSubcommandsInput { sub_doc: Vec, cli_ident: Ident, subcommands: Vec, } impl Parse for Subcommand { fn parse(input: ParseStream) -> syn::Result { let attrs = input.call(Attribute::parse_outer)?; let ident: Ident = input.parse()?; Ok(Subcommand { attrs, ident }) } } impl Parse for MergeSubcommandsInput { fn parse(input: ParseStream) -> syn::Result { // parse any doc comments attached to the cli_ident let sub_doc = Attribute::parse_outer(input) .unwrap_or_default() .into_iter() .filter(|attr| attr.path().is_ident("doc")) .collect::>(); let cli_ident: Ident = input.parse()?; input.parse::()?; let content; bracketed!(content in input); let subcommands: syn::punctuated::Punctuated = content.parse_terminated(Subcommand::parse, Token![,])?; Ok(MergeSubcommandsInput { sub_doc, cli_ident, subcommands: subcommands.into_iter().collect(), }) } } fn get_command_name(func_name: &Ident) -> Ident { Ident::new( &func_name.to_string().to_case(Case::UpperCamel), func_name.span(), ) } #[allow(dead_code)] #[derive(Debug, Clone, Copy)] pub struct InvalidIdentifierError(&'static str); #[allow(dead_code)] #[derive(Debug, Clone)] pub struct InvalidIdentifierListError { message: &'static str, err: syn::Error, } #[allow(dead_code)] #[derive(Debug, Clone)] pub struct InvalidSubcommandFunctionError { message: &'static str, err: syn::Error, } #[allow(dead_code)] #[derive(Debug, Clone)] pub struct InvalidCliArgumentError(String);