diff --git a/crates/polars-arrow/src/io/ipc/read/stream.rs b/crates/polars-arrow/src/io/ipc/read/stream.rs index 357c5da29168..6248b4d90afe 100644 --- a/crates/polars-arrow/src/io/ipc/read/stream.rs +++ b/crates/polars-arrow/src/io/ipc/read/stream.rs @@ -29,7 +29,7 @@ pub struct StreamMetadata { } /// Reads the metadata of the stream -pub fn read_stream_metadata(reader: &mut R) -> PolarsResult { +pub fn read_stream_metadata(reader: &mut dyn std::io::Read) -> PolarsResult { // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; reader.read_exact(&mut meta_size)?; @@ -48,10 +48,7 @@ pub fn read_stream_metadata(reader: &mut R) -> PolarsResult for DataFrame { @@ -51,3 +51,71 @@ impl DataFrame { } } } + +/// Split DataFrame into chunks in preparation for writing. The chunks have a +/// maximum number of rows per chunk to ensure reasonable memory efficiency when +/// reading the resulting file, and a minimum size per chunk to ensure +/// reasonable performance when writing. +pub fn chunk_df_for_writing( + df: &mut DataFrame, + row_group_size: usize, +) -> PolarsResult> { + // ensures all chunks are aligned. + df.align_chunks_par(); + + // Accumulate many small chunks to the row group size. + // See: #16403 + if !df.get_columns().is_empty() + && df.get_columns()[0] + .as_materialized_series() + .chunk_lengths() + .take(5) + .all(|len| len < row_group_size) + { + fn finish(scratch: &mut Vec, new_chunks: &mut Vec) { + let mut new = accumulate_dataframes_vertical_unchecked(scratch.drain(..)); + new.as_single_chunk_par(); + new_chunks.push(new); + } + + let mut new_chunks = Vec::with_capacity(df.first_col_n_chunks()); // upper limit; + let mut scratch = vec![]; + let mut remaining = row_group_size; + + for df in df.split_chunks() { + remaining = remaining.saturating_sub(df.height()); + scratch.push(df); + + if remaining == 0 { + remaining = row_group_size; + finish(&mut scratch, &mut new_chunks); + } + } + if !scratch.is_empty() { + finish(&mut scratch, &mut new_chunks); + } + return Ok(std::borrow::Cow::Owned( + accumulate_dataframes_vertical_unchecked(new_chunks), + )); + } + + let n_splits = df.height() / row_group_size; + let result = if n_splits > 0 { + let mut splits = split_df_as_ref(df, n_splits, false); + + for df in splits.iter_mut() { + // If the chunks are small enough, writing many small chunks + // leads to slow writing performance, so in that case we + // merge them. + let n_chunks = df.first_col_n_chunks(); + if n_chunks > 1 && (df.estimated_size() / n_chunks < 128 * 1024) { + df.as_single_chunk_par(); + } + } + + std::borrow::Cow::Owned(accumulate_dataframes_vertical_unchecked(splits)) + } else { + std::borrow::Cow::Borrowed(df) + }; + Ok(result) +} diff --git a/crates/polars-core/src/frame/column/mod.rs b/crates/polars-core/src/frame/column/mod.rs index 125fc62bb99e..a1426bd225cb 100644 --- a/crates/polars-core/src/frame/column/mod.rs +++ b/crates/polars-core/src/frame/column/mod.rs @@ -871,6 +871,19 @@ impl Column { } } + /// Returns whether the flags were set + pub fn set_flags(&mut self, flags: StatisticsFlags) -> bool { + match self { + Column::Series(s) => { + s.set_flags(flags); + true + }, + // @partition-opt + Column::Partitioned(_) => false, + Column::Scalar(_) => false, + } + } + pub fn vec_hash(&self, build_hasher: PlRandomState, buf: &mut Vec) -> PolarsResult<()> { // @scalar-opt? self.as_materialized_series().vec_hash(build_hasher, buf) diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 35df9e3ba96f..1a057ddef800 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -20,6 +20,7 @@ use crate::{HEAD_DEFAULT_LENGTH, TAIL_DEFAULT_LENGTH}; #[cfg(feature = "dataframe_arithmetic")] mod arithmetic; mod chunks; +pub use chunks::chunk_df_for_writing; pub mod column; pub mod explode; mod from; @@ -3578,41 +3579,4 @@ mod test { assert_eq!(df.get_column_names(), &["a", "b", "c"]); Ok(()) } - - #[cfg(feature = "serde")] - #[test] - fn test_deserialize_height_validation_8751() { - // Construct an invalid directly from the inner fields as the `new_unchecked_*` functions - // have debug assertions - - use polars_utils::pl_serialize; - - let df = DataFrame { - height: 2, - columns: vec![ - Int64Chunked::full("a".into(), 1, 2).into_column(), - Int64Chunked::full("b".into(), 1, 1).into_column(), - ], - cached_schema: OnceLock::new(), - }; - - // We rely on the fact that the serialization doesn't check the heights of all columns - let serialized = serde_json::to_string(&df).unwrap(); - let err = serde_json::from_str::(&serialized).unwrap_err(); - - assert!(err.to_string().contains( - "successful parse invalid data: lengths don't match: could not create a new DataFrame:", - )); - - let serialized = pl_serialize::SerializeOptions::default() - .serialize_to_bytes(&df) - .unwrap(); - let err = pl_serialize::SerializeOptions::default() - .deserialize_from_reader::(serialized.as_slice()) - .unwrap_err(); - - assert!(err.to_string().contains( - "successful parse invalid data: lengths don't match: could not create a new DataFrame:", - )); - } } diff --git a/crates/polars-core/src/serde/df.rs b/crates/polars-core/src/serde/df.rs index 52d6a0ee6eae..052f606a76d8 100644 --- a/crates/polars-core/src/serde/df.rs +++ b/crates/polars-core/src/serde/df.rs @@ -1,35 +1,144 @@ -use polars_error::PolarsError; +use std::sync::Arc; + +use arrow::datatypes::Metadata; +use arrow::io::ipc::read::{read_stream_metadata, StreamReader, StreamState}; +use arrow::io::ipc::write::WriteOptions; +use polars_error::{polars_err, to_compute_err, PolarsResult}; +use polars_utils::format_pl_smallstr; +use polars_utils::pl_serialize::deserialize_map_bytes; +use polars_utils::pl_str::PlSmallStr; use serde::de::Error; use serde::*; -use crate::prelude::{Column, DataFrame}; - -// utility to ensure we serde to a struct -// { -// columns: Vec -// } -// that ensures it differentiates between Vec -// and is backwards compatible -#[derive(Deserialize)] -struct Util { - columns: Vec, -} +use crate::chunked_array::flags::StatisticsFlags; +use crate::config; +use crate::frame::chunk_df_for_writing; +use crate::prelude::{CompatLevel, DataFrame, SchemaExt}; +use crate::utils::accumulate_dataframes_vertical_unchecked; -#[derive(Serialize)] -struct UtilBorrowed<'a> { - columns: &'a [Column], -} +const FLAGS_KEY: PlSmallStr = PlSmallStr::from_static("_PL_FLAGS"); -impl<'de> Deserialize<'de> for DataFrame { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let parsed = ::deserialize(deserializer)?; - DataFrame::new(parsed.columns).map_err(|e| { - let e = PolarsError::ComputeError(format!("successful parse invalid data: {e}").into()); - D::Error::custom::(e) - }) +impl DataFrame { + pub fn serialize_into_writer(&mut self, writer: &mut dyn std::io::Write) -> PolarsResult<()> { + let schema = self.schema(); + + if schema.iter_values().any(|x| x.is_object()) { + return Err(polars_err!( + ComputeError: + "serializing data of type Object is not supported", + )); + } + + let mut ipc_writer = + arrow::io::ipc::write::StreamWriter::new(writer, WriteOptions { compression: None }); + + ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from_iter( + self.get_columns().iter().map(|c| { + ( + format_pl_smallstr!("{}{}", FLAGS_KEY, c.name()), + PlSmallStr::from(c.get_flags().bits().to_string()), + ) + }), + ))); + + ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from([( + FLAGS_KEY, + serde_json::to_string( + &self + .iter() + .map(|s| s.get_flags().bits()) + .collect::>(), + ) + .map_err(to_compute_err)? + .into(), + )]))); + + ipc_writer.start(&schema.to_arrow(CompatLevel::newest()), None)?; + + for batch in chunk_df_for_writing(self, 512 * 512)?.iter_chunks(CompatLevel::newest(), true) + { + ipc_writer.write(&batch, None)?; + } + + ipc_writer.finish()?; + + Ok(()) + } + + pub fn serialize_to_bytes(&mut self) -> PolarsResult> { + let mut buf = vec![]; + self.serialize_into_writer(&mut buf)?; + + Ok(buf) + } + + pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult { + let mut md = read_stream_metadata(reader)?; + let arrow_schema = md.schema.clone(); + + let custom_metadata = md.custom_schema_metadata.take(); + + let reader = StreamReader::new(reader, md, None); + let dfs = reader + .into_iter() + .map_while(|batch| match batch { + Ok(StreamState::Some(batch)) => Some(DataFrame::try_from((batch, &arrow_schema))), + Ok(StreamState::Waiting) => None, + Err(e) => Some(Err(e)), + }) + .collect::>>()?; + + let mut df = accumulate_dataframes_vertical_unchecked(dfs); + + // Set custom metadata (fallible) + (|| { + let custom_metadata = custom_metadata?; + let flags = custom_metadata.get(&FLAGS_KEY)?; + + let flags: PolarsResult> = serde_json::from_str(flags).map_err(to_compute_err); + + let verbose = config::verbose(); + + if let Err(e) = &flags { + if verbose { + eprintln!("DataFrame::read_ipc: Error parsing metadata flags: {}", e); + } + } + + let flags = flags.ok()?; + + if flags.len() != df.width() { + if verbose { + eprintln!( + "DataFrame::read_ipc: Metadata flags width mismatch: {} != {}", + flags.len(), + df.width() + ); + } + + return None; + } + + let mut n_set = 0; + + for (c, v) in unsafe { df.get_columns_mut() }.iter_mut().zip(flags) { + if let Some(flags) = StatisticsFlags::from_bits(v) { + n_set += c.set_flags(flags) as usize; + } + } + + if verbose { + eprintln!( + "DataFrame::read_ipc: Loaded metadata for {} / {} columns", + n_set, + df.width() + ); + } + + Some(()) + })(); + + Ok(df) } } @@ -38,9 +147,26 @@ impl Serialize for DataFrame { where S: Serializer, { - UtilBorrowed { - columns: &self.columns, - } - .serialize(serializer) + use serde::ser::Error; + + let mut bytes = vec![]; + self.clone() + .serialize_into_writer(&mut bytes) + .map_err(S::Error::custom)?; + + serializer.serialize_bytes(bytes.as_slice()) + } +} + +impl<'de> Deserialize<'de> for DataFrame { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserialize_map_bytes(deserializer, &mut |b| { + let v = &mut b.as_ref(); + Self::deserialize_from_reader(v) + })? + .map_err(D::Error::custom) } } diff --git a/crates/polars-core/src/serde/series.rs b/crates/polars-core/src/serde/series.rs index a14778d250da..0df7c19fde2a 100644 --- a/crates/polars-core/src/serde/series.rs +++ b/crates/polars-core/src/serde/series.rs @@ -1,17 +1,38 @@ -use std::fmt::Formatter; - -use arrow::datatypes::Metadata; -use arrow::io::ipc::read::{read_stream_metadata, StreamReader, StreamState}; -use arrow::io::ipc::write::WriteOptions; -use serde::de::{Error as DeError, Visitor}; +use polars_utils::pl_serialize::deserialize_map_bytes; +use serde::de::Error; use serde::{Deserialize, Deserializer, Serialize, Serializer}; -use crate::chunked_array::flags::StatisticsFlags; -use crate::config; use crate::prelude::*; -use crate::utils::accumulate_dataframes_vertical; -const FLAGS_KEY: PlSmallStr = PlSmallStr::from_static("_PL_FLAGS"); +impl Series { + pub fn serialize_into_writer(&self, writer: &mut dyn std::io::Write) -> PolarsResult<()> { + let mut df = + unsafe { DataFrame::new_no_checks_height_from_first(vec![self.clone().into_column()]) }; + + df.serialize_into_writer(writer) + } + + pub fn serialize_to_bytes(&self) -> PolarsResult> { + let mut buf = vec![]; + self.serialize_into_writer(&mut buf)?; + + Ok(buf) + } + + pub fn deserialize_from_reader(reader: &mut dyn std::io::Read) -> PolarsResult { + let df = DataFrame::deserialize_from_reader(reader)?; + + if df.width() != 1 { + polars_bail!( + ShapeMismatch: + "expected only 1 column when deserializing Series from IPC, got columns: {:?}", + df.schema().iter_names().collect::>() + ) + } + + Ok(df.take_columns().swap_remove(0).take_materialized_series()) + } +} impl Serialize for Series { fn serialize( @@ -23,50 +44,11 @@ impl Serialize for Series { { use serde::ser::Error; - if self.dtype().is_object() { - return Err(polars_err!( - ComputeError: - "serializing data of type Object is not supported", - )) - .map_err(S::Error::custom); - } - - let bytes = vec![]; - let mut bytes = std::io::Cursor::new(bytes); - let mut ipc_writer = arrow::io::ipc::write::StreamWriter::new( - &mut bytes, - WriteOptions { - // Compression should be done on an outer level - compression: Some(arrow::io::ipc::write::Compression::ZSTD), - }, - ); - - let df = unsafe { - DataFrame::new_no_checks_height_from_first(vec![self.rechunk().into_column()]) - }; - - ipc_writer.set_custom_schema_metadata(Arc::new(Metadata::from([( - FLAGS_KEY, - PlSmallStr::from(self.get_flags().bits().to_string()), - )]))); - - ipc_writer - .start( - &ArrowSchema::from_iter([Field { - name: self.name().clone(), - dtype: self.dtype().clone(), - } - .to_arrow(CompatLevel::newest())]), - None, - ) - .map_err(S::Error::custom)?; - - for batch in df.iter_chunks(CompatLevel::newest(), false) { - ipc_writer.write(&batch, None).map_err(S::Error::custom)?; - } - - ipc_writer.finish().map_err(S::Error::custom)?; - serializer.serialize_bytes(bytes.into_inner().as_slice()) + serializer.serialize_bytes( + self.serialize_to_bytes() + .map_err(S::Error::custom)? + .as_slice(), + ) } } @@ -75,76 +57,10 @@ impl<'de> Deserialize<'de> for Series { where D: Deserializer<'de>, { - struct SeriesVisitor; - - impl<'de> Visitor<'de> for SeriesVisitor { - type Value = Series; - - fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { - formatter.write_str("bytes (IPC)") - } - - fn visit_bytes(self, mut v: &[u8]) -> Result - where - E: DeError, - { - let mut md = read_stream_metadata(&mut v).map_err(E::custom)?; - let arrow_schema = md.schema.clone(); - - let custom_metadata = md.custom_schema_metadata.take(); - - let reader = StreamReader::new(v, md, None); - let dfs = reader - .into_iter() - .map_while(|batch| match batch { - Ok(StreamState::Some(batch)) => { - Some(DataFrame::try_from((batch, &arrow_schema))) - }, - Ok(StreamState::Waiting) => None, - Err(e) => Some(Err(e)), - }) - .collect::>>() - .map_err(E::custom)?; - - let df = accumulate_dataframes_vertical(dfs).map_err(E::custom)?; - - if df.width() != 1 { - return Err(polars_err!( - ShapeMismatch: - "expected only 1 column when deserializing Series from IPC, got columns: {:?}", - df.schema().iter_names().collect::>() - )).map_err(E::custom); - } - - let mut s = df.take_columns().swap_remove(0).take_materialized_series(); - - if let Some(custom_metadata) = custom_metadata { - if let Some(flags) = custom_metadata.get(&FLAGS_KEY) { - if let Ok(v) = flags.parse::() { - if let Some(flags) = StatisticsFlags::from_bits(v) { - s.set_flags(flags); - } - } else if config::verbose() { - eprintln!("Series::Deserialize: Failed to parse as u8: {:?}", flags) - } - } - } - - Ok(s) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - // This is not ideal, but we hit here if the serialization format is JSON. - let bytes = std::iter::from_fn(|| seq.next_element::().transpose()) - .collect::, A::Error>>()?; - - self.visit_bytes(&bytes) - } - } - - deserializer.deserialize_bytes(SeriesVisitor) + deserialize_map_bytes(deserializer, &mut |b| { + let v = &mut b.as_ref(); + Self::deserialize_from_reader(v) + })? + .map_err(D::Error::custom) } } diff --git a/crates/polars-io/src/ipc/ipc_stream.rs b/crates/polars-io/src/ipc/ipc_stream.rs index c3d2f353759b..76cb85e62d61 100644 --- a/crates/polars-io/src/ipc/ipc_stream.rs +++ b/crates/polars-io/src/ipc/ipc_stream.rs @@ -40,6 +40,7 @@ use arrow::datatypes::Metadata; use arrow::io::ipc::read::{StreamMetadata, StreamState}; use arrow::io::ipc::write::WriteOptions; use arrow::io::ipc::{read, write}; +use polars_core::frame::chunk_df_for_writing; use polars_core::prelude::*; use crate::prelude::*; diff --git a/crates/polars-io/src/parquet/write/writer.rs b/crates/polars-io/src/parquet/write/writer.rs index 13885316d9d7..5d7d2ddafdb6 100644 --- a/crates/polars-io/src/parquet/write/writer.rs +++ b/crates/polars-io/src/parquet/write/writer.rs @@ -2,6 +2,7 @@ use std::io::Write; use std::sync::Mutex; use arrow::datatypes::PhysicalType; +use polars_core::frame::chunk_df_for_writing; use polars_core::prelude::*; use polars_parquet::write::{ to_parquet_schema, transverse, CompressionOptions, Encoding, FileWriter, StatisticsOptions, @@ -11,7 +12,6 @@ use polars_parquet::write::{ use super::batched_writer::BatchedWriter; use super::options::ParquetCompression; use super::ParquetWriteOptions; -use crate::prelude::chunk_df_for_writing; use crate::shared::schema_to_arrow_checked; impl ParquetWriteOptions { diff --git a/crates/polars-io/src/utils/other.rs b/crates/polars-io/src/utils/other.rs index e6f51b5f3b6d..ceec5dc46217 100644 --- a/crates/polars-io/src/utils/other.rs +++ b/crates/polars-io/src/utils/other.rs @@ -1,13 +1,9 @@ -#[cfg(any(feature = "ipc_streaming", feature = "parquet"))] -use std::borrow::Cow; use std::io::Read; #[cfg(target_os = "emscripten")] use std::io::{Seek, SeekFrom}; use once_cell::sync::Lazy; use polars_core::prelude::*; -#[cfg(any(feature = "ipc_streaming", feature = "parquet"))] -use polars_core::utils::{accumulate_dataframes_vertical_unchecked, split_df_as_ref}; use polars_utils::mmap::{MMapSemaphore, MemSlice}; use regex::{Regex, RegexBuilder}; @@ -206,75 +202,6 @@ pub fn materialize_projection( } } -/// Split DataFrame into chunks in preparation for writing. The chunks have a -/// maximum number of rows per chunk to ensure reasonable memory efficiency when -/// reading the resulting file, and a minimum size per chunk to ensure -/// reasonable performance when writing. -#[cfg(any(feature = "ipc_streaming", feature = "parquet"))] -pub(crate) fn chunk_df_for_writing( - df: &mut DataFrame, - row_group_size: usize, -) -> PolarsResult> { - // ensures all chunks are aligned. - df.align_chunks_par(); - - // Accumulate many small chunks to the row group size. - // See: #16403 - if !df.get_columns().is_empty() - && df.get_columns()[0] - .as_materialized_series() - .chunk_lengths() - .take(5) - .all(|len| len < row_group_size) - { - fn finish(scratch: &mut Vec, new_chunks: &mut Vec) { - let mut new = accumulate_dataframes_vertical_unchecked(scratch.drain(..)); - new.as_single_chunk_par(); - new_chunks.push(new); - } - - let mut new_chunks = Vec::with_capacity(df.first_col_n_chunks()); // upper limit; - let mut scratch = vec![]; - let mut remaining = row_group_size; - - for df in df.split_chunks() { - remaining = remaining.saturating_sub(df.height()); - scratch.push(df); - - if remaining == 0 { - remaining = row_group_size; - finish(&mut scratch, &mut new_chunks); - } - } - if !scratch.is_empty() { - finish(&mut scratch, &mut new_chunks); - } - return Ok(Cow::Owned(accumulate_dataframes_vertical_unchecked( - new_chunks, - ))); - } - - let n_splits = df.height() / row_group_size; - let result = if n_splits > 0 { - let mut splits = split_df_as_ref(df, n_splits, false); - - for df in splits.iter_mut() { - // If the chunks are small enough, writing many small chunks - // leads to slow writing performance, so in that case we - // merge them. - let n_chunks = df.first_col_n_chunks(); - if n_chunks > 1 && (df.estimated_size() / n_chunks < 128 * 1024) { - df.as_single_chunk_par(); - } - } - - Cow::Owned(accumulate_dataframes_vertical_unchecked(splits)) - } else { - Cow::Borrowed(df) - }; - Ok(result) -} - #[cfg(test)] mod tests { use super::FLOAT_RE; diff --git a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs index dba897a8eeea..71db1a74601a 100644 --- a/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs +++ b/crates/polars-parquet/src/parquet/metadata/column_chunk_metadata.rs @@ -63,10 +63,13 @@ fn deserialize_column_chunk<'de, D>(deserializer: D) -> std::result::Result, { - let buf = Vec::::deserialize(deserializer)?; - let mut cursor = Cursor::new(&buf[..]); - let mut protocol = TCompactInputProtocol::new(&mut cursor, usize::MAX); - ColumnChunk::read_from_in_protocol(&mut protocol).map_err(D::Error::custom) + use polars_utils::pl_serialize::deserialize_map_bytes; + + deserialize_map_bytes(deserializer, &mut |b| { + let mut b = b.as_ref(); + let mut protocol = TCompactInputProtocol::new(&mut b, usize::MAX); + ColumnChunk::read_from_in_protocol(&mut protocol).map_err(D::Error::custom) + })? } // Represents common operations for a column chunk. diff --git a/crates/polars-plan/src/client/mod.rs b/crates/polars-plan/src/client/mod.rs index c42e481cd1f6..67ce7bbced72 100644 --- a/crates/polars-plan/src/client/mod.rs +++ b/crates/polars-plan/src/client/mod.rs @@ -12,9 +12,7 @@ pub fn prepare_cloud_plan(dsl: DslPlan) -> PolarsResult> { // Serialize the plan. let mut writer = Vec::new(); - pl_serialize::SerializeOptions::default() - .with_compression(true) - .serialize_into_writer(&mut writer, &dsl)?; + pl_serialize::SerializeOptions::default().serialize_into_writer(&mut writer, &dsl)?; Ok(writer) } diff --git a/crates/polars-plan/src/dsl/expr_dyn_fn.rs b/crates/polars-plan/src/dsl/expr_dyn_fn.rs index 039f00b4f2ba..da7ace09b721 100644 --- a/crates/polars-plan/src/dsl/expr_dyn_fn.rs +++ b/crates/polars-plan/src/dsl/expr_dyn_fn.rs @@ -2,6 +2,8 @@ use std::fmt::Formatter; use std::ops::Deref; use std::sync::Arc; +#[cfg(feature = "serde")] +use polars_utils::pl_serialize::deserialize_map_bytes; #[cfg(feature = "serde")] use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -43,7 +45,7 @@ impl Serialize for LazySerde { { match self { Self::Deserialized(t) => t.serialize(serializer), - Self::Bytes(b) => serializer.serialize_bytes(b), + Self::Bytes(b) => b.serialize(serializer), } } } @@ -54,8 +56,8 @@ impl<'a, T: Deserialize<'a> + Clone> Deserialize<'a> for LazySerde { where D: Deserializer<'a>, { - let buf = Vec::::deserialize(deserializer)?; - Ok(Self::Bytes(bytes::Bytes::from(buf))) + let buf = bytes::Bytes::deserialize(deserializer)?; + Ok(Self::Bytes(buf)) } } @@ -69,17 +71,17 @@ impl<'a> Deserialize<'a> for SpecialEq> { use serde::de::Error; #[cfg(feature = "python")] { - let buf = Vec::::deserialize(deserializer)?; - - if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { - let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) - .map_err(|e| D::Error::custom(format!("{e}")))?; - Ok(SpecialEq::new(udf)) - } else { - Err(D::Error::custom( - "deserialization not supported for this 'opaque' function", - )) - } + deserialize_map_bytes(deserializer, &mut |buf| { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { + let udf = python_udf::PythonUdfExpression::try_deserialize(&buf) + .map_err(|e| D::Error::custom(format!("{e}")))?; + Ok(SpecialEq::new(udf)) + } else { + Err(D::Error::custom( + "deserialization not supported for this 'opaque' function", + )) + } + })? } #[cfg(not(feature = "python"))] { @@ -403,17 +405,17 @@ impl<'a> Deserialize<'a> for GetOutput { use serde::de::Error; #[cfg(feature = "python")] { - let buf = Vec::::deserialize(deserializer)?; - - if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { - let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) - .map_err(|e| D::Error::custom(format!("{e}")))?; - Ok(SpecialEq::new(get_output)) - } else { - Err(D::Error::custom( - "deserialization not supported for this output field", - )) - } + deserialize_map_bytes(deserializer, &mut |buf| { + if buf.starts_with(python_udf::PYTHON_SERDE_MAGIC_BYTE_MARK) { + let get_output = python_udf::PythonGetOutput::try_deserialize(&buf) + .map_err(|e| D::Error::custom(format!("{e}")))?; + Ok(SpecialEq::new(get_output)) + } else { + Err(D::Error::custom( + "deserialization not supported for this output field", + )) + } + })? } #[cfg(not(feature = "python"))] { diff --git a/crates/polars-python/src/dataframe/serde.rs b/crates/polars-python/src/dataframe/serde.rs index 48dd22fdc0d6..3b6cd596cec8 100644 --- a/crates/polars-python/src/dataframe/serde.rs +++ b/crates/polars-python/src/dataframe/serde.rs @@ -3,10 +3,7 @@ use std::ops::Deref; use polars::prelude::*; use polars_io::mmap::ReaderBytes; -use polars_utils::pl_serialize; use pyo3::prelude::*; -use pyo3::pybacked::PyBackedBytes; -use pyo3::types::PyBytes; use super::PyDataFrame; use crate::error::PyPolarsErr; @@ -15,62 +12,40 @@ use crate::file::{get_file_like, get_mmap_bytes_reader}; #[pymethods] impl PyDataFrame { - #[cfg(feature = "ipc_streaming")] - fn __getstate__<'py>(&self, py: Python<'py>) -> Bound<'py, PyBytes> { - // Used in pickle/pickling - PyBytes::new( - py, - &pl_serialize::SerializeOptions::default() - .with_compression(true) - .serialize_to_bytes(&self.df) - .unwrap(), - ) - } + /// Serialize into binary data. + fn serialize_binary(&mut self, py: Python, py_f: PyObject) -> PyResult<()> { + let file = get_file_like(py_f, true)?; + let mut writer = BufWriter::new(file); - #[cfg(feature = "ipc_streaming")] - fn __setstate__(&mut self, state: &Bound) -> PyResult<()> { - // Used in pickle/pickling - match state.extract::() { - Ok(s) => pl_serialize::SerializeOptions::default() - .with_compression(true) - .deserialize_from_reader(&*s) - .map(|df| { - self.df = df; - }) - .map_err(|e| PyPolarsErr::from(e).into()), - Err(e) => Err(e), - } + py.allow_threads(|| { + self.df + .serialize_into_writer(&mut writer) + .map_err(|e| PyPolarsErr::from(e).into()) + }) } - /// Serialize into binary data. - fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { - let file = get_file_like(py_f, true)?; - let writer = BufWriter::new(file); - pl_serialize::SerializeOptions::default() - .with_compression(true) - .serialize_into_writer(writer, &self.df) - .map_err(|err| ComputeError::new_err(err.to_string())) + /// Deserialize a file-like object containing binary data into a DataFrame. + #[staticmethod] + fn deserialize_binary(py: Python, py_f: PyObject) -> PyResult { + let file = get_file_like(py_f, false)?; + let mut file = BufReader::new(file); + + py.allow_threads(|| { + DataFrame::deserialize_from_reader(&mut file) + .map_err(|e| PyPolarsErr::from(e).into()) + .map(|x| x.into()) + }) } /// Serialize into a JSON string. #[cfg(feature = "json")] - pub fn serialize_json(&mut self, py_f: PyObject) -> PyResult<()> { + pub fn serialize_json(&mut self, py: Python, py_f: PyObject) -> PyResult<()> { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); - serde_json::to_writer(writer, &self.df) - .map_err(|err| ComputeError::new_err(err.to_string())) - } - - /// Deserialize a file-like object containing binary data into a DataFrame. - #[staticmethod] - fn deserialize_binary(py_f: PyObject) -> PyResult { - let file = get_file_like(py_f, false)?; - let file = BufReader::new(file); - let df: DataFrame = pl_serialize::SerializeOptions::default() - .with_compression(true) - .deserialize_from_reader(file) - .map_err(|err| ComputeError::new_err(err.to_string()))?; - Ok(df.into()) + py.allow_threads(|| { + serde_json::to_writer(writer, &self.df) + .map_err(|err| ComputeError::new_err(err.to_string())) + }) } /// Deserialize a file-like object containing JSON string data into a DataFrame. diff --git a/crates/polars-python/src/expr/serde.rs b/crates/polars-python/src/expr/serde.rs index 08685baed417..3096a36e07ec 100644 --- a/crates/polars-python/src/expr/serde.rs +++ b/crates/polars-python/src/expr/serde.rs @@ -1,4 +1,4 @@ -use std::io::{BufReader, BufWriter, Cursor}; +use std::io::{BufReader, BufWriter}; use polars::lazy::prelude::Expr; use polars_utils::pl_serialize; @@ -17,7 +17,6 @@ impl PyExpr { // Used in pickle/pickling let mut writer: Vec = vec![]; pl_serialize::SerializeOptions::default() - .with_compression(true) .serialize_into_writer(&mut writer, &self.inner) .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; @@ -26,12 +25,9 @@ impl PyExpr { fn __setstate__(&mut self, state: &Bound) -> PyResult<()> { // Used in pickle/pickling - let bytes = state.extract::()?; - let cursor = Cursor::new(&*bytes); self.inner = pl_serialize::SerializeOptions::default() - .with_compression(true) - .deserialize_from_reader(cursor) + .deserialize_from_reader(&*bytes) .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; Ok(()) } @@ -41,7 +37,6 @@ impl PyExpr { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); pl_serialize::SerializeOptions::default() - .with_compression(true) .serialize_into_writer(writer, &self.inner) .map_err(|err| ComputeError::new_err(err.to_string())) } @@ -61,7 +56,6 @@ impl PyExpr { let file = get_file_like(py_f, false)?; let reader = BufReader::new(file); let expr: Expr = pl_serialize::SerializeOptions::default() - .with_compression(true) .deserialize_from_reader(reader) .map_err(|err| ComputeError::new_err(err.to_string()))?; Ok(expr.into()) diff --git a/crates/polars-python/src/lazyframe/serde.rs b/crates/polars-python/src/lazyframe/serde.rs index 2164785ddb24..dd7c6b9dc648 100644 --- a/crates/polars-python/src/lazyframe/serde.rs +++ b/crates/polars-python/src/lazyframe/serde.rs @@ -2,11 +2,8 @@ use std::io::{BufReader, BufWriter}; use polars_utils::pl_serialize; use pyo3::prelude::*; -use pyo3::pybacked::PyBackedBytes; -use pyo3::types::PyBytes; use super::PyLazyFrame; -use crate::error::PyPolarsErr; use crate::exceptions::ComputeError; use crate::file::get_file_like; use crate::prelude::*; @@ -14,72 +11,49 @@ use crate::prelude::*; #[pymethods] #[allow(clippy::should_implement_trait)] impl PyLazyFrame { - fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult> { - // Used in pickle/pickling - let mut writer: Vec = vec![]; - - pl_serialize::SerializeOptions::default() - .with_compression(true) - .serialize_into_writer(&mut writer, &self.ldf.logical_plan) - .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; - - Ok(PyBytes::new(py, &writer)) - } - - fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - // Used in pickle/pickling - match state.extract::(py) { - Ok(s) => { - let lp: DslPlan = pl_serialize::SerializeOptions::default() - .with_compression(true) - .deserialize_from_reader(&*s) - .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; - self.ldf = LazyFrame::from(lp); - Ok(()) - }, - Err(e) => Err(e), - } - } - /// Serialize into binary data. - fn serialize_binary(&self, py_f: PyObject) -> PyResult<()> { + fn serialize_binary(&self, py: Python, py_f: PyObject) -> PyResult<()> { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); - pl_serialize::SerializeOptions::default() - .with_compression(true) - .serialize_into_writer(writer, &self.ldf.logical_plan) - .map_err(|err| ComputeError::new_err(err.to_string())) + py.allow_threads(|| { + pl_serialize::SerializeOptions::default() + .serialize_into_writer(writer, &self.ldf.logical_plan) + .map_err(|err| ComputeError::new_err(err.to_string())) + }) } /// Serialize into a JSON string. #[cfg(feature = "json")] - fn serialize_json(&self, py_f: PyObject) -> PyResult<()> { + fn serialize_json(&self, py: Python, py_f: PyObject) -> PyResult<()> { let file = get_file_like(py_f, true)?; let writer = BufWriter::new(file); - serde_json::to_writer(writer, &self.ldf.logical_plan) - .map_err(|err| ComputeError::new_err(err.to_string())) + py.allow_threads(|| { + serde_json::to_writer(writer, &self.ldf.logical_plan) + .map_err(|err| ComputeError::new_err(err.to_string())) + }) } /// Deserialize a file-like object containing binary data into a LazyFrame. #[staticmethod] - fn deserialize_binary(py_f: PyObject) -> PyResult { + fn deserialize_binary(py: Python, py_f: PyObject) -> PyResult { let file = get_file_like(py_f, false)?; let reader = BufReader::new(file); - let lp: DslPlan = pl_serialize::SerializeOptions::default() - .with_compression(true) - .deserialize_from_reader(reader) - .map_err(|err| ComputeError::new_err(err.to_string()))?; + let lp: DslPlan = py.allow_threads(|| { + pl_serialize::SerializeOptions::default() + .deserialize_from_reader(reader) + .map_err(|err| ComputeError::new_err(err.to_string())) + })?; Ok(LazyFrame::from(lp).into()) } /// Deserialize a file-like object containing JSON string data into a LazyFrame. #[staticmethod] #[cfg(feature = "json")] - fn deserialize_json(py_f: PyObject) -> PyResult { + fn deserialize_json(py: Python, py_f: PyObject) -> PyResult { // it is faster to first read to memory and then parse: https://github.com/serde-rs/json/issues/160 // so don't bother with files. let mut json = String::new(); - let _ = get_file_like(py_f, false)? + get_file_like(py_f, false)? .read_to_string(&mut json) .unwrap(); @@ -91,8 +65,10 @@ impl PyLazyFrame { // in this scope. let json = unsafe { std::mem::transmute::<&'_ str, &'static str>(json.as_str()) }; - let lp = serde_json::from_str::(json) - .map_err(|err| ComputeError::new_err(err.to_string()))?; + let lp = py.allow_threads(|| { + serde_json::from_str::(json) + .map_err(|err| ComputeError::new_err(err.to_string())) + })?; Ok(LazyFrame::from(lp).into()) } } diff --git a/crates/polars-python/src/series/general.rs b/crates/polars-python/src/series/general.rs index 408217963bde..d363dabb396f 100644 --- a/crates/polars-python/src/series/general.rs +++ b/crates/polars-python/src/series/general.rs @@ -2,7 +2,6 @@ use polars_core::chunked_array::cast::CastOptions; use polars_core::series::IsSorted; use polars_core::utils::flatten::flatten_series; use polars_row::RowEncodingOptions; -use polars_utils::pl_serialize; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -390,14 +389,10 @@ impl PySeries { fn __getstate__<'py>(&self, py: Python<'py>) -> PyResult> { // Used in pickle/pickling - let mut buf: Vec = vec![]; - - pl_serialize::SerializeOptions::default() - .with_compression(true) - .serialize_into_writer(&mut buf, &self.series) - .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; - - Ok(PyBytes::new(py, &buf)) + Ok(PyBytes::new( + py, + &py.allow_threads(|| self.series.serialize_to_bytes().map_err(PyPolarsErr::from))?, + )) } fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { @@ -405,14 +400,11 @@ impl PySeries { use pyo3::pybacked::PyBackedBytes; match state.extract::(py) { - Ok(s) => { - let s: Series = pl_serialize::SerializeOptions::default() - .with_compression(true) - .deserialize_from_reader(&*s) - .map_err(|e| PyPolarsErr::Other(format!("{}", e)))?; + Ok(s) => py.allow_threads(|| { + let s = Series::deserialize_from_reader(&mut &*s).map_err(PyPolarsErr::from)?; self.series = s; Ok(()) - }, + }), Err(e) => Err(e), } } diff --git a/crates/polars-utils/src/pl_serialize.rs b/crates/polars-utils/src/pl_serialize.rs index d0ab267daa3a..04079943a2be 100644 --- a/crates/polars-utils/src/pl_serialize.rs +++ b/crates/polars-utils/src/pl_serialize.rs @@ -100,6 +100,61 @@ where Ok(v) } +/// Potentially avoids copying memory compared to a naive `Vec::::deserialize`. +/// +/// This is essentially boilerplate for visiting bytes without copying where possible. +pub fn deserialize_map_bytes<'de, D, O>( + deserializer: D, + func: &mut (dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>) -> O), +) -> Result +where + D: serde::de::Deserializer<'de>, +{ + // Lets us avoid monomorphizing the visitor + let mut out: Option = None; + struct V<'f>(&'f mut (dyn for<'b> FnMut(std::borrow::Cow<'b, [u8]>))); + + deserializer.deserialize_bytes(V(&mut |v| drop(out.replace(func(v)))))?; + + return Ok(out.unwrap()); + + impl<'de> serde::de::Visitor<'de> for V<'_> { + type Value = (); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("deserialize_map_bytes") + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + self.0(std::borrow::Cow::Borrowed(v)); + Ok(()) + } + + fn visit_byte_buf(self, v: Vec) -> Result + where + E: serde::de::Error, + { + self.0(std::borrow::Cow::Owned(v)); + Ok(()) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + // This is not ideal, but we hit here if the serialization format is JSON. + let bytes = std::iter::from_fn(|| seq.next_element::().transpose()) + .collect::, A::Error>>()?; + + self.0(std::borrow::Cow::Owned(bytes)); + Ok(()) + } + } +} + #[cfg(test)] mod tests { #[test] diff --git a/crates/polars-utils/src/python_function.rs b/crates/polars-utils/src/python_function.rs index 8fc69a774f7c..949fc41b6af8 100644 --- a/crates/polars-utils/src/python_function.rs +++ b/crates/polars-utils/src/python_function.rs @@ -8,6 +8,8 @@ pub use serde_wrap::{ SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK, }; +use crate::pl_serialize::deserialize_map_bytes; + #[derive(Debug)] pub struct PythonFunction(pub PyObject); @@ -60,10 +62,9 @@ impl<'a> serde::Deserialize<'a> for PythonFunction { D: serde::Deserializer<'a>, { use serde::de::Error; - let bytes = Vec::::deserialize(deserializer)?; - let v = Self::try_deserialize_bytes(bytes.as_slice()) - .map_err(|e| D::Error::custom(e.to_string())); - v + deserialize_map_bytes(deserializer, &mut |bytes| { + Self::try_deserialize_bytes(&bytes).map_err(|e| D::Error::custom(e.to_string())) + })? } } @@ -132,6 +133,8 @@ mod serde_wrap { use once_cell::sync::Lazy; use polars_error::PolarsResult; + use crate::pl_serialize::deserialize_map_bytes; + pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes(); /// [minor, micro] pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version); @@ -170,44 +173,45 @@ mod serde_wrap { D: serde::Deserializer<'a>, { use serde::de::Error; - let bytes = Vec::::deserialize(deserializer)?; - - let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { - return Err(D::Error::custom( - "unexpected EOF when reading serialized pyobject version", - )); - }; - - if magic != SERDE_MAGIC_BYTE_MARK { - return Err(D::Error::custom( - "serialized pyobject did not begin with magic byte mark", - )); - } - - let bytes = rem; - - let [a, b, rem @ ..] = bytes else { - return Err(D::Error::custom( - "unexpected EOF when reading serialized pyobject metadata", - )); - }; - - let py3_version = [*a, *b]; - - if py3_version != *PYTHON3_VERSION { - return Err(D::Error::custom(format!( - "python version that pyobject was serialized with {:?} \ - differs from system python version {:?}", - (3, py3_version[0], py3_version[1]), - (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]), - ))); - } - - let bytes = rem; - - T::try_deserialize_bytes(bytes) - .map(Self) - .map_err(|e| D::Error::custom(e.to_string())) + + deserialize_map_bytes(deserializer, &mut |bytes| { + let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject version", + )); + }; + + if magic != SERDE_MAGIC_BYTE_MARK { + return Err(D::Error::custom( + "serialized pyobject did not begin with magic byte mark", + )); + } + + let bytes = rem; + + let [a, b, rem @ ..] = bytes else { + return Err(D::Error::custom( + "unexpected EOF when reading serialized pyobject metadata", + )); + }; + + let py3_version = [*a, *b]; + + if py3_version != *PYTHON3_VERSION { + return Err(D::Error::custom(format!( + "python version that pyobject was serialized with {:?} \ + differs from system python version {:?}", + (3, py3_version[0], py3_version[1]), + (3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]), + ))); + } + + let bytes = rem; + + T::try_deserialize_bytes(bytes) + .map(Self) + .map_err(|e| D::Error::custom(e.to_string())) + })? } } } diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 5ccf9db5b210..7014b6e74554 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1154,11 +1154,11 @@ def __ge__(self, other: Any) -> DataFrame: def __le__(self, other: Any) -> DataFrame: return self._comp(other, "lt_eq") - def __getstate__(self) -> list[Series]: - return self.get_columns() + def __getstate__(self) -> bytes: + return self.serialize() - def __setstate__(self, state: list[Series]) -> None: - self._df = DataFrame(state)._df + def __setstate__(self, state: bytes) -> None: + self._df = self.deserialize(BytesIO(state))._df def __mul__(self, other: DataFrame | Series | int | float) -> DataFrame: if isinstance(other, DataFrame): @@ -2635,8 +2635,8 @@ def serialize( ... } ... ) >>> bytes = df.serialize() - >>> bytes # doctest: +ELLIPSIS - b'x\x01bb@\x80\x15...' + >>> type(bytes) + The bytes can later be deserialized back into a DataFrame. diff --git a/py-polars/polars/expr/meta.py b/py-polars/polars/expr/meta.py index 6c97082cbcbb..71302643640a 100644 --- a/py-polars/polars/expr/meta.py +++ b/py-polars/polars/expr/meta.py @@ -331,8 +331,8 @@ def serialize( >>> expr = pl.col("foo").sum().over("bar") >>> bytes = expr.meta.serialize() - >>> bytes # doctest: +ELLIPSIS - b'x\x01\x02L\x80\x81...' + >>> type(bytes) + The bytes can later be deserialized back into an `Expr` object. diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index ebef241ce3c3..5453daa3995c 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -328,11 +328,10 @@ def _from_pyldf(cls, ldf: PyLazyFrame) -> LazyFrame: return self def __getstate__(self) -> bytes: - return self._ldf.__getstate__() + return self.serialize() def __setstate__(self, state: bytes) -> None: - self._ldf = LazyFrame()._ldf # Initialize with a dummy - self._ldf.__setstate__(state) + self._ldf = self.deserialize(BytesIO(state))._ldf @classmethod def _scan_python_function( diff --git a/py-polars/tests/unit/io/cloud/test_credential_provider.py b/py-polars/tests/unit/io/cloud/test_credential_provider.py index 5afbc343fdec..b88568cad3dc 100644 --- a/py-polars/tests/unit/io/cloud/test_credential_provider.py +++ b/py-polars/tests/unit/io/cloud/test_credential_provider.py @@ -71,14 +71,11 @@ def __call__(self) -> pl.CredentialProviderFunctionReturn: def test_scan_credential_provider_serialization_pyversion() -> None: - import zlib - lf = pl.scan_parquet( "s3://bucket/path", credential_provider=pl.CredentialProviderAWS() ) serialized = lf.serialize() - serialized = zlib.decompress(serialized) serialized = bytearray(serialized) # We can't monkeypatch sys.python_version so we just mutate the output @@ -93,8 +90,6 @@ def test_scan_credential_provider_serialization_pyversion() -> None: serialized[i] = 255 serialized[i + 1] = 254 - serialized = zlib.compress(serialized) - with pytest.raises(ComputeError, match=r"python version.*(3, 255, 254).*differs.*"): lf = pl.LazyFrame.deserialize(io.BytesIO(serialized))