Skip to content

Commit

Permalink
feat: Add Session to expose decoding steps
Browse files Browse the repository at this point in the history
  • Loading branch information
inflation committed Sep 6, 2024
1 parent 2fd7fb6 commit b4b7098
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ sweep.timestamp
build/
lcov.info
.pijul
compile_commands.json
.cache/
70 changes: 46 additions & 24 deletions jpegxl-rs/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ use crate::{
utils::check_valid_signature,
};

mod event;
pub use event::*;
mod result;
pub use result::*;
mod session;
pub use session::*;

/// Basic information
pub type BasicInfo = JxlBasicInfo;
Expand Down Expand Up @@ -87,7 +91,7 @@ impl Default for PixelFormat {
pub struct JxlDecoder<'pr, 'mm> {
/// Opaque pointer to the underlying decoder
#[builder(setter(skip))]
dec: *mut jpegxl_sys::decode::JxlDecoder,
ptr: *mut jpegxl_sys::decode::JxlDecoder,

/// Override desired pixel format
pub pixel_format: Option<PixelFormat>,
Expand Down Expand Up @@ -161,7 +165,7 @@ impl<'pr, 'mm> JxlDecoderBuilder<'pr, 'mm> {
///
/// # Errors
/// Return [`DecodeError::CannotCreateDecoder`] if it fails to create the decoder.
pub fn build(&self) -> Result<JxlDecoder<'pr, 'mm>, DecodeError> {
pub fn build(&mut self) -> Result<JxlDecoder<'pr, 'mm>, DecodeError> {
let mm = self.memory_manager.flatten();
let dec = unsafe {
mm.map_or_else(
Expand All @@ -175,7 +179,7 @@ impl<'pr, 'mm> JxlDecoderBuilder<'pr, 'mm> {
}

Ok(JxlDecoder {
dec,
ptr: dec,
pixel_format: self.pixel_format.flatten(),
skip_reorientation: self.skip_reorientation.flatten(),
unpremul_alpha: self.unpremul_alpha.flatten(),
Expand Down Expand Up @@ -217,22 +221,22 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
let next_in = data.as_ptr();
let avail_in = std::mem::size_of_val(data) as _;

check_dec_status(unsafe { JxlDecoderSetInput(self.dec, next_in, avail_in) })?;
unsafe { JxlDecoderCloseInput(self.dec) };
check_dec_status(unsafe { JxlDecoderSetInput(self.ptr, next_in, avail_in) })?;
unsafe { JxlDecoderCloseInput(self.ptr) };

let mut status;
loop {
use JxlDecoderStatus as s;

status = unsafe { JxlDecoderProcessInput(self.dec) };
status = unsafe { JxlDecoderProcessInput(self.ptr) };

match status {
s::NeedMoreInput | s::Error => return Err(DecodeError::GenericError),

// Get the basic info
s::BasicInfo => {
check_dec_status(unsafe {
JxlDecoderGetBasicInfo(self.dec, basic_info.as_mut_ptr())
JxlDecoderGetBasicInfo(self.ptr, basic_info.as_mut_ptr())
})?;

if let Some(pr) = self.parallel_runner {
Expand All @@ -252,7 +256,7 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
let buf = unsafe { reconstruct_jpeg_buffer.as_mut().unwrap_unchecked() };
buf.resize(self.init_jpeg_buffer, 0);
check_dec_status(unsafe {
JxlDecoderSetJPEGBuffer(self.dec, buf.as_mut_ptr(), buf.len())
JxlDecoderSetJPEGBuffer(self.ptr, buf.as_mut_ptr(), buf.len())
})?;
}

Expand All @@ -261,11 +265,11 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
// Safety: JpegNeedMoreOutput is only called when reconstruct_jpeg_buffer
// is not None
let buf = unsafe { reconstruct_jpeg_buffer.as_mut().unwrap_unchecked() };
let need_to_write = unsafe { JxlDecoderReleaseJPEGBuffer(self.dec) };
let need_to_write = unsafe { JxlDecoderReleaseJPEGBuffer(self.ptr) };

buf.resize(buf.len() + need_to_write, 0);
check_dec_status(unsafe {
JxlDecoderSetJPEGBuffer(self.dec, buf.as_mut_ptr(), buf.len())
JxlDecoderSetJPEGBuffer(self.ptr, buf.as_mut_ptr(), buf.len())
})?;
}

Expand All @@ -277,13 +281,13 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
s::FullImage => continue,
s::Success => {
if let Some(buf) = reconstruct_jpeg_buffer.as_mut() {
let remaining = unsafe { JxlDecoderReleaseJPEGBuffer(self.dec) };
let remaining = unsafe { JxlDecoderReleaseJPEGBuffer(self.ptr) };

buf.truncate(buf.len() - remaining);
buf.shrink_to_fit();
}

unsafe { JxlDecoderReset(self.dec) };
unsafe { JxlDecoderReset(self.ptr) };

let info = unsafe { basic_info.assume_init() };
return Ok(Metadata {
Expand Down Expand Up @@ -312,7 +316,7 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
fn setup_decoder(&self, icc: bool, reconstruct_jpeg: bool) -> Result<(), DecodeError> {
if let Some(runner) = self.parallel_runner {
check_dec_status(unsafe {
JxlDecoderSetParallelRunner(self.dec, runner.runner(), runner.as_opaque_ptr())
JxlDecoderSetParallelRunner(self.ptr, runner.runner(), runner.as_opaque_ptr())
})?;
}

Expand All @@ -329,22 +333,22 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {

events
};
check_dec_status(unsafe { JxlDecoderSubscribeEvents(self.dec, events) })?;
check_dec_status(unsafe { JxlDecoderSubscribeEvents(self.ptr, events) })?;

if let Some(val) = self.skip_reorientation {
check_dec_status(unsafe { JxlDecoderSetKeepOrientation(self.dec, val.into()) })?;
check_dec_status(unsafe { JxlDecoderSetKeepOrientation(self.ptr, val.into()) })?;
}
if let Some(val) = self.unpremul_alpha {
check_dec_status(unsafe { JxlDecoderSetUnpremultiplyAlpha(self.dec, val.into()) })?;
check_dec_status(unsafe { JxlDecoderSetUnpremultiplyAlpha(self.ptr, val.into()) })?;
}
if let Some(val) = self.render_spotcolors {
check_dec_status(unsafe { JxlDecoderSetRenderSpotcolors(self.dec, val.into()) })?;
check_dec_status(unsafe { JxlDecoderSetRenderSpotcolors(self.ptr, val.into()) })?;
}
if let Some(val) = self.coalescing {
check_dec_status(unsafe { JxlDecoderSetCoalescing(self.dec, val.into()) })?;
check_dec_status(unsafe { JxlDecoderSetCoalescing(self.ptr, val.into()) })?;
}
if let Some(val) = self.desired_intensity_target {
check_dec_status(unsafe { JxlDecoderSetDesiredIntensityTarget(self.dec, val) })?;
check_dec_status(unsafe { JxlDecoderSetDesiredIntensityTarget(self.ptr, val) })?;
}

Ok(())
Expand All @@ -353,13 +357,13 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {
fn get_icc_profile(&self, icc_profile: &mut Vec<u8>) -> Result<(), DecodeError> {
let mut icc_size = 0;
check_dec_status(unsafe {
JxlDecoderGetICCProfileSize(self.dec, JxlColorProfileTarget::Data, &mut icc_size)
JxlDecoderGetICCProfileSize(self.ptr, JxlColorProfileTarget::Data, &mut icc_size)
})?;
icc_profile.resize(icc_size, 0);

check_dec_status(unsafe {
JxlDecoderGetColorAsICCProfile(
self.dec,
self.ptr,
JxlColorProfileTarget::Data,
icc_profile.as_mut_ptr(),
icc_size,
Expand Down Expand Up @@ -401,18 +405,36 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {

let mut size = 0;
check_dec_status(unsafe {
JxlDecoderImageOutBufferSize(self.dec, &pixel_format, &mut size)
JxlDecoderImageOutBufferSize(self.ptr, &pixel_format, &mut size)
})?;
pixels.resize(size, 0);

check_dec_status(unsafe {
JxlDecoderSetImageOutBuffer(self.dec, &pixel_format, pixels.as_mut_ptr().cast(), size)
JxlDecoderSetImageOutBuffer(self.ptr, &pixel_format, pixels.as_mut_ptr().cast(), size)
})?;

unsafe { *format = pixel_format };
Ok(())
}

/// Start a new decoding session.
///
/// Later event will overwrite the previous one if they are the same.
///
/// # Arguments
///
/// * `events` - The events to subscribe to during the session.
///
/// # Errors
///
/// Returns a [`DecodeError`] if the decoding session encounters an error.
pub fn session<I>(&mut self, events: I) -> Result<Session<'_, 'pr, 'mm>, DecodeError>
where
I: IntoIterator<Item = Event>,
{
Session::new(self, events)
}

/// Decode a JPEG XL image
///
/// # Errors
Expand Down Expand Up @@ -496,7 +518,7 @@ impl<'pr, 'mm> JxlDecoder<'pr, 'mm> {

impl<'prl, 'mm> Drop for JxlDecoder<'prl, 'mm> {
fn drop(&mut self) {
unsafe { JxlDecoderDestroy(self.dec) };
unsafe { JxlDecoderDestroy(self.ptr) };
}
}

Expand Down
Loading

0 comments on commit b4b7098

Please sign in to comment.