Skip to content

Commit

Permalink
fix duplicated error logs and fix/force nested enum/dictionaries to b…
Browse files Browse the repository at this point in the history
…e nullable
  • Loading branch information
raj-nimble committed Oct 28, 2024
1 parent 4b4e7d2 commit 4a0ad54
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 23 deletions.
1 change: 1 addition & 0 deletions serde_arrow/src/internal/schema/extensions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod variable_shape_tensor_field;

pub use bool8_field::Bool8Field;
pub use fixed_shape_tensor_field::FixedShapeTensorField;
pub(crate) use utils::fix_dictionaries;
pub use variable_shape_tensor_field::VariableShapeTensorField;

const _: () = {
Expand Down
15 changes: 14 additions & 1 deletion serde_arrow/src/internal/schema/extensions/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::internal::error::{fail, Result};
use crate::internal::{
arrow::{DataType, Field},
error::{fail, Result},
};

pub fn check_dim_names(ndim: usize, dim_names: &[String]) -> Result<()> {
if dim_names.len() != ndim {
Expand Down Expand Up @@ -56,3 +59,13 @@ impl<T: std::fmt::Debug> std::fmt::Display for DebugRepr<T> {
write!(f, "{:?}", self.0)
}
}

pub(crate) fn fix_dictionaries(field: &mut Field) {
if matches!(field.data_type, DataType::Dictionary(_, _, _)) {
field.nullable = true;
} else if let DataType::Struct(children) = &mut field.data_type {
for child in children {
fix_dictionaries(child);
}
}
}
53 changes: 32 additions & 21 deletions serde_arrow/src/internal/schema/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::internal::{
arrow::{DataType, Field, UnionMode},
error::{fail, set_default, Context, Result},
schema::{
DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy, TracingMode, TracingOptions,
STRATEGY_KEY,
extensions::fix_dictionaries, DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy,
TracingMode, TracingOptions, STRATEGY_KEY,
},
};

Expand Down Expand Up @@ -101,6 +101,30 @@ impl Tracer {
Self::Unknown(UnknownTracer::new(name, path, options))
}

fn schema_tracing_error(
failed_data_type: impl std::fmt::Display,
tracing_mode: TracingMode,
) -> Result<SerdeArrowSchema> {
fail!(
concat!(
"Schema tracing is not directly supported for the root data type {failed_data_type}. ",
"Only struct-like types are supported as root types in schema tracing. ",
"{mitigation}",
),
failed_data_type = failed_data_type,
mitigation = match tracing_mode {
TracingMode::FromType => {
"Consider using the `Item` wrapper, i.e., `::from_type<Item<T>>()`."
}
TracingMode::FromSamples => {
"Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`."
}
TracingMode::Unknown => "Consider using the `Item` / `Items` wrappers.",
},

)
}

/// Convert the traced schema into a schema object
pub fn to_schema(&self) -> Result<SerdeArrowSchema> {
let root = self.to_field()?;
Expand All @@ -113,30 +137,16 @@ impl Tracer {

let fields = match root.data_type {
DataType::Struct(children) => {
if let Some(strategy) = root
.metadata.get(STRATEGY_KEY) {
if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() {
// TODO: combine with fail messaging below
fail!("Schema tracing is not directly supported for the root data Union. Consider using the `Item` / `Items` wrappers.");
if let Some(strategy) = root.metadata.get(STRATEGY_KEY) {
if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() {
return Self::schema_tracing_error("Union", tracing_mode);
}
}

children
}
DataType::Null => fail!("No records found to determine schema"),
dt => fail!(
concat!(
"Schema tracing is not directly supported for the root data type {dt}. ",
"Only struct-like types are supported as root types in schema tracing. ",
"{mitigation}",
),
dt = DataTypeDisplay(&dt),
mitigation = match tracing_mode {
TracingMode::FromType => "Consider using the `Item` wrapper, i.e., `::from_type<Item<T>>()`.",
TracingMode::FromSamples => "Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`.",
TracingMode::Unknown => "Consider using the `Item` / `Items` wrappers.",
},
),
dt => return Self::schema_tracing_error(DataTypeDisplay(&dt), tracing_mode),
};

Ok(SerdeArrowSchema { fields })
Expand Down Expand Up @@ -1094,7 +1104,8 @@ impl UnionTracer {
if let Some(variant) = variant {
let schema = variant.tracer.to_schema()?;
for field in schema.fields {
let flat_field = field.to_flattened_union_field(variant.name.as_str());
let mut flat_field = field.to_flattened_union_field(variant.name.as_str());
fix_dictionaries(&mut flat_field);
fields.insert(flat_field.name.to_string(), flat_field);
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion serde_arrow/src/test_with_arrow/impls/flattened_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ fn test_flattened_union_with_nested_enum() {

let array = &arrays[0];

let Array::Struct(ref struct_array) = array else {
let Array::Struct(ref _struct_array) = array else {
panic!("expected a struct array, found {array:#?}");
};

Expand Down

0 comments on commit 4a0ad54

Please sign in to comment.