diff --git a/Cargo.lock b/Cargo.lock index 6de10ad..ca9d1c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1169,6 +1169,7 @@ dependencies = [ "memchr", "mozjpeg", "msg_tool_macro", + "num_cpus", "overf", "parse-size", "pelite", diff --git a/Cargo.toml b/Cargo.toml index 46ce3c8..3f622d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ markup5ever_rcdom = { version = "0.35", optional = true } memchr = { version = "2.7", optional = true } mozjpeg = { version = "0.10", optional = true } msg_tool_macro = { version = "0.2.1" } +num_cpus = { version = "1.17", optional = true } overf = "0.1" pelite = { version = "0.10", optional = true } png = { version = "0.18", optional = true } @@ -58,7 +59,7 @@ artemis-arc = ["artemis", "msg_tool_macro/artemis-arc", "sha1"] bgi = ["fancy-regex"] bgi-arc = ["bgi", "rand", "utils-bit-stream"] bgi-audio = ["bgi"] -bgi-img = ["bgi", "image", "rand", "utils-bit-stream"] +bgi-img = ["bgi", "image", "rand", "utils-threadpool", "utils-bit-stream"] cat-system = ["fancy-regex", "flate2", "int-enum"] cat-system-arc = ["cat-system", "pelite", "utils-blowfish", "utils-crc32"] cat-system-img = ["cat-system", "flate2", "image", "mozjpeg", "utils-bit-stream"] @@ -94,6 +95,7 @@ utils-crc32 = [] utils-escape = ["fancy-regex"] utils-pcm = [] utils-str = [] +utils-threadpool = ["num_cpus"] [target.'cfg(windows)'.dependencies] windows-sys = { version = "0", features = ["Win32_Globalization", "Win32_System_Diagnostics_Debug"] } diff --git a/src/args.rs b/src/args.rs index a1dc240..47c2075 100644 --- a/src/args.rs +++ b/src/args.rs @@ -172,6 +172,11 @@ pub struct Arg { /// Whether to create scrambled SysGrp images. When in import mode, the default value depends on the original image. /// When in creation mode, it is not enabled by default. pub bgi_img_scramble: Option, + #[cfg(feature = "bgi-img")] + #[arg(long, global = true, default_value_t = crate::types::get_default_threads())] + /// Workers count for decode BGI compressed images v2 in parallel. Default is half of CPU cores. + /// Set this to 1 to disable parallel decoding. 0 means same as 1. + pub bgi_img_workers: usize, #[cfg(feature = "cat-system-arc")] #[arg(long, global = true, group = "cat_system_int_encrypt_passwordg")] /// CatSystem2 engine int archive password diff --git a/src/ext/mod.rs b/src/ext/mod.rs index 0a318d1..fbe0026 100644 --- a/src/ext/mod.rs +++ b/src/ext/mod.rs @@ -3,6 +3,7 @@ pub mod atomic; #[cfg(feature = "fancy-regex")] pub mod fancy_regex; pub mod io; +pub mod mutex; pub mod path; #[cfg(feature = "emote-psb")] pub mod psb; diff --git a/src/ext/mutex.rs b/src/ext/mutex.rs new file mode 100644 index 0000000..f6aecb8 --- /dev/null +++ b/src/ext/mutex.rs @@ -0,0 +1,19 @@ +//! Extension for [std::sync::Mutex]. +pub trait MutexExt { + /// Lock the mutex, blocking the current thread until it can be acquired. + fn lock_blocking(&self) -> std::sync::MutexGuard<'_, T>; +} + +impl MutexExt for std::sync::Mutex { + fn lock_blocking(&self) -> std::sync::MutexGuard<'_, T> { + loop { + match self.try_lock() { + Ok(guard) => return guard, + Err(std::sync::TryLockError::WouldBlock) => { + std::thread::yield_now(); + } + Err(std::sync::TryLockError::Poisoned(err)) => return err.into_inner(), + } + } + } +} diff --git a/src/main.rs b/src/main.rs index c1edf4f..2231b1d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1925,6 +1925,8 @@ fn main() { kirikiri_title: arg.kirikiri_title, #[cfg(feature = "favorite")] favorite_hcb_filter_ascii: !arg.favorite_hcb_no_filter_ascii, + #[cfg(feature = "bgi-img")] + bgi_img_workers: arg.bgi_img_workers, }; match &arg.command { args::Command::Export { input, output } => { diff --git a/src/scripts/bgi/image/cbg.rs b/src/scripts/bgi/image/cbg.rs index 5072d12..ee31fac 100644 --- a/src/scripts/bgi/image/cbg.rs +++ b/src/scripts/bgi/image/cbg.rs @@ -6,6 +6,7 @@ use crate::types::*; use crate::utils::bit_stream::*; use crate::utils::img::*; use crate::utils::struct_pack::*; +use crate::utils::threadpool::*; use anyhow::Result; use msg_tool_macro::*; use std::io::{Read, Seek, Write}; @@ -133,6 +134,7 @@ pub struct BgiCBG { header: BgiCBGHeader, data: MemReader, color_type: CbgColorType, + decode_workers: usize, } impl BgiCBG { @@ -140,7 +142,7 @@ impl BgiCBG { /// /// * `data` - The buffer containing the script data. /// * `config` - Extra configuration options. - pub fn new(data: Vec, _config: &ExtraConfig) -> Result { + pub fn new(data: Vec, config: &ExtraConfig) -> Result { let mut reader = MemReader::new(data); let mut magic = [0u8; 16]; reader.read_exact(&mut magic)?; @@ -167,6 +169,7 @@ impl BgiCBG { header, data: reader, color_type, + decode_workers: config.bgi_img_workers.max(1), }) } } @@ -185,7 +188,12 @@ impl Script for BgiCBG { } fn export_image(&self) -> Result { - let decoder = CbgDecoder::new(self.data.to_ref(), &self.header, self.color_type)?; + let decoder = CbgDecoder::new( + self.data.to_ref(), + &self.header, + self.color_type, + self.decode_workers, + )?; Ok(decoder.unpack()?) } @@ -209,6 +217,7 @@ struct CbgDecoder<'a> { magic: u32, pixel_size: u8, stride: usize, + workers: usize, } impl<'a> CbgDecoder<'a> { @@ -216,6 +225,7 @@ impl<'a> CbgDecoder<'a> { reader: MemReaderRef<'a>, info: &'a BgiCBGHeader, color_type: CbgColorType, + workers: usize, ) -> Result { let magic = 0; let key = info.key; @@ -230,6 +240,7 @@ impl<'a> CbgDecoder<'a> { color_type, pixel_size, stride, + workers, }) } @@ -334,7 +345,7 @@ impl<'a> CbgDecoder<'a> { has_alpha: AtomicBool::new(false), }); - let mut tasks = Vec::new(); + let thread_pool = ThreadPool::new(self.workers, Some("cbg-decoder-worker-")); let mut dst = 0i32; for i in 0..y_blocks { @@ -347,23 +358,27 @@ impl<'a> CbgDecoder<'a> { let closure_dst = dst; let decoder_ref = Arc::clone(&decoder); - let task = std::thread::spawn(move || { - decoder_ref.unpack_block(block_offset, next_offset - block_offset, closure_dst) - }); - tasks.push(task); + thread_pool.execute( + move || { + decoder_ref.unpack_block(block_offset, next_offset - block_offset, closure_dst) + }, + true, + )?; dst += width * 32; } if self.info.bpp == 32 { let decoder_ref = Arc::clone(&decoder); - let task = - std::thread::spawn(move || decoder_ref.unpack_alpha(offsets[y_blocks as usize])); - tasks.push(task); + thread_pool.execute( + move || decoder_ref.unpack_alpha(offsets[y_blocks as usize]), + true, + )?; } + let tasks = thread_pool.into_results(); + for task in tasks { - task.join() - .map_err(|e| anyhow::anyhow!("Thread join failed: {:?}", e))??; + task?; } let has_alpha = decoder.has_alpha.qload(); diff --git a/src/types.rs b/src/types.rs index 49a4f7c..bd1574f 100644 --- a/src/types.rs +++ b/src/types.rs @@ -426,6 +426,11 @@ pub struct ExtraConfig { #[default(true)] /// Whether to filter ascii strings in Favorite HCB script. pub favorite_hcb_filter_ascii: bool, + #[cfg(feature = "bgi-img")] + #[default(get_default_threads())] + /// Workers count for decode BGI compressed images v2 in parallel. Default is half of CPU cores. + /// Set this to 1 to disable parallel decoding. 0 means same as 1. + pub bgi_img_workers: usize, } #[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq, PartialOrd, Ord)] @@ -915,3 +920,9 @@ impl AsRef for LosslessAudioFormat { } } } + +#[cfg(feature = "utils-threadpool")] +#[allow(unused)] +pub(crate) fn get_default_threads() -> usize { + num_cpus::get().max(2) / 2 +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index d9d82c2..809d30d 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -25,6 +25,8 @@ pub mod pcm; #[cfg(feature = "utils-str")] pub mod str; pub mod struct_pack; +#[cfg(feature = "utils-threadpool")] +pub mod threadpool; #[cfg(windows)] pub use encoding_win::WinError; diff --git a/src/utils/threadpool.rs b/src/utils/threadpool.rs new file mode 100644 index 0000000..05fabdf --- /dev/null +++ b/src/utils/threadpool.rs @@ -0,0 +1,209 @@ +//! Thread pool utilities +use crate::ext::mutex::*; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{ + Arc, Condvar, Mutex, + mpsc::{Receiver, SyncSender, TrySendError, sync_channel}, +}; +use std::thread::{self, JoinHandle}; + +type Job = Box T + Send + 'static>; + +/// A simple generic thread pool. +/// +/// - T: the return type of tasks. Completed task results are stored in `results: Arc>>`. +/// - execute accepts a task and a `block_if_full` flag: +/// * if true, submission will block when the pool is saturated until a worker becomes available; +/// * if false, submission will return an error when the pool is saturated. +/// - join waits until all submitted tasks have completed (it does not shut down the pool). +pub struct ThreadPool { + sender: Option>>, + #[allow(unused)] + receiver: Arc>>>, + workers: Vec>, + /// Completed task results + pub results: Arc>>, + /// Number of pending tasks (queued + running) + pending: Arc, + /// Pair for wait/notify in join + pending_pair: Arc<(Mutex<()>, Condvar)>, + size: usize, +} + +#[derive(Debug)] +/// Error type for [ThreadPool::execute] +pub enum ExecuteError { + /// Pool is full + Full, + /// Pool is closed + Closed, +} + +impl std::error::Error for ExecuteError {} + +impl std::fmt::Display for ExecuteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ExecuteError::Full => write!(f, "ThreadPool is full"), + ExecuteError::Closed => write!(f, "ThreadPool is closed"), + } + } +} + +impl ThreadPool { + pub fn size(&self) -> usize { + self.size + } + + /// Create a new thread pool with `size` workers. + /// The internal submission channel is bounded to `size`, so when all workers are busy and + /// the channel is full, further submissions will block or return error depending on the flag. + /// + /// * `name` - Optional base name for worker threads. If None, "threadpool-worker-" is used. + pub fn new<'a>(size: usize, name: Option<&'a str>) -> Self { + assert!(size > 0, "size must be > 0"); + + let (tx, rx) = sync_channel::>(size); + let receiver = Arc::new(Mutex::new(rx)); + let results = Arc::new(Mutex::new(Vec::new())); + let pending = Arc::new(AtomicUsize::new(0)); + let pending_pair = Arc::new((Mutex::new(()), Condvar::new())); + let thread_name = name.unwrap_or("threadpool-worker-"); + + let mut workers = Vec::with_capacity(size); + for id in 0..size { + let rx_clone = Arc::clone(&receiver); + let results_clone = Arc::clone(&results); + let pending_clone = Arc::clone(&pending); + let pending_pair_clone = Arc::clone(&pending_pair); + + let handle = thread::Builder::new() + .name(format!("{}{}", thread_name, id)) + .spawn(move || { + loop { + // Lock receiver to call recv. Using a Mutex around Receiver serializes + // the recv calls but is fine for this simple implementation. + let job = { + let guard = rx_clone.lock_blocking(); + // If recv returns Err, sender was dropped -> exit thread + guard.recv() + }; + + match job { + Ok(job) => { + // Execute the job and store result + let res = job(); + { + let mut r = results_clone.lock_blocking(); + r.push(res); + } + + // Decrement pending count and notify join waiters + pending_clone.fetch_sub(1, Ordering::SeqCst); + let (lock, cvar) = &*pending_pair_clone; + let _g = lock.lock_blocking(); + cvar.notify_all(); + } + Err(_) => { + // Channel closed -> shutdown worker + break; + } + } + } + }) + .expect("failed to spawn worker thread"); + + workers.push(handle); + } + + ThreadPool { + sender: Some(tx), + receiver, + workers, + results, + pending, + pending_pair, + size, + } + } + + /// Execute a task. If `block_if_full` is true, this call will block when the internal + /// submission channel is full (i.e. all workers busy and buffer full) until space becomes available. + /// If `block_if_full` is false, this returns Err(ExecuteError::Full) when the channel is full. + pub fn execute(&self, job: F, block_if_full: bool) -> Result<(), ExecuteError> + where + F: FnOnce() -> T + Send + 'static, + { + let sender = match &self.sender { + Some(s) => s, + None => return Err(ExecuteError::Closed), + }; + + // Increase pending count for this submission. If submission fails we will decrement. + self.pending.fetch_add(1, Ordering::SeqCst); + + let boxed: Job = Box::new(job); + + if block_if_full { + // This will block until there is space in the bounded channel or the channel is closed. + if sender.send(boxed).is_err() { + // Channel closed + self.pending.fetch_sub(1, Ordering::SeqCst); + return Err(ExecuteError::Closed); + } + Ok(()) + } else { + match sender.try_send(boxed) { + Ok(()) => Ok(()), + Err(TrySendError::Full(_)) => { + // revert pending increment + self.pending.fetch_sub(1, Ordering::SeqCst); + Err(ExecuteError::Full) + } + Err(TrySendError::Disconnected(_)) => { + self.pending.fetch_sub(1, Ordering::SeqCst); + Err(ExecuteError::Closed) + } + } + } + } + + /// Wait until all submitted tasks have completed. This does not shut down the pool; new tasks + /// can still be submitted after join returns. + pub fn join(&self) { + // Fast path + if self.pending.load(Ordering::SeqCst) == 0 { + return; + } + + let (lock, cvar) = &*self.pending_pair; + let mut guard = lock.lock_blocking(); + while self.pending.load(Ordering::SeqCst) != 0 { + guard = match cvar.wait(guard) { + Ok(g) => g, + Err(poisoned) => poisoned.into_inner(), + }; + } + } + + /// Wait until all submitted tasks have completed, then return the results. + pub fn into_results(self) -> Vec { + self.join(); + let mut results = self.results.lock_blocking(); + results.split_off(0) + } +} + +impl Drop for ThreadPool { + fn drop(&mut self) { + // Close sender so worker threads exit recv loop + self.sender.take(); + // Dropping the sender (SyncSender) happens above; but to ensure we close the channel we + // explicitly drop any remaining clones by letting sender go out of scope. + + // Join worker threads + while let Some(handle) = self.workers.pop() { + let _ = handle.join(); + } + } +}