From db8f30c9782998868caa8959492ea2c201a924bc Mon Sep 17 00:00:00 2001 From: lifegpc Date: Sat, 9 Jul 2022 08:50:16 +0000 Subject: [PATCH] Update --- proc_macros/proc_macros.rs | 102 +++++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/proc_macros/proc_macros.rs b/proc_macros/proc_macros.rs index c928243..0892817 100644 --- a/proc_macros/proc_macros.rs +++ b/proc_macros/proc_macros.rs @@ -1,5 +1,6 @@ use proc_macro::TokenStream; use quote::quote; +use syn::bracketed; use syn::parse::Parse; use syn::parse_macro_input; use syn::token; @@ -211,12 +212,36 @@ pub fn fanbox_api_quick_test(item: TokenStream) -> TokenStream { stream.into() } +struct HTTPHeader { + pub header: String, +} + +impl Parse for HTTPHeader { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut header = String::new(); + let ident = Ident::parse(input)?; + header += ident.to_string().as_str(); + loop { + if input.cursor().eof() { + break; + } + token::Sub::parse(input)?; + let ident = Ident::parse(input)?; + header += "-"; + header += ident.to_string().as_str(); + } + return Ok(Self { header }); + } +} + struct FilterHttpMethods { pub req: Ident, pub typ: Expr, pub handle_options: LitBool, pub ctx: Option, pub methods: Vec, + pub cors_methods: Option>, + pub expose_headers: Option>, } impl Parse for FilterHttpMethods { @@ -238,13 +263,41 @@ impl Parse for FilterHttpMethods { } else { None }; + let mut cors_methods = None; + let mut expose_headers = None; loop { if input.cursor().eof() { break; } token::Comma::parse(input)?; let method = Ident::parse(input)?; - methods.push(method); + if method.to_string() == "cors_methods" { + cors_methods.replace(Vec::new()); + token::Eq::parse(input)?; + let content; + bracketed!(content in input); + let first: Ident = content.parse()?; + cors_methods.as_mut().unwrap().push(first); + while !content.is_empty() { + let _: token::Comma = content.parse()?; + let m: Ident = content.parse()?; + cors_methods.as_mut().unwrap().push(m); + } + } else if method.to_string() == "expose_headers" { + expose_headers.replace(Vec::new()); + token::Eq::parse(input)?; + let content; + bracketed!(content in input); + let first: HTTPHeader = content.parse()?; + expose_headers.as_mut().unwrap().push(first.header); + while !content.is_empty() { + let _: token::Comma = content.parse()?; + let m: HTTPHeader = content.parse()?; + expose_headers.as_mut().unwrap().push(m.header); + } + } else { + methods.push(method); + } } Ok(Self { req, @@ -252,6 +305,8 @@ impl Parse for FilterHttpMethods { handle_options, ctx, methods, + cors_methods, + expose_headers, }) } } @@ -267,6 +322,8 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { handle_options, ctx, methods, + cors_methods, + expose_headers, } = parse_macro_input!(item as FilterHttpMethods); let mut header_value = Vec::new(); let mut streams = Vec::new(); @@ -280,8 +337,29 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { } } let allow_header = header_value.join(", "); + let cors_methods_header = match cors_methods { + Some(methods) => { + let mut v = Vec::new(); + for method in methods { + v.push(method.to_string()); + } + v.join(", ") + } + None => allow_header.clone(), + }; let allow_header = LitStr::new(allow_header.as_str(), req.span()); + let cors_methods_header = LitStr::new(cors_methods_header.as_str(), req.span()); if enable_options { + let expose_headers = match &expose_headers { + Some(h) => { + let headers = h.join(", "); + let headers = LitStr::new(headers.as_str(), req.span()); + quote!( + let builder = builder.header(hyper::header::ACCESS_CONTROL_EXPOSE_HEADERS, #headers); + ) + } + None => quote!(), + }; streams.push(quote!(&hyper::Method::OPTIONS => { let builder = hyper::Response::builder(); let headers = #req.headers(); @@ -296,11 +374,17 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { Some(origin) => { match #ctx.cors.matches(origin.as_str()) { crate::server::cors::CorsResult::Allowed => { - let builder = builder.header("Access-Control-Allow-Origin", origin.as_str()); + let builder = builder + .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.as_str()) + .header(hyper::header::ACCESS_CONTROL_ALLOW_METHODS, #cors_methods_header); + #expose_headers return Ok(builder.status(200).header("Allow", #allow_header).body(#typ).unwrap()); } crate::server::cors::CorsResult::AllowedAll => { - let builder = builder.header("Access-Control-Allow-Origin", "*"); + let builder = builder + .header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*") + .header(hyper::header::ACCESS_CONTROL_ALLOW_METHODS, #cors_methods_header); + #expose_headers return Ok(builder.status(200).header("Allow", #allow_header).body(#typ).unwrap()); } _ => { @@ -315,6 +399,16 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { })); } let post_stream = if enable_options { + let expose_headers = match expose_headers { + Some(h) => { + let headers = h.join(", "); + let headers = LitStr::new(headers.as_str(), req.span()); + quote!( + builder.headers_mut().unwrap().insert(hyper::header::ACCESS_CONTROL_EXPOSE_HEADERS, #headers.parse().unwrap()); + ) + } + None => quote!(), + }; quote!( let mut builder = hyper::Response::builder(); let headers = #req.headers(); @@ -330,9 +424,11 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { match #ctx.cors.matches(origin.as_str()) { crate::server::cors::CorsResult::Allowed => { builder.headers_mut().unwrap().insert(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap()); + #expose_headers } crate::server::cors::CorsResult::AllowedAll => { builder.headers_mut().unwrap().insert(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap()); + #expose_headers } _ => { return Ok(builder.status(403).body(#typ).unwrap());