Fix huffman encode problem

This commit is contained in:
2025-08-05 15:28:05 +08:00
parent 9e26c01421
commit d9cf2f9202

View File

@@ -241,26 +241,70 @@ impl Ord for FreqNode {
}
fn calculate_huffman_depths(freqs: &[u32]) -> Vec<u8> {
let mut heap = BinaryHeap::new();
for (symbol, &freq) in freqs.iter().enumerate() {
if freq > 0 {
heap.push(FreqNode {
freq,
symbol: Some(symbol as u16),
left: None,
right: None,
});
}
}
const MAX_DEPTH: u8 = 9;
if heap.len() <= 1 {
let mut depths = vec![0; 512];
if let Some(node) = heap.pop() {
depths[node.symbol.unwrap() as usize] = 1;
}
// 收集所有非零频率的符号
let mut symbols_with_freq: Vec<(u16, u32)> = freqs
.iter()
.enumerate()
.filter_map(|(symbol, &freq)| {
if freq > 0 {
Some((symbol as u16, freq))
} else {
None
}
})
.collect();
let mut depths = vec![0u8; 512];
if symbols_with_freq.is_empty() {
return depths;
}
if symbols_with_freq.len() == 1 {
depths[symbols_with_freq[0].0 as usize] = 1;
return depths;
}
// 使用受限Huffman算法
loop {
let current_depths = build_huffman_tree(&symbols_with_freq);
let max_depth = current_depths.iter().max().copied().unwrap_or(0);
if max_depth <= MAX_DEPTH {
// 将深度映射回原始数组
for &(symbol, _) in &symbols_with_freq {
let symbol_index = symbols_with_freq
.iter()
.position(|(s, _)| *s == symbol)
.unwrap();
depths[symbol as usize] = current_depths[symbol_index];
}
break;
}
// 如果深度超限,调整频率
adjust_frequencies_for_depth_limit(&mut symbols_with_freq);
}
depths
}
fn build_huffman_tree(symbols_with_freq: &[(u16, u32)]) -> Vec<u8> {
let mut heap = BinaryHeap::new();
// 添加所有叶子节点
for &(symbol, freq) in symbols_with_freq {
heap.push(FreqNode {
freq,
symbol: Some(symbol),
left: None,
right: None,
});
}
// 构建Huffman树
while heap.len() > 1 {
let node1 = heap.pop().unwrap();
let node2 = heap.pop().unwrap();
@@ -273,29 +317,53 @@ fn calculate_huffman_depths(freqs: &[u32]) -> Vec<u8> {
heap.push(new_node);
}
let mut depths = vec![0; 512];
// 计算深度
let mut depths = vec![0u8; symbols_with_freq.len()];
if let Some(root) = heap.pop() {
fn traverse(node: &FreqNode, depth: u8, depths: &mut [u8]) {
if let Some(symbol) = node.symbol {
if depth == 0 {
depths[symbol as usize] = 1;
} else {
depths[symbol as usize] = depth;
}
} else {
if let Some(ref left) = node.left {
traverse(left, depth + 1, depths);
}
if let Some(ref right) = node.right {
traverse(right, depth + 1, depths);
}
}
}
traverse(&root, 0, &mut depths);
calculate_depths(&root, 0, symbols_with_freq, &mut depths);
}
depths
}
fn calculate_depths(
node: &FreqNode,
depth: u8,
symbols_with_freq: &[(u16, u32)],
depths: &mut [u8],
) {
if let Some(symbol) = node.symbol {
let symbol_index = symbols_with_freq
.iter()
.position(|(s, _)| *s == symbol)
.unwrap();
depths[symbol_index] = if depth == 0 { 1 } else { depth };
} else {
if let Some(ref left) = node.left {
calculate_depths(left, depth + 1, symbols_with_freq, depths);
}
if let Some(ref right) = node.right {
calculate_depths(right, depth + 1, symbols_with_freq, depths);
}
}
}
fn adjust_frequencies_for_depth_limit(symbols_with_freq: &mut [(u16, u32)]) {
// 按频率排序
symbols_with_freq.sort_by(|a, b| a.1.cmp(&b.1));
// 使用Package-Merge算法的简化版本
// 这里使用一个启发式方法:增加低频符号的频率
let min_freq = symbols_with_freq[0].1;
let adjustment = (min_freq as f64 * 0.1).max(1.0) as u32;
// 找到频率最低的几个符号并调整它们的频率
let num_to_adjust = (symbols_with_freq.len() / 4).max(1);
for i in 0..num_to_adjust.min(symbols_with_freq.len()) {
symbols_with_freq[i].1 += adjustment;
}
}
fn generate_canonical_codes(depths: &[u8]) -> Vec<Option<(u16, u8)>> {
let mut codes_with_depths = vec![];
for (symbol, &depth) in depths.iter().enumerate() {
@@ -575,7 +643,7 @@ impl Script for Dsc {
}
fn custom_output_extension(&self) -> &'static str {
"unk"
""
}
fn custom_export(&self, filename: &std::path::Path, _encoding: Encoding) -> Result<()> {