This commit is contained in:
2022-07-09 08:50:16 +00:00
committed by GitHub
parent c0d61e8a5b
commit db8f30c978

View File

@@ -1,5 +1,6 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::bracketed;
use syn::parse::Parse;
use syn::parse_macro_input;
use syn::token;
@@ -211,12 +212,36 @@ pub fn fanbox_api_quick_test(item: TokenStream) -> TokenStream {
stream.into()
}
struct HTTPHeader {
pub header: String,
}
impl Parse for HTTPHeader {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut header = String::new();
let ident = Ident::parse(input)?;
header += ident.to_string().as_str();
loop {
if input.cursor().eof() {
break;
}
token::Sub::parse(input)?;
let ident = Ident::parse(input)?;
header += "-";
header += ident.to_string().as_str();
}
return Ok(Self { header });
}
}
struct FilterHttpMethods {
pub req: Ident,
pub typ: Expr,
pub handle_options: LitBool,
pub ctx: Option<Expr>,
pub methods: Vec<Ident>,
pub cors_methods: Option<Vec<Ident>>,
pub expose_headers: Option<Vec<String>>,
}
impl Parse for FilterHttpMethods {
@@ -238,13 +263,41 @@ impl Parse for FilterHttpMethods {
} else {
None
};
let mut cors_methods = None;
let mut expose_headers = None;
loop {
if input.cursor().eof() {
break;
}
token::Comma::parse(input)?;
let method = Ident::parse(input)?;
methods.push(method);
if method.to_string() == "cors_methods" {
cors_methods.replace(Vec::new());
token::Eq::parse(input)?;
let content;
bracketed!(content in input);
let first: Ident = content.parse()?;
cors_methods.as_mut().unwrap().push(first);
while !content.is_empty() {
let _: token::Comma = content.parse()?;
let m: Ident = content.parse()?;
cors_methods.as_mut().unwrap().push(m);
}
} else if method.to_string() == "expose_headers" {
expose_headers.replace(Vec::new());
token::Eq::parse(input)?;
let content;
bracketed!(content in input);
let first: HTTPHeader = content.parse()?;
expose_headers.as_mut().unwrap().push(first.header);
while !content.is_empty() {
let _: token::Comma = content.parse()?;
let m: HTTPHeader = content.parse()?;
expose_headers.as_mut().unwrap().push(m.header);
}
} else {
methods.push(method);
}
}
Ok(Self {
req,
@@ -252,6 +305,8 @@ impl Parse for FilterHttpMethods {
handle_options,
ctx,
methods,
cors_methods,
expose_headers,
})
}
}
@@ -267,6 +322,8 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream {
handle_options,
ctx,
methods,
cors_methods,
expose_headers,
} = parse_macro_input!(item as FilterHttpMethods);
let mut header_value = Vec::new();
let mut streams = Vec::new();
@@ -280,8 +337,29 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream {
}
}
let allow_header = header_value.join(", ");
let cors_methods_header = match cors_methods {
Some(methods) => {
let mut v = Vec::new();
for method in methods {
v.push(method.to_string());
}
v.join(", ")
}
None => allow_header.clone(),
};
let allow_header = LitStr::new(allow_header.as_str(), req.span());
let cors_methods_header = LitStr::new(cors_methods_header.as_str(), req.span());
if enable_options {
let expose_headers = match &expose_headers {
Some(h) => {
let headers = h.join(", ");
let headers = LitStr::new(headers.as_str(), req.span());
quote!(
let builder = builder.header(hyper::header::ACCESS_CONTROL_EXPOSE_HEADERS, #headers);
)
}
None => quote!(),
};
streams.push(quote!(&hyper::Method::OPTIONS => {
let builder = hyper::Response::builder();
let headers = #req.headers();
@@ -296,11 +374,17 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream {
Some(origin) => {
match #ctx.cors.matches(origin.as_str()) {
crate::server::cors::CorsResult::Allowed => {
let builder = builder.header("Access-Control-Allow-Origin", origin.as_str());
let builder = builder
.header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.as_str())
.header(hyper::header::ACCESS_CONTROL_ALLOW_METHODS, #cors_methods_header);
#expose_headers
return Ok(builder.status(200).header("Allow", #allow_header).body(#typ).unwrap());
}
crate::server::cors::CorsResult::AllowedAll => {
let builder = builder.header("Access-Control-Allow-Origin", "*");
let builder = builder
.header(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*")
.header(hyper::header::ACCESS_CONTROL_ALLOW_METHODS, #cors_methods_header);
#expose_headers
return Ok(builder.status(200).header("Allow", #allow_header).body(#typ).unwrap());
}
_ => {
@@ -315,6 +399,16 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream {
}));
}
let post_stream = if enable_options {
let expose_headers = match expose_headers {
Some(h) => {
let headers = h.join(", ");
let headers = LitStr::new(headers.as_str(), req.span());
quote!(
builder.headers_mut().unwrap().insert(hyper::header::ACCESS_CONTROL_EXPOSE_HEADERS, #headers.parse().unwrap());
)
}
None => quote!(),
};
quote!(
let mut builder = hyper::Response::builder();
let headers = #req.headers();
@@ -330,9 +424,11 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream {
match #ctx.cors.matches(origin.as_str()) {
crate::server::cors::CorsResult::Allowed => {
builder.headers_mut().unwrap().insert(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.parse().unwrap());
#expose_headers
}
crate::server::cors::CorsResult::AllowedAll => {
builder.headers_mut().unwrap().insert(hyper::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*".parse().unwrap());
#expose_headers
}
_ => {
return Ok(builder.status(403).body(#typ).unwrap());