mirror of
https://github.com/lifegpc/pixiv_downloader.git
synced 2026-06-06 05:49:01 +08:00
Push API now support update
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -107,9 +107,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0"
|
||||
version = "0.1.74"
|
||||
source = "git+https://github.com/lifegpc/async-trait#6fb1f56b170289e1681e864c1d6e51783783b3c7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
||||
@@ -70,6 +70,7 @@ server = ["async-trait", "base64", "db", "hex", "hyper", "multipart", "openssl",
|
||||
ugoira = ["avdict", "bindgen", "cmake", "link-cplusplus"]
|
||||
|
||||
[patch.crates-io]
|
||||
async-trait = { git = "https://github.com/lifegpc/async-trait" }
|
||||
buf_redux = { git = "https://github.com/lifegpc/buf_redux" }
|
||||
openssl = { git = "https://github.com/lifegpc/rust-openssl" }
|
||||
openssl-sys = { git = "https://github.com/lifegpc/rust-openssl" }
|
||||
|
||||
@@ -207,4 +207,11 @@ impl PushTask {
|
||||
ttl: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_need_update(&self) -> bool {
|
||||
let now = Utc::now();
|
||||
let last_updated = self.last_updated;
|
||||
let ttl = self.ttl;
|
||||
now.timestamp() - last_updated.timestamp() > ttl as i64
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,6 +62,11 @@ push_configs TEXT,
|
||||
last_updated DATETIME,
|
||||
ttl INT
|
||||
);";
|
||||
const PUSH_TASK_DATA_TABLE: &'static str = "CREATE TABLE push_task_data (
|
||||
id INT,
|
||||
data TEXT,
|
||||
PRIMARY KEY (id)
|
||||
);";
|
||||
const TAGS_TABLE: &'static str = "CREATE TABLE tags (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT
|
||||
@@ -93,7 +98,7 @@ v3 INT,
|
||||
v4 INT,
|
||||
PRIMARY KEY (id)
|
||||
);";
|
||||
const VERSION: [u8; 4] = [1, 0, 0, 7];
|
||||
const VERSION: [u8; 4] = [1, 0, 0, 8];
|
||||
|
||||
pub struct PixivDownloaderSqlite {
|
||||
db: Mutex<Connection>,
|
||||
@@ -252,6 +257,9 @@ impl PixivDownloaderSqlite {
|
||||
if db_version < [1, 0, 0, 7] {
|
||||
tx.execute(PUSH_TASK_TABLE, [])?;
|
||||
}
|
||||
if db_version < [1, 0, 0, 8] {
|
||||
tx.execute(PUSH_TASK_DATA_TABLE, [])?;
|
||||
}
|
||||
self._write_version(&tx)?;
|
||||
tx.commit()?;
|
||||
}
|
||||
@@ -302,6 +310,9 @@ impl PixivDownloaderSqlite {
|
||||
if !tables.contains_key("config") {
|
||||
t.execute(CONFIG_TABLE, [])?;
|
||||
}
|
||||
if !tables.contains_key("push_task_data") {
|
||||
t.execute(PUSH_TASK_DATA_TABLE, [])?;
|
||||
}
|
||||
t.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -344,6 +355,28 @@ impl PixivDownloaderSqlite {
|
||||
Ok(tables)
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn get_all_push_tasks(&self) -> Result<Vec<PushTask>, PixivDownloaderDbError> {
|
||||
let con = self.db.lock().await;
|
||||
let mut stmt = con.prepare("SELECT * FROM push_task;")?;
|
||||
let mut rows = stmt.query([])?;
|
||||
let mut tasks = Vec::new();
|
||||
while let Some(row) = rows.next()? {
|
||||
let config: String = row.get(1)?;
|
||||
let config: PushTaskConfig = serde_json::from_str(&config)?;
|
||||
let push_configs: String = row.get(2)?;
|
||||
let push_configs: Vec<PushConfig> = serde_json::from_str(&push_configs)?;
|
||||
tasks.push(PushTask {
|
||||
id: row.get(0)?,
|
||||
config,
|
||||
push_configs,
|
||||
last_updated: row.get(3)?,
|
||||
ttl: row.get(4)?,
|
||||
});
|
||||
}
|
||||
Ok(tasks)
|
||||
}
|
||||
|
||||
async fn get_config(&self, key: &str) -> Result<Option<String>, SqliteError> {
|
||||
let con = self.db.lock().await;
|
||||
Ok(con
|
||||
@@ -405,6 +438,18 @@ impl PixivDownloaderSqlite {
|
||||
.optional2()
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn get_push_task_data(&self, id: u64) -> Result<Option<String>, PixivDownloaderDbError> {
|
||||
let con = self.db.lock().await;
|
||||
Ok(con
|
||||
.query_row(
|
||||
"SELECT data FROM push_task_data WHERE id = ?;",
|
||||
[id],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.optional()?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn get_token(&self, id: u64) -> Result<Option<Token>, SqliteError> {
|
||||
let con = self.db.lock().await;
|
||||
@@ -542,6 +587,19 @@ impl PixivDownloaderSqlite {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
fn _set_push_task_data(
|
||||
ts: &Transaction,
|
||||
id: u64,
|
||||
data: &str,
|
||||
) -> Result<(), PixivDownloaderDbError> {
|
||||
ts.execute(
|
||||
"INSERT OR REPLACE INTO push_task_data (id, data) VALUES (?, ?);",
|
||||
(id, data),
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn _set_user(
|
||||
&self,
|
||||
@@ -746,7 +804,7 @@ impl PixivDownloaderSqlite {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(+Sync)]
|
||||
impl PixivDownloaderDb for PixivDownloaderSqlite {
|
||||
fn new<R: AsRef<PixivDownloaderDbConfig> + ?Sized>(
|
||||
cfg: &R,
|
||||
@@ -896,6 +954,11 @@ impl PixivDownloaderDb for PixivDownloaderSqlite {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn get_all_push_tasks(&self) -> Result<Vec<PushTask>, PixivDownloaderDbError> {
|
||||
Ok(self.get_all_push_tasks().await?)
|
||||
}
|
||||
|
||||
async fn get_config(&self, key: &str) -> Result<Option<String>, PixivDownloaderDbError> {
|
||||
Ok(self.get_config(key).await?)
|
||||
}
|
||||
@@ -938,6 +1001,11 @@ impl PixivDownloaderDb for PixivDownloaderSqlite {
|
||||
Ok(self.get_push_task(id).await?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn get_push_task_data(&self, id: u64) -> Result<Option<String>, PixivDownloaderDbError> {
|
||||
Ok(self.get_push_task_data(id).await?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn get_token(&self, id: u64) -> Result<Option<Token>, PixivDownloaderDbError> {
|
||||
Ok(self.get_token(id).await?)
|
||||
@@ -1029,6 +1097,17 @@ impl PixivDownloaderDb for PixivDownloaderSqlite {
|
||||
Ok(self.get_push_task(id).await?.expect("Task not found:"))
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn set_push_task_data(&self, id: u64, data: &str) -> Result<(), PixivDownloaderDbError> {
|
||||
{
|
||||
let mut db = self.db.lock().await;
|
||||
let tx = db.transaction()?;
|
||||
Self::_set_push_task_data(&tx, id, data)?;
|
||||
tx.commit()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "server")]
|
||||
async fn update_push_task_last_updated(
|
||||
&self,
|
||||
|
||||
@@ -8,7 +8,7 @@ use super::{Token, User};
|
||||
use chrono::{DateTime, Utc};
|
||||
use flagset::FlagSet;
|
||||
|
||||
#[async_trait]
|
||||
#[async_trait(+Sync)]
|
||||
pub trait PixivDownloaderDb {
|
||||
/// Create a new instance of database
|
||||
/// * `cfg` - The database configuration
|
||||
@@ -108,6 +108,9 @@ pub trait PixivDownloaderDb {
|
||||
id: u64,
|
||||
expired_at: &DateTime<Utc>,
|
||||
) -> Result<(), PixivDownloaderDbError>;
|
||||
#[cfg(feature = "server")]
|
||||
/// Get all push tasks
|
||||
async fn get_all_push_tasks(&self) -> Result<Vec<PushTask>, PixivDownloaderDbError>;
|
||||
/// Get a config from database
|
||||
/// * `key` - The config key
|
||||
async fn get_config(&self, key: &str) -> Result<Option<String>, PixivDownloaderDbError>;
|
||||
@@ -143,6 +146,10 @@ pub trait PixivDownloaderDb {
|
||||
/// * `id` - The task's ID
|
||||
async fn get_push_task(&self, id: u64) -> Result<Option<PushTask>, PixivDownloaderDbError>;
|
||||
#[cfg(feature = "server")]
|
||||
/// Get a push task's data
|
||||
/// * `id` - The task's ID
|
||||
async fn get_push_task_data(&self, id: u64) -> Result<Option<String>, PixivDownloaderDbError>;
|
||||
#[cfg(feature = "server")]
|
||||
/// Get token by ID
|
||||
/// * `id` - The token ID
|
||||
async fn get_token(&self, id: u64) -> Result<Option<Token>, PixivDownloaderDbError>;
|
||||
@@ -204,6 +211,11 @@ pub trait PixivDownloaderDb {
|
||||
is_admin: bool,
|
||||
) -> Result<User, PixivDownloaderDbError>;
|
||||
#[cfg(feature = "server")]
|
||||
/// Set push task's data
|
||||
/// * `id`: The task's ID
|
||||
/// * `data`: The task's data
|
||||
async fn set_push_task_data(&self, id: u64, data: &str) -> Result<(), PixivDownloaderDbError>;
|
||||
#[cfg(feature = "server")]
|
||||
/// Update a push task
|
||||
/// * `id`: The task's ID
|
||||
/// * `config`: The task's config
|
||||
|
||||
@@ -127,7 +127,8 @@ impl PushContext {
|
||||
.ok_or((400, "Missing test_send_mode."))?;
|
||||
let test_send_mode: TestSendMode = serde_json::from_str(test_send_mode)
|
||||
.try_err3(400, "Failed to parse test_send_mode:")?;
|
||||
run_push_task(self.ctx.clone(), &task, Some(&test_send_mode))
|
||||
let task = Arc::new(task);
|
||||
run_push_task(self.ctx.clone(), task, Some(&test_send_mode))
|
||||
.await
|
||||
.try_err3(1, "Failed to test push task:")?;
|
||||
Ok(serde_json::to_value(true).try_err3(500, "Failed to serialize result:")?)
|
||||
|
||||
@@ -4,7 +4,11 @@ pub mod pixiv_send_message;
|
||||
use super::super::preclude::*;
|
||||
use crate::db::push_task::PushTaskPixivAction;
|
||||
use crate::db::{PushTask, PushTaskConfig};
|
||||
use crate::task_manager::{MaxCount, TaskManagerWithId};
|
||||
use futures_util::lock::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{interval_at, Duration, Instant};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -52,15 +56,51 @@ impl TestSendMode {
|
||||
|
||||
pub async fn run_push_task(
|
||||
ctx: Arc<ServerContext>,
|
||||
task: &PushTask,
|
||||
task: Arc<PushTask>,
|
||||
send_mode: Option<&TestSendMode>,
|
||||
) -> Result<(), PixivDownloaderError> {
|
||||
match &task.config {
|
||||
PushTaskConfig::Pixiv(config) => match &config.act {
|
||||
PushTaskPixivAction::Follow { restrict, mode } => {
|
||||
pixiv_follow::run_push_task(ctx, task, config, restrict, mode, send_mode).await
|
||||
pixiv_follow::run_push_task(ctx, task.clone(), config, restrict, mode, send_mode)
|
||||
.await
|
||||
}
|
||||
_ => Ok(()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_checking(ctx: Arc<ServerContext>) {
|
||||
let mut interval = interval_at(Instant::now(), Duration::from_secs(1));
|
||||
let manager = TaskManagerWithId::new(Arc::new(Mutex::new(0)), MaxCount::new(4));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
manager.check_task().await;
|
||||
let tasks = manager.take_finished_tasks();
|
||||
for (id, task) in tasks {
|
||||
let re = task.await;
|
||||
if let Ok(Err(e)) = re {
|
||||
log::warn!("Push task error (task id: {}): {}", id, e);
|
||||
} else if let Err(e) = re {
|
||||
log::error!("Join error: {}", e);
|
||||
} else if let Ok(Ok(())) = re {
|
||||
log::debug!("Push task finished: {}", id);
|
||||
}
|
||||
}
|
||||
let all_tasks = match ctx.db.get_all_push_tasks().await {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
log::error!("Get all push tasks error: {}", e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
for task in all_tasks {
|
||||
if task.is_need_update() && !manager.is_pending_or_running(&task.id) {
|
||||
let task = Arc::new(task);
|
||||
manager
|
||||
.add_pending_task(task.id, run_push_task(ctx.clone(), task, None))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use super::pixiv_send_message::send_message;
|
||||
use super::TestSendMode;
|
||||
use crate::db::push_task::{PixivMode, PushTaskPixivConfig};
|
||||
use crate::db::PushTask;
|
||||
use crate::ext::atomic::AtomicQuick;
|
||||
use crate::ext::replace::ReplaceWith2;
|
||||
use crate::ext::rw_lock::GetRwLock;
|
||||
use crate::get_helper;
|
||||
use crate::pixiv_app::PixivRestrictType;
|
||||
@@ -10,6 +12,7 @@ use crate::pixivapp::illust::PixivAppIllust;
|
||||
use crate::utils::parse_pixiv_id;
|
||||
use json::JsonValue;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::RwLock;
|
||||
|
||||
struct PixivFollowData {
|
||||
@@ -34,7 +37,7 @@ impl PixivFollowData {
|
||||
|
||||
struct RunContext<'a> {
|
||||
ctx: Arc<ServerContext>,
|
||||
task: &'a PushTask,
|
||||
task: Arc<PushTask>,
|
||||
config: &'a PushTaskPixivConfig,
|
||||
restrict: &'a PixivRestrictType,
|
||||
mode: &'a PixivMode,
|
||||
@@ -43,12 +46,14 @@ struct RunContext<'a> {
|
||||
use_app_api: bool,
|
||||
use_web_description: bool,
|
||||
use_webpage: bool,
|
||||
pushed: RwLock<Vec<u64>>,
|
||||
first_run: AtomicBool,
|
||||
}
|
||||
|
||||
impl<'a> RunContext<'a> {
|
||||
pub fn new(
|
||||
ctx: Arc<ServerContext>,
|
||||
task: &'a PushTask,
|
||||
task: Arc<PushTask>,
|
||||
config: &'a PushTaskPixivConfig,
|
||||
restrict: &'a PixivRestrictType,
|
||||
mode: &'a PixivMode,
|
||||
@@ -68,15 +73,42 @@ impl<'a> RunContext<'a> {
|
||||
.use_web_description
|
||||
.unwrap_or(helper.use_web_description()),
|
||||
use_webpage: config.use_webpage.unwrap_or(helper.use_webpage()),
|
||||
pushed: RwLock::new(Vec::new()),
|
||||
first_run: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(&self) -> Result<(), PixivDownloaderError> {
|
||||
let now = chrono::Utc::now();
|
||||
if self.send_mode.is_none() {
|
||||
match self.ctx.db.get_push_task_data(self.task.id).await? {
|
||||
Some(data) => match serde_json::from_str(&data) {
|
||||
Ok(data) => {
|
||||
self.pushed.replace_with2(data);
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!(target: "pixiv_follow", "Failed to parse push task data: {}", e);
|
||||
log::debug!(target: "pixiv_follow", "Push task data: {}", data);
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.first_run.qstore(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
if self.use_app_api {
|
||||
self.app_run().await?;
|
||||
} else {
|
||||
self.web_run().await?;
|
||||
}
|
||||
if self.send_mode.is_none() {
|
||||
let data = serde_json::to_string(self.pushed.get_ref().as_slice())?;
|
||||
self.ctx.db.set_push_task_data(self.task.id, &data).await?;
|
||||
self.ctx
|
||||
.db
|
||||
.update_push_task_last_updated(self.task.id, &now)
|
||||
.await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -100,7 +132,19 @@ impl<'a> RunContext<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
None => {
|
||||
if self.first_run.qload() {
|
||||
for i in illusts.members() {
|
||||
if let Some(id) = parse_pixiv_id(&i["id"]) {
|
||||
self.pushed.get_mut().push(id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in illusts.members() {
|
||||
self.web_illust(i, &data).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -111,6 +155,9 @@ impl<'a> RunContext<'a> {
|
||||
data: &JsonValue,
|
||||
) -> Result<(), PixivDownloaderError> {
|
||||
let id = parse_pixiv_id(&illust["id"]).ok_or("illust id is none")?;
|
||||
if self.send_mode.is_none() && self.pushed.get_ref().contains(&id) {
|
||||
return Ok(());
|
||||
}
|
||||
let wdata = match self.data.get_web_data(id) {
|
||||
Some(d) => d,
|
||||
None => {
|
||||
@@ -134,6 +181,9 @@ impl<'a> RunContext<'a> {
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if self.send_mode.is_none() {
|
||||
self.pushed.get_mut().push(id);
|
||||
}
|
||||
for i in self.task.push_configs.iter() {
|
||||
send_message(
|
||||
self.ctx.clone(),
|
||||
@@ -165,13 +215,28 @@ impl<'a> RunContext<'a> {
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {}
|
||||
None => {
|
||||
if self.first_run.qload() {
|
||||
for i in app_data.illusts.iter() {
|
||||
if let Some(id) = i.id() {
|
||||
self.pushed.get_mut().push(id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in app_data.illusts.iter() {
|
||||
self.app_illust(i).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn app_illust(&self, illust: &PixivAppIllust) -> Result<(), PixivDownloaderError> {
|
||||
let id = illust.id().ok_or("illust id is none")?;
|
||||
if self.send_mode.is_none() && self.pushed.get_ref().contains(&id) {
|
||||
return Ok(());
|
||||
}
|
||||
let data = match self.data.get_web_data(id) {
|
||||
Some(d) => Some(d),
|
||||
None => {
|
||||
@@ -189,6 +254,9 @@ impl<'a> RunContext<'a> {
|
||||
}
|
||||
}
|
||||
};
|
||||
if self.send_mode.is_none() {
|
||||
self.pushed.get_mut().push(id);
|
||||
}
|
||||
for i in self.task.push_configs.iter() {
|
||||
send_message(
|
||||
self.ctx.clone(),
|
||||
@@ -207,7 +275,7 @@ impl<'a> RunContext<'a> {
|
||||
|
||||
pub async fn run_push_task(
|
||||
ctx: Arc<ServerContext>,
|
||||
task: &PushTask,
|
||||
task: Arc<PushTask>,
|
||||
config: &PushTaskPixivConfig,
|
||||
restrict: &PixivRestrictType,
|
||||
mode: &PixivMode,
|
||||
|
||||
@@ -97,6 +97,7 @@ pub async fn start_server(
|
||||
) -> Result<Server<AddrIncoming, PixivDownloaderMakeSvc>, hyper::Error> {
|
||||
let ctx = Arc::new(ServerContext::default().await);
|
||||
let ser = Server::try_bind(addr)?.serve(PixivDownloaderMakeSvc::new(&ctx));
|
||||
tokio::spawn(super::timer::start_timer(ctx));
|
||||
tokio::spawn(super::timer::start_timer(ctx.clone()));
|
||||
tokio::task::spawn(super::push::task::run_checking(ctx));
|
||||
Ok(ser)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ use crate::ext::rw_lock::GetRwLock;
|
||||
use crate::opthelper::get_helper;
|
||||
use futures_util::lock::Mutex;
|
||||
use indicatif::MultiProgress;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
@@ -215,3 +217,98 @@ impl<O> Default for TaskManager<O> {
|
||||
Self::new(get_total_download_task_count(), MaxDownloadTasks::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Task manager with ID
|
||||
pub struct TaskManagerWithId<K, T> {
|
||||
/// Current running task
|
||||
tasks: RwLock<HashMap<K, JoinHandle<T>>>,
|
||||
/// Finished task
|
||||
finished_tasks: RwLock<HashMap<K, JoinHandle<T>>>,
|
||||
/// Pending task
|
||||
pedding_tasks: RwLock<Vec<(K, Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>)>>,
|
||||
/// Total task count
|
||||
task_count: Arc<Mutex<usize>>,
|
||||
max_count: Box<dyn GetMaxCount + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<K, O> TaskManagerWithId<K, O>
|
||||
where
|
||||
K: Eq + std::hash::Hash,
|
||||
O: Send + 'static,
|
||||
{
|
||||
/// Create a new instance
|
||||
pub fn new<T: GetMaxCount + Send + Sync + 'static>(
|
||||
task_count: Arc<Mutex<usize>>,
|
||||
max_count: T,
|
||||
) -> Self {
|
||||
Self {
|
||||
tasks: RwLock::new(HashMap::new()),
|
||||
finished_tasks: RwLock::new(HashMap::new()),
|
||||
pedding_tasks: RwLock::new(Vec::new()),
|
||||
task_count,
|
||||
max_count: Box::new(max_count),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add pending task
|
||||
pub async fn add_pending_task<F>(&self, id: K, future: F)
|
||||
where
|
||||
F: Future<Output = O> + Send + Sync + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let total_count = self.max_count.get_max_count();
|
||||
{
|
||||
let mut count = self.task_count.lock().await;
|
||||
if *count < total_count {
|
||||
self.tasks.get_mut().insert(id, tokio::task::spawn(future));
|
||||
count.replace_with(*count + 1);
|
||||
return;
|
||||
}
|
||||
}
|
||||
self.pedding_tasks.get_mut().push((id, Box::pin(future)));
|
||||
}
|
||||
|
||||
/// Check running tasks and run pending tasks
|
||||
pub async fn check_task(&self) {
|
||||
let total_count = self.max_count.get_max_count();
|
||||
let mut count = self.task_count.lock().await;
|
||||
let tasks = self.tasks.replace_with2(HashMap::new());
|
||||
let mut new_tasks = HashMap::new();
|
||||
let mut new_count = *count;
|
||||
for (k, v) in tasks {
|
||||
if v.is_finished() {
|
||||
self.finished_tasks.get_mut().insert(k, v);
|
||||
new_count -= 1;
|
||||
} else {
|
||||
new_tasks.insert(k, v);
|
||||
}
|
||||
}
|
||||
while new_count < total_count {
|
||||
if let Some((k, v)) = self.pedding_tasks.get_mut().pop() {
|
||||
new_tasks.insert(k, tokio::task::spawn(v));
|
||||
new_count += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
self.tasks.replace_with2(new_tasks);
|
||||
count.replace_with(new_count);
|
||||
}
|
||||
|
||||
pub fn is_pending(&self, id: &K) -> bool {
|
||||
self.pedding_tasks.get_ref().iter().any(|(k, _)| k == id)
|
||||
}
|
||||
|
||||
pub fn is_pending_or_running(&self, id: &K) -> bool {
|
||||
self.is_running(id) || self.is_pending(id)
|
||||
}
|
||||
|
||||
pub fn is_running(&self, id: &K) -> bool {
|
||||
self.tasks.get_ref().contains_key(id)
|
||||
}
|
||||
|
||||
/// Take all finished tasks
|
||||
pub fn take_finished_tasks(&self) -> HashMap<K, JoinHandle<O>> {
|
||||
self.finished_tasks.replace_with2(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user