Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/221 serialize enums as flattened structs #222

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
9f89dca
add option to TracingOptions for enums with data to be flattened to s…
raj-nimble Aug 15, 2024
0a84386
add test cases in from_samples to repro enums as flattened structs
raj-nimble Aug 15, 2024
e1d617a
add tracer changes to create GenericField for enums as flattened structs
raj-nimble Aug 15, 2024
e35d0a9
change from_samples unit tests to be a single equality test
raj-nimble Aug 15, 2024
5d8941c
add unit tests for edge cases, add basic test for from_type
raj-nimble Aug 29, 2024
ec1ae8e
missed an addition to internal schema checks
raj-nimble Sep 11, 2024
5cfcd01
typo fix
raj-nimble Sep 11, 2024
52fecdd
add flattened union builder
raj-nimble Sep 18, 2024
fd144bf
wip: flattened union builder first working version
raj-nimble Sep 18, 2024
202c078
fix num elements in struct and start adding unit tests for flattened …
raj-nimble Sep 21, 2024
180c99c
fix row count
raj-nimble Sep 21, 2024
ebe3117
fix schema/array column ordering issue by using BTreeMaps in the Trac…
raj-nimble Sep 21, 2024
0104fa5
fix ordering of fields in unit tests and replace continue with todo
raj-nimble Sep 23, 2024
522140b
Move flattened_union_builder tests to test_with_arrow directory
raj-nimble Oct 22, 2024
25755fb
complex nested enum test case
raj-nimble Oct 24, 2024
c5b98b6
fix duplicated error logs and fix/force nested enum/dictionaries to b…
raj-nimble Oct 28, 2024
7ed5ce5
consolidate helpers and cleanup from_samples and from_type unit tests
raj-nimble Oct 28, 2024
019b2a5
added tracing options rustdoc
raj-nimble Oct 28, 2024
d6314b1
slightly logging improvement in flattened union builder
raj-nimble Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions serde_arrow/src/internal/arrow/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,51 @@ pub struct Field {
pub metadata: HashMap<String, String>,
}

impl PartialOrd for Field {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.name.partial_cmp(&other.name)
}
}

impl Field {
pub fn to_flattened_union_field(mut self, variant_name: &str) -> Self {
self.name = format!("{}::{}", variant_name, self.name);
self.nullable = true;
self
}

fn from_flattened_union(&self) -> bool {
self.name.contains("::")
}

pub fn union_variant_name(&self) -> Option<&str> {
if self.from_flattened_union() {
self.name.split("::").next()
} else {
None
}
}

pub fn union_field_name(&self) -> Option<String> {
if self.from_flattened_union() {
Some(
self.name
.split("::")
.skip(1)
.fold(String::new(), |acc: String, e| {
if acc.is_empty() {
String::from(e)
} else {
format!("{acc}::{e}")
}
}),
)
} else {
None
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DataType {
Expand Down
1 change: 1 addition & 0 deletions serde_arrow/src/internal/schema/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod variable_shape_tensor_field;

pub use bool8_field::Bool8Field;
pub use fixed_shape_tensor_field::FixedShapeTensorField;
pub(crate) use utils::fix_dictionaries;
pub use variable_shape_tensor_field::VariableShapeTensorField;

const _: () = {
Expand Down
15 changes: 14 additions & 1 deletion serde_arrow/src/internal/schema/extensions/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::internal::error::{fail, Result};
use crate::internal::{
arrow::{DataType, Field},
error::{fail, Result},
};

pub fn check_dim_names(ndim: usize, dim_names: &[String]) -> Result<()> {
if dim_names.len() != ndim {
Expand Down Expand Up @@ -56,3 +59,13 @@ impl<T: std::fmt::Debug> std::fmt::Display for DebugRepr<T> {
write!(f, "{:?}", self.0)
}
}

pub(crate) fn fix_dictionaries(field: &mut Field) {
if matches!(field.data_type, DataType::Dictionary(_, _, _)) {
field.nullable = true;
} else if let DataType::Struct(children) = &mut field.data_type {
for child in children {
fix_dictionaries(child);
}
}
}
53 changes: 52 additions & 1 deletion serde_arrow/src/internal/schema/from_samples/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,10 @@ mod test {
use serde::Serialize;
use serde_json::{json, Value};

use crate::internal::schema::{transmute_field, TracingOptions};
use crate::internal::{
schema::{transmute_field, TracingOptions},
testing::{Coin, Number, Optionals, Payment},
};

use super::*;

Expand Down Expand Up @@ -920,4 +923,52 @@ mod test {
expected,
);
}

#[test]
fn example_enum_as_struct_equal_to_struct_with_nullable_fields() {
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);
let enum_tracer = Tracer::from_samples(Number::sample_items().as_slice(), opts).unwrap();
assert_eq!(enum_tracer.to_field().unwrap(), Number::expected_field());
}

#[test]
fn example_enum_as_struct_no_fields() {
// This should continue to maintain previously implemented behavior, serializing as a map
let opts = TracingOptions::default()
.enums_with_named_fields_as_structs(true)
.enums_without_data_as_strings(true);

let enum_tracer = Tracer::from_samples(Coin::sample_items(), opts).unwrap();
assert_eq!(enum_tracer.to_field().unwrap(), Coin::expected_field());
}

#[test]
#[should_panic]
fn example_enum_as_struct_no_fields_panics_when_opts_not_set() {
// This should continue to maintain previously implemented behavior,
// throwing an error because we detect Unions with no fields
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);

Tracer::from_samples(Coin::sample_items(), opts)
.unwrap()
.to_field()
.unwrap();
}

#[test]
fn example_enum_as_struct_all_fields_nullable() {
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);
let enum_tracer = Tracer::from_samples(Optionals::sample_items(), opts).unwrap();
assert_eq!(enum_tracer.to_field().unwrap(), Optionals::expected_field());
}

#[test]
#[should_panic]
fn example_enum_as_struct_tuple_variants() {
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);
let enum_tracer = Tracer::from_samples(Payment::sample_items(), opts).unwrap();

// Currently panics when `to_schema()` is called on the variant tracer
enum_tracer.to_field().unwrap();
}
}
56 changes: 56 additions & 0 deletions serde_arrow/src/internal/schema/from_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,59 @@ impl<'de, 'a> serde::de::Deserializer<'de> for IdentifierDeserializer<'a> {
unimplemented!('de, deserialize_enum, _: &'static str, _: &'static [&'static str]);
unimplemented!('de, deserialize_ignored_any);
}

#[cfg(test)]
mod test {
use crate::{
internal::{
schema::tracer::Tracer,
testing::{Coin, Number, Optionals, Payment},
},
schema::TracingOptions,
};

#[test]
fn example_enum_as_struct_equal_to_struct_with_nullable_fields() {
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);
let enum_tracer = Tracer::from_type::<Number>(opts).unwrap();
assert_eq!(enum_tracer.to_field().unwrap(), Number::expected_field());
}

#[test]
fn example_enum_as_struct_no_fields() {
// This should continue to maintain previously implemented behavior, serializing as a map
let opts = TracingOptions::default()
.enums_with_named_fields_as_structs(true)
.enums_without_data_as_strings(true);

let enum_tracer = Tracer::from_type::<Coin>(opts).unwrap();
assert_eq!(enum_tracer.to_field().unwrap(), Coin::expected_field());
}

#[test]
#[should_panic]
fn example_enum_as_struct_no_fields_panics_when_opts_not_set() {
// This should continue to maintain previously implemented behavior,
// throwing an error because we detect Unions with no fields
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);

Tracer::from_type::<Coin>(opts).unwrap().to_field().unwrap();
}

#[test]
fn example_enum_as_struct_all_fields_nullable() {
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);
let enum_tracer = Tracer::from_type::<Optionals>(opts).unwrap();
assert_eq!(enum_tracer.to_field().unwrap(), Optionals::expected_field());
}

#[test]
#[should_panic]
fn example_enum_as_struct_tuple_variants() {
let opts = TracingOptions::default().enums_with_named_fields_as_structs(true);
let enum_tracer = Tracer::from_type::<Payment>(opts).unwrap();

// Currently panics when `to_schema()` is called on the variant tracer
enum_tracer.to_field().unwrap();
}
}
5 changes: 4 additions & 1 deletion serde_arrow/src/internal/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,10 @@ fn validate_time64_field(field: &Field, unit: TimeUnit) -> Result<()> {
fn validate_struct_field(field: &Field, children: &[Field]) -> Result<()> {
// NOTE: do not check number of children: arrow-rs can 0 children, arrow2 not
match get_strategy_from_metadata(&field.metadata)? {
None | Some(Strategy::MapAsStruct) | Some(Strategy::TupleAsStruct) => {}
None
| Some(Strategy::MapAsStruct)
| Some(Strategy::TupleAsStruct)
| Some(Strategy::EnumsWithNamedFieldsAsStructs) => {}
Some(strategy) => fail!("invalid strategy for Struct field: {strategy}"),
}
for child in children {
Expand Down
10 changes: 10 additions & 0 deletions serde_arrow/src/internal/schema/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ pub enum Strategy {
/// polars does not support them)
///
MapAsStruct,
/// Serialize Rust enums that contain named field data as flattened structs
///
/// This strategy is a workaround for the fact that Unions are not supported in parquet.
/// Currently, only Serialization is supported. Deserialization is not.
/// When writing out the enum, it will be flattened into a Struct with
/// a list of Fields, where the names of those Fields are the field prefixed with
/// the name of the variant.
EnumsWithNamedFieldsAsStructs,
/// Mark a variant as unknown
///
/// This strategy applies only to fields with DataType Null. If
Expand All @@ -79,6 +87,7 @@ impl std::fmt::Display for Strategy {
Self::NaiveStrAsDate64 => write!(f, "NaiveStrAsDate64"),
Self::TupleAsStruct => write!(f, "TupleAsStruct"),
Self::MapAsStruct => write!(f, "MapAsStruct"),
Self::EnumsWithNamedFieldsAsStructs => write!(f, "EnumsWithNamedFieldsAsStructs"),
Self::UnknownVariant => write!(f, "UnknownVariant"),
}
}
Expand Down Expand Up @@ -108,6 +117,7 @@ impl FromStr for Strategy {
"NaiveStrAsDate64" => Ok(Self::NaiveStrAsDate64),
"TupleAsStruct" => Ok(Self::TupleAsStruct),
"MapAsStruct" => Ok(Self::MapAsStruct),
"EnumsWithNamedFieldsAsStructs" => Ok(Self::EnumsWithNamedFieldsAsStructs),
"UnknownVariant" => Ok(Self::UnknownVariant),
_ => fail!("Unknown strategy {s}"),
}
Expand Down
105 changes: 79 additions & 26 deletions serde_arrow/src/internal/schema/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::internal::{
arrow::{DataType, Field, UnionMode},
error::{fail, set_default, Context, Result},
schema::{
DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy, TracingMode, TracingOptions,
STRATEGY_KEY,
extensions::fix_dictionaries, DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy,
TracingMode, TracingOptions, STRATEGY_KEY,
},
};

Expand Down Expand Up @@ -101,32 +101,52 @@ impl Tracer {
Self::Unknown(UnknownTracer::new(name, path, options))
}

fn schema_tracing_error(
failed_data_type: impl std::fmt::Display,
tracing_mode: TracingMode,
) -> Result<SerdeArrowSchema> {
fail!(
concat!(
"Schema tracing is not directly supported for the root data type {failed_data_type}. ",
"Only struct-like types are supported as root types in schema tracing. ",
"{mitigation}",
),
failed_data_type = failed_data_type,
mitigation = match tracing_mode {
TracingMode::FromType => {
"Consider using the `Item` wrapper, i.e., `::from_type<Item<T>>()`."
}
TracingMode::FromSamples => {
"Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`."
}
TracingMode::Unknown => "Consider using the `Item` / `Items` wrappers.",
},

)
}

/// Convert the traced schema into a schema object
pub fn to_schema(&self) -> Result<SerdeArrowSchema> {
let root = self.to_field()?;

if root.nullable {
fail!("The root type cannot be nullable");
fail!("The root type cannot be nullable: {root:#?}");
}

let tracing_mode = dispatch_tracer!(self, tracer => tracer.options.tracing_mode);

let fields = match root.data_type {
DataType::Struct(children) => children,
DataType::Struct(children) => {
if let Some(strategy) = root.metadata.get(STRATEGY_KEY) {
if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() {
return Self::schema_tracing_error("Union", tracing_mode);
}
}

children
}
DataType::Null => fail!("No records found to determine schema"),
dt => fail!(
concat!(
"Schema tracing is not directly supported for the root data type {dt}. ",
"Only struct-like types are supported as root types in schema tracing. ",
"{mitigation}",
),
dt = DataTypeDisplay(&dt),
mitigation = match tracing_mode {
TracingMode::FromType => "Consider using the `Item` wrapper, i.e., `::from_type<Item<T>>()`.",
TracingMode::FromSamples => "Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`.",
TracingMode::Unknown => "Consider using the `Item` / `Items` wrappers.",
},
),
dt => return Self::schema_tracing_error(DataTypeDisplay(&dt), tracing_mode),
};

Ok(SerdeArrowSchema { fields })
Expand Down Expand Up @@ -1065,20 +1085,53 @@ impl UnionTracer {
}
}

let mut fields = Vec::new();
for (idx, variant) in self.variants.iter().enumerate() {
if let Some(variant) = variant {
fields.push((i8::try_from(idx)?, variant.tracer.to_field()?));
} else {
fields.push((i8::try_from(idx)?, unknown_variant_field()));
};
let data_type: DataType;
let mut metadata = HashMap::new();

if self.options.enums_with_named_fields_as_structs {
metadata.insert(
STRATEGY_KEY.to_string(),
Strategy::EnumsWithNamedFieldsAsStructs.to_string(),
);
let mut fields = BTreeMap::new();

// For this option, we want to merge the variant children up one level, combining the names
// For each variant with name variant_name
// For each variant_field with field_name
// Add field {variant_name}::{field_name} -> variant_field.to_field() that is nullable

for variant in &self.variants {
if let Some(variant) = variant {
let schema = variant.tracer.to_schema()?;
for field in schema.fields {
let mut flat_field = field.to_flattened_union_field(variant.name.as_str());
fix_dictionaries(&mut flat_field);
fields.insert(flat_field.name.to_string(), flat_field);
}
} else {
let uf = unknown_variant_field();
fields.insert(uf.name, unknown_variant_field());
};
}

data_type = DataType::Struct(fields.into_values().collect());
} else {
let mut fields = Vec::new();
for (idx, variant) in self.variants.iter().enumerate() {
if let Some(variant) = variant {
fields.push((i8::try_from(idx)?, variant.tracer.to_field()?));
} else {
fields.push((i8::try_from(idx)?, unknown_variant_field()));
};
}
data_type = DataType::Union(fields, UnionMode::Dense);
}

Ok(Field {
raj-nimble marked this conversation as resolved.
Show resolved Hide resolved
name: self.name.to_owned(),
data_type: DataType::Union(fields, UnionMode::Dense),
data_type,
nullable: self.nullable,
metadata: HashMap::new(),
metadata,
})
}

Expand Down
Loading