struct HuffmanTree [src]

Fields

max_bit_count: u4
symbol_count_minus_one: u8
nodes: [256]PrefixedSymbol

Members

Source

pub const HuffmanTree = struct { max_bit_count: u4, symbol_count_minus_one: u8, nodes: [256]PrefixedSymbol, pub const PrefixedSymbol = struct { symbol: u8, prefix: u16, weight: u4, }; pub const Result = union(enum) { symbol: u8, index: usize, }; pub fn query(self: HuffmanTree, index: usize, prefix: u16) error{HuffmanTreeIncomplete}!Result { var node = self.nodes[index]; const weight = node.weight; var i: usize = index; while (node.weight == weight) { if (node.prefix == prefix) return .{ .symbol = node.symbol }; if (i == 0) return error.HuffmanTreeIncomplete; i -= 1; node = self.nodes[i]; } return .{ .index = i }; } pub fn weightToBitCount(weight: u4, max_bit_count: u4) u4 { return if (weight == 0) 0 else ((max_bit_count + 1) - weight); } pub const DecodeError = Reader.Error || error{ MalformedHuffmanTree, MalformedFseTable, MalformedAccuracyLog, EndOfStream, MissingStartBit, }; pub fn decode(in: *Reader, remaining: *Limit) HuffmanTree.DecodeError!HuffmanTree { remaining.* = remaining.subtract(1) orelse return error.EndOfStream; const header = try in.takeByte(); if (header < 128) { return decodeFse(in, remaining, header); } else { return decodeDirect(in, remaining, header - 127); } } fn decodeDirect( in: *Reader, remaining: *Limit, encoded_symbol_count: usize, ) HuffmanTree.DecodeError!HuffmanTree { var weights: [256]u4 = undefined; const weights_byte_count = (encoded_symbol_count + 1) / 2; remaining.* = remaining.subtract(weights_byte_count) orelse return error.EndOfStream; for (0..weights_byte_count) |i| { const byte = try in.takeByte(); weights[2 * i] = @as(u4, @intCast(byte >> 4)); weights[2 * i + 1] = @as(u4, @intCast(byte & 0xF)); } const symbol_count = encoded_symbol_count + 1; return build(&weights, symbol_count); } fn decodeFse( in: *Reader, remaining: *Limit, compressed_size: usize, ) HuffmanTree.DecodeError!HuffmanTree { var weights: [256]u4 = undefined; remaining.* = remaining.subtract(compressed_size) orelse return error.EndOfStream; const compressed_buffer = try in.take(compressed_size); var bit_reader: BitReader = .{ .bytes = compressed_buffer }; var entries: [1 << 6]Table.Fse = undefined; const table_size = try Table.decode(&bit_reader, 256, 6, &entries); const accuracy_log = std.math.log2_int_ceil(usize, table_size); const remaining_buffer = bit_reader.bytes[bit_reader.index..]; const symbol_count = try assignWeights(remaining_buffer, accuracy_log, &entries, &weights); return build(&weights, symbol_count); } fn assignWeights( huff_bits_buffer: []const u8, accuracy_log: u16, entries: *[1 << 6]Table.Fse, weights: *[256]u4, ) !usize { var huff_bits = try ReverseBitReader.init(huff_bits_buffer); var i: usize = 0; var even_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log); var odd_state: u32 = try huff_bits.readBitsNoEof(u32, accuracy_log); while (i < 254) { const even_data = entries[even_state]; var read_bits: u16 = 0; const even_bits = huff_bits.readBits(u32, even_data.bits, &read_bits) catch unreachable; weights[i] = std.math.cast(u4, even_data.symbol) orelse return error.MalformedHuffmanTree; i += 1; if (read_bits < even_data.bits) { weights[i] = std.math.cast(u4, entries[odd_state].symbol) orelse return error.MalformedHuffmanTree; i += 1; break; } even_state = even_data.baseline + even_bits; read_bits = 0; const odd_data = entries[odd_state]; const odd_bits = huff_bits.readBits(u32, odd_data.bits, &read_bits) catch unreachable; weights[i] = std.math.cast(u4, odd_data.symbol) orelse return error.MalformedHuffmanTree; i += 1; if (read_bits < odd_data.bits) { if (i == 255) return error.MalformedHuffmanTree; weights[i] = std.math.cast(u4, entries[even_state].symbol) orelse return error.MalformedHuffmanTree; i += 1; break; } odd_state = odd_data.baseline + odd_bits; } else return error.MalformedHuffmanTree; if (!huff_bits.isEmpty()) { return error.MalformedHuffmanTree; } return i + 1; // stream contains all but the last symbol } fn assignSymbols(weight_sorted_prefixed_symbols: []PrefixedSymbol, weights: [256]u4) usize { for (0..weight_sorted_prefixed_symbols.len) |i| { weight_sorted_prefixed_symbols[i] = .{ .symbol = @as(u8, @intCast(i)), .weight = undefined, .prefix = undefined, }; } std.mem.sort( PrefixedSymbol, weight_sorted_prefixed_symbols, weights, lessThanByWeight, ); var prefix: u16 = 0; var prefixed_symbol_count: usize = 0; var sorted_index: usize = 0; const symbol_count = weight_sorted_prefixed_symbols.len; while (sorted_index < symbol_count) { var symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; const weight = weights[symbol]; if (weight == 0) { sorted_index += 1; continue; } while (sorted_index < symbol_count) : ({ sorted_index += 1; prefixed_symbol_count += 1; prefix += 1; }) { symbol = weight_sorted_prefixed_symbols[sorted_index].symbol; if (weights[symbol] != weight) { prefix = ((prefix - 1) >> (weights[symbol] - weight)) + 1; break; } weight_sorted_prefixed_symbols[prefixed_symbol_count].symbol = symbol; weight_sorted_prefixed_symbols[prefixed_symbol_count].prefix = prefix; weight_sorted_prefixed_symbols[prefixed_symbol_count].weight = weight; } } return prefixed_symbol_count; } fn build(weights: *[256]u4, symbol_count: usize) error{MalformedHuffmanTree}!HuffmanTree { var weight_power_sum_big: u32 = 0; for (weights[0 .. symbol_count - 1]) |value| { weight_power_sum_big += (@as(u16, 1) << value) >> 1; } if (weight_power_sum_big >= 1 << 11) return error.MalformedHuffmanTree; const weight_power_sum = @as(u16, @intCast(weight_power_sum_big)); // advance to next power of two (even if weight_power_sum is a power of 2) // TODO: is it valid to have weight_power_sum == 0? const max_number_of_bits = if (weight_power_sum == 0) 1 else std.math.log2_int(u16, weight_power_sum) + 1; const next_power_of_two = @as(u16, 1) << max_number_of_bits; weights[symbol_count - 1] = std.math.log2_int(u16, next_power_of_two - weight_power_sum) + 1; var weight_sorted_prefixed_symbols: [256]PrefixedSymbol = undefined; const prefixed_symbol_count = assignSymbols(weight_sorted_prefixed_symbols[0..symbol_count], weights.*); const tree: HuffmanTree = .{ .max_bit_count = max_number_of_bits, .symbol_count_minus_one = @as(u8, @intCast(prefixed_symbol_count - 1)), .nodes = weight_sorted_prefixed_symbols, }; return tree; } fn lessThanByWeight( weights: [256]u4, lhs: PrefixedSymbol, rhs: PrefixedSymbol, ) bool { // NOTE: this function relies on the use of a stable sorting algorithm, // otherwise a special case of if (weights[lhs] == weights[rhs]) return lhs < rhs; // should be added return weights[lhs.symbol] < weights[rhs.symbol]; } }