diff --git a/proc_macros/proc_macros.rs b/proc_macros/proc_macros.rs index 4ce8ae7..e0c64ff 100644 --- a/proc_macros/proc_macros.rs +++ b/proc_macros/proc_macros.rs @@ -261,6 +261,7 @@ struct FilterHttpMethods { pub cors_methods: Option>, pub expose_headers: Option>, pub cors_allow_headers: Option>, + pub typ_def: Option, } impl Parse for FilterHttpMethods { @@ -285,6 +286,7 @@ impl Parse for FilterHttpMethods { let mut cors_methods = None; let mut expose_headers = None; let mut cors_allow_headers = None; + let mut typ_def = None; loop { if input.cursor().eof() { break; @@ -330,6 +332,9 @@ impl Parse for FilterHttpMethods { let m: HTTPHeader = content.parse()?; cors_allow_headers.as_mut().unwrap().push(m.header); } + } else if method.to_string() == "typ_def" { + token::Eq::parse(input)?; + typ_def.replace(input.parse()?); } else { methods.push(method); } @@ -343,6 +348,7 @@ impl Parse for FilterHttpMethods { cors_methods, expose_headers, cors_allow_headers, + typ_def, }) } } @@ -361,7 +367,12 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { cors_methods, expose_headers, cors_allow_headers, + typ_def, } = parse_macro_input!(item as FilterHttpMethods); + let typ_def = match typ_def { + Some(t) => Some(quote!(::<#t>)), + None => None, + }; let mut header_value = Vec::new(); let mut streams = Vec::new(); let mut enable_options = false; @@ -429,7 +440,7 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { .header(hyper::header::ACCESS_CONTROL_ALLOW_METHODS, #cors_methods_header); #expose_headers #cors_allow_headers - return Ok(builder.status(200).header("Allow", #allow_header).body(#typ)?); + return Ok(builder.status(200).header("Allow", #allow_header).body #typ_def(#typ)?); } crate::server::cors::CorsResult::AllowedAll => { let builder = builder @@ -437,15 +448,15 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { .header(hyper::header::ACCESS_CONTROL_ALLOW_METHODS, #cors_methods_header); #expose_headers #cors_allow_headers - return Ok(builder.status(200).header("Allow", #allow_header).body(#typ)?); + return Ok(builder.status(200).header("Allow", #allow_header).body #typ_def(#typ)?); } _ => { - return Ok(builder.status(400).header("Allow", #allow_header).body(#typ)?); + return Ok(builder.status(400).header("Allow", #allow_header).body #typ_def(#typ)?); } } } None => { - return Ok(builder.status(200).header("Allow", #allow_header).body(#typ)?); + return Ok(builder.status(200).header("Allow", #allow_header).body #typ_def(#typ)?); } } })); @@ -509,7 +520,7 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { #cors_allow_headers } _ => { - return Ok(builder.status(403).body(#typ)?); + return Ok(builder.status(403).body #typ_def(#typ)?); } } } @@ -523,7 +534,7 @@ pub fn filter_http_methods(item: TokenStream) -> TokenStream { match #req.method() { #(#streams)* _ => { - return Ok(hyper::Response::builder().status(405).header("Allow", #allow_header).body(#typ)?) + return Ok(hyper::Response::builder().status(405).header("Allow", #allow_header).body #typ_def(#typ)?) } } #post_stream @@ -889,3 +900,47 @@ pub fn call_parent_data_source_fun(item: TokenStream) -> TokenStream { ); stream.into() } + +struct HttpError { + pub code: Option, + pub expr: Expr, +} + +impl Parse for HttpError { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut code = None; + match input.parse::() { + Ok(c) => { + code.replace(c); + match input.parse::() { + Ok(_) => {} + Err(_) => { + panic!("Expr not found"); + } + }; + } + Err(_) => {} + } + let expr = input.parse()?; + Ok(Self { code, expr }) + } +} + +#[proc_macro] +pub fn http_error(item: TokenStream) -> TokenStream { + let HttpError { code, expr } = parse_macro_input!(item as HttpError); + let code = match code { + Some(code) => code, + None => LitInt::new("400", Span::call_site()), + }; + let stream = quote!( + match (#expr) { + Ok(re) => re, + Err(e) => { + builder = builder.status(#code); + return Ok(builder.body::>>(Box::pin(Body::from(format!("{}", e))))?); + } + } + ); + stream.into() +} diff --git a/src/server/auth/pubkey.rs b/src/server/auth/pubkey.rs index 72521b7..4dbb1d1 100644 --- a/src/server/auth/pubkey.rs +++ b/src/server/auth/pubkey.rs @@ -107,7 +107,7 @@ impl AuthPubkeyRoute { } } -impl MatchRoute for AuthPubkeyRoute { +impl MatchRoute>> for AuthPubkeyRoute { fn match_route( &self, ctx: &Arc, diff --git a/src/server/auth/status.rs b/src/server/auth/status.rs index c50e386..f592f0b 100644 --- a/src/server/auth/status.rs +++ b/src/server/auth/status.rs @@ -43,7 +43,7 @@ impl AuthStatusRoute { } } -impl MatchRoute for AuthStatusRoute { +impl MatchRoute>> for AuthStatusRoute { fn match_route( &self, ctx: &Arc, diff --git a/src/server/auth/token.rs b/src/server/auth/token.rs index 0a72730..46781ef 100644 --- a/src/server/auth/token.rs +++ b/src/server/auth/token.rs @@ -216,7 +216,7 @@ impl AuthTokenRoute { } } -impl MatchRoute for AuthTokenRoute { +impl MatchRoute>> for AuthTokenRoute { fn match_route( &self, ctx: &Arc, diff --git a/src/server/auth/user.rs b/src/server/auth/user.rs index 4ff311b..b43df4e 100644 --- a/src/server/auth/user.rs +++ b/src/server/auth/user.rs @@ -507,7 +507,7 @@ impl AuthUserRoute { } } -impl MatchRoute for AuthUserRoute { +impl MatchRoute>> for AuthUserRoute { fn match_route( &self, ctx: &Arc, diff --git a/src/server/context.rs b/src/server/context.rs index d8addf6..9a4afed 100644 --- a/src/server/context.rs +++ b/src/server/context.rs @@ -136,4 +136,17 @@ impl ServerContext { .await? .ok_or(gettext("No corresponding user was found."))?) } + + pub async fn verify( + &self, + req: &Request, + params: &RequestParams, + ) -> Result, PixivDownloaderError> { + let root_user = self.db.get_user(0).await?; + if root_user.is_some() { + Ok(Some(self.verify_token(req, params).await?)) + } else { + Ok(None) + } + } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 46ecebc..af4974b 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -7,6 +7,8 @@ pub mod cors; pub mod params; /// Predefined includes pub mod preclude; +/// Routes about proxy +pub mod proxy; /// Base result type for JSON response pub mod result; /// Routes diff --git a/src/server/preclude.rs b/src/server/preclude.rs index f8a3c11..36568fa 100644 --- a/src/server/preclude.rs +++ b/src/server/preclude.rs @@ -4,11 +4,15 @@ pub use super::result::JSONResult; pub use super::route::ResponseForType; pub use super::traits::{GetRequestParams, MatchRoute, ResponseFor, ResponseJsonFor}; pub use crate::error::PixivDownloaderError; +pub use hyper::body::HttpBody; pub use hyper::Body; pub use hyper::Method; pub use hyper::Request; pub use hyper::Response; pub use json::JsonValue; -pub use proc_macros::filter_http_methods; +pub use proc_macros::{filter_http_methods, http_error}; pub use regex::Regex; +pub use std::pin::Pin; pub use std::sync::Arc; + +pub type HttpBodyType = dyn HttpBody + Send; diff --git a/src/server/proxy/mod.rs b/src/server/proxy/mod.rs new file mode 100644 index 0000000..0457c91 --- /dev/null +++ b/src/server/proxy/mod.rs @@ -0,0 +1,3 @@ +pub mod pixiv; + +pub use pixiv::ProxyPixivRoute; diff --git a/src/server/proxy/pixiv.rs b/src/server/proxy/pixiv.rs new file mode 100644 index 0000000..073fea0 --- /dev/null +++ b/src/server/proxy/pixiv.rs @@ -0,0 +1,66 @@ +use super::super::preclude::*; +use http::Uri; + +pub struct ProxyPixivContext { + ctx: Arc, +} + +impl ProxyPixivContext { + pub fn new(ctx: Arc) -> Self { + Self { ctx } + } +} + +#[async_trait] +impl ResponseFor>> for ProxyPixivContext { + async fn response( + &self, + mut req: Request, + ) -> Result>>, PixivDownloaderError> { + filter_http_methods!( + req, + Box::pin(Body::empty()), + true, + self.ctx, + allow_headers = [X_SIGN, X_TOKEN_ID], + typ_def=Pin>, + GET, + OPTIONS + ); + let params = req.get_params().await?; + let _ = http_error!(401, self.ctx.verify(&req, ¶ms).await); + let url = http_error!(params.get("url").ok_or("Url is required.")); + let uri = http_error!(Uri::try_from(url)); + let host = uri.host().ok_or("Host is needed.")?; + if !host.ends_with(".pximg.net") { + http_error!(403, Err("Host is not allowed.")); + } + return Ok(builder.body::>>(Box::pin(Body::empty()))?); + } +} + +pub struct ProxyPixivRoute { + regex: Regex, +} + +impl ProxyPixivRoute { + pub fn new() -> Self { + Self { + regex: Regex::new(r"^(/+api)?/+proxy/+pixiv(/.*)?$").unwrap(), + } + } +} + +impl MatchRoute>> for ProxyPixivRoute { + fn match_route( + &self, + ctx: &Arc, + req: &Request, + ) -> Option> { + if self.regex.is_match(req.uri().path()) { + Some(Box::new(ProxyPixivContext::new(Arc::clone(ctx)))) + } else { + None + } + } +} diff --git a/src/server/route.rs b/src/server/route.rs index 897ab7d..5ae2bbb 100644 --- a/src/server/route.rs +++ b/src/server/route.rs @@ -1,14 +1,17 @@ use super::auth::*; use super::context::ServerContext; +use super::preclude::HttpBodyType; +use super::proxy::*; use super::traits::MatchRoute; use super::traits::ResponseFor; use super::version::VersionRoute; use hyper::Body; use hyper::Request; +use std::pin::Pin; use std::sync::Arc; -pub type RouteType = dyn MatchRoute + Send + Sync; -pub type ResponseForType = dyn ResponseFor + Send + Sync; +pub type RouteType = dyn MatchRoute>> + Send + Sync; +pub type ResponseForType = dyn ResponseFor>> + Send + Sync; pub struct ServerRoutes { routes: Vec>, @@ -22,6 +25,7 @@ impl ServerRoutes { routes.push(Box::new(AuthUserRoute::new())); routes.push(Box::new(AuthPubkeyRoute::new())); routes.push(Box::new(AuthTokenRoute::new())); + routes.push(Box::new(ProxyPixivRoute::new())); Self { routes } } diff --git a/src/server/service.rs b/src/server/service.rs index adec2ed..9eafe00 100644 --- a/src/server/service.rs +++ b/src/server/service.rs @@ -1,9 +1,9 @@ use super::context::ServerContext; +use super::preclude::*; use super::route::ServerRoutes; use hyper::server::conn::AddrIncoming; use hyper::server::Server; use hyper::service::Service; -use hyper::Body; use hyper::Request; use hyper::Response; use std::future::Future; @@ -25,7 +25,7 @@ impl PixivDownloaderSvc { } impl Service> for PixivDownloaderSvc { - type Response = Response; + type Response = Response>>; type Error = hyper::Error; type Future = Pin> + Send>>; @@ -43,7 +43,9 @@ impl Service> for PixivDownloaderSvc { println!("{}", e); Ok(Response::builder() .status(500) - .body(Body::from("Internal server error")) + .body::>>(Box::pin(Body::from( + "Internal server error", + ))) .unwrap()) } } @@ -51,7 +53,7 @@ impl Service> for PixivDownloaderSvc { None => Box::pin(async { Ok(Response::builder() .status(404) - .body(Body::from("404 Not Found")) + .body::>>(Box::pin(Body::from("404 Not Found"))) .unwrap()) }), } diff --git a/src/server/traits.rs b/src/server/traits.rs index e1f8dae..2a1d81d 100644 --- a/src/server/traits.rs +++ b/src/server/traits.rs @@ -1,10 +1,12 @@ use super::context::ServerContext; use super::params::RequestParams; +use super::preclude::HttpBodyType; use crate::error::PixivDownloaderError; use hyper::Body; use hyper::Request; use hyper::Response; use json::JsonValue; +use std::pin::Pin; use std::sync::Arc; pub trait MatchRoute { @@ -36,18 +38,24 @@ pub trait GetRequestParams { } #[async_trait] -impl ResponseFor for U +impl ResponseFor>> for U where U: ResponseJsonFor + Sync + Send, T: Sync + Send + 'static, { - async fn response(&self, req: Request) -> Result, PixivDownloaderError> { + async fn response( + &self, + req: Request, + ) -> Result>>, PixivDownloaderError> { let re = self.response_json(req).await?; let (mut parts, body) = re.into_parts(); parts.headers.insert( hyper::header::CONTENT_TYPE, "application/json; charset=utf-8".parse()?, ); - Ok(Response::from_parts(parts, Body::from(body.to_string()))) + Ok(Response::from_parts( + parts, + Box::pin(Body::from(body.to_string())), + )) } } diff --git a/src/server/unittest/mod.rs b/src/server/unittest/mod.rs index d5f248d..4298986 100644 --- a/src/server/unittest/mod.rs +++ b/src/server/unittest/mod.rs @@ -3,6 +3,7 @@ mod version; use super::context::ServerContext; use super::cors::CorsContext; +use super::preclude::HttpBodyType; use super::route::ServerRoutes; use crate::db::{open_and_init_database, PixivDownloaderDbConfig}; use crate::error::PixivDownloaderError; @@ -14,6 +15,7 @@ use std::collections::BTreeMap; use std::fs::{create_dir, remove_file}; #[cfg(test)] use std::path::Path; +use std::pin::Pin; use std::sync::Arc; pub struct UnitTestContext { @@ -44,7 +46,7 @@ impl UnitTestContext { pub async fn request( &self, req: Request, - ) -> Result>, PixivDownloaderError> { + ) -> Result>>>, PixivDownloaderError> { Ok(match self.routes.match_route(&req, &self.ctx) { Some(r) => Some(r.response(req).await?), None => None, diff --git a/src/server/version.rs b/src/server/version.rs index b591adf..76340b2 100644 --- a/src/server/version.rs +++ b/src/server/version.rs @@ -44,7 +44,7 @@ impl VersionRoute { } } -impl MatchRoute for VersionRoute { +impl MatchRoute>> for VersionRoute { fn match_route( &self, ctx: &Arc,