use arithmetic_coding::{Decoder, Encoder, Model}; use bitstream_io::{BigEndian, BitRead, BitReader, BitWrite, BitWriter}; use symbolic::Symbol; const PRECISION: u32 = 12; mod integer { use std::ops::Range; pub struct Model; #[derive(Debug, thiserror::Error)] #[error("invalid symbol: {0}")] pub struct Error(u8); impl arithmetic_coding::Model for Model { type B = u32; type Symbol = u8; type ValueError = Error; fn probability(&self, symbol: Option<&Self::Symbol>) -> Result, Error> { match symbol { None => Ok(0..1), Some(&1) => Ok(1..2), Some(&2) => Ok(2..3), Some(&3) => Ok(2..4), Some(x) => Err(Error(*x)), } } fn symbol(&self, value: u32) -> Option { match value { 0..1 => None, 1..2 => Some(1), 2..3 => Some(2), 3..4 => Some(3), _ => unreachable!(), } } fn max_denominator(&self) -> u32 { 4 } } } mod symbolic { use std::{convert::Infallible, ops::Range}; #[derive(Debug, PartialEq, Eq)] pub enum Symbol { A, B, C, } pub struct Model; impl arithmetic_coding::Model for Model { type B = u32; type Symbol = Symbol; type ValueError = Infallible; fn probability(&self, symbol: Option<&Self::Symbol>) -> Result, Infallible> { Ok(match symbol { None => 0..1, Some(&Symbol::A) => 1..2, Some(&Symbol::B) => 2..3, Some(&Symbol::C) => 3..4, }) } fn symbol(&self, value: u32) -> Option { match value { 0..1 => None, 1..2 => Some(Symbol::A), 2..3 => Some(Symbol::B), 3..4 => Some(Symbol::C), _ => unreachable!(), } } fn max_denominator(&self) -> u32 { 4 } } } #[test] fn round_trip() { let input1 = vec![Symbol::A, Symbol::B, Symbol::C]; let input2 = vec![2, 1, 1, 2, 2]; let buffer = encode2(symbolic::Model, &input1, integer::Model, &input2); let (output1, output2) = decode2(symbolic::Model, integer::Model, &buffer); assert_eq!(input1, output1); assert_eq!(input2, output2); } /// Encode two sets of symbols in sequence fn encode2(model1: M, input1: &[M::Symbol], model2: N, input2: &[N::Symbol]) -> Vec where M: Model, N: Model, { let mut bitwriter = BitWriter::endian(Vec::default(), BigEndian); let mut encoder1 = Encoder::with_precision(model1, &mut bitwriter, PRECISION); encode(&mut encoder1, input1); let mut encoder2 = encoder1.chain(model2); encode(&mut encoder2, input2); encoder2.flush().unwrap(); bitwriter.byte_align().unwrap(); bitwriter.into_writer() } /// Encode all symbols, followed by EOF. Doesn't flush the encoder (allowing /// more bits to be concatenated) fn encode(encoder: &mut Encoder, input: &[M::Symbol]) where M: Model, W: BitWrite, { for symbol in input { encoder.encode(Some(symbol)).unwrap(); } encoder.encode(None).unwrap(); } /// Decode two sets of symbols, in sequence fn decode2(model1: M, model2: N, buffer: &[u8]) -> (Vec, Vec) where M: Model, N: Model, { let bitreader = BitReader::endian(buffer, BigEndian); let mut decoder1 = Decoder::with_precision(model1, bitreader, PRECISION); let output1 = decode(&mut decoder1); let mut decoder2 = decoder1.chain(model2); let output2 = decode(&mut decoder2); (output1, output2) } /// Decode all symbols from a [`Decoder`] until EOF is reached fn decode(decoder: &mut Decoder) -> Vec where M: Model, R: BitRead, { decoder.decode_all().map(Result::unwrap).collect() }