Skip to content

Commit

Permalink
add flattened union builder
Browse files Browse the repository at this point in the history
  • Loading branch information
raj-nimble committed Sep 18, 2024
1 parent 7d66061 commit 5a815e3
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 6 deletions.
14 changes: 14 additions & 0 deletions serde_arrow/src/internal/arrow/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ pub struct Field {
pub metadata: HashMap<String, String>,
}

impl Field {
pub fn from_flattened_enum(&self) -> bool {
self.name.contains("::")
}

pub fn enum_variant_name(&self) -> Option<&str> {
if self.from_flattened_enum() {
self.name.split("::").next()
} else {
None
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum DataType {
Expand Down
12 changes: 11 additions & 1 deletion serde_arrow/src/internal/schema/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,17 @@ impl Tracer {
let tracing_mode = dispatch_tracer!(self, tracer => tracer.options.tracing_mode);

let fields = match root.data_type {
DataType::Struct(children) => children,
DataType::Struct(children) => {
if let Some(strategy) = root
.metadata.get(STRATEGY_KEY) {
if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() {
// TODO: combine with fail messaging below
fail!("Schema tracing is not directly supported for the root data Union. Consider using the `Item` / `Items` wrappers.");
}
}

children
}
DataType::Null => fail!("No records found to determine schema"),
dt => fail!(
concat!(
Expand Down
10 changes: 6 additions & 4 deletions serde_arrow/src/internal/serialization/array_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use super::{
date64_builder::Date64Builder, decimal_builder::DecimalBuilder,
dictionary_utf8_builder::DictionaryUtf8Builder, duration_builder::DurationBuilder,
fixed_size_binary_builder::FixedSizeBinaryBuilder,
fixed_size_list_builder::FixedSizeListBuilder, float_builder::FloatBuilder,
int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder,
null_builder::NullBuilder, simple_serializer::SimpleSerializer, struct_builder::StructBuilder,
time_builder::TimeBuilder, union_builder::UnionBuilder,
fixed_size_list_builder::FixedSizeListBuilder, flattened_union_builder::FlattenedUnionBuilder,
float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder,
map_builder::MapBuilder, null_builder::NullBuilder, simple_serializer::SimpleSerializer,
struct_builder::StructBuilder, time_builder::TimeBuilder, union_builder::UnionBuilder,
unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder,
};

Expand Down Expand Up @@ -53,6 +53,7 @@ pub enum ArrayBuilder {
LargeUtf8(Utf8Builder<i64>),
DictionaryUtf8(DictionaryUtf8Builder),
Union(UnionBuilder),
FlattenedUnion(FlattenedUnionBuilder),
UnknownVariant(UnknownVariantBuilder),
}

Expand Down Expand Up @@ -90,6 +91,7 @@ macro_rules! dispatch {
$wrapper::Struct($name) => $expr,
$wrapper::DictionaryUtf8($name) => $expr,
$wrapper::Union($name) => $expr,
$wrapper::FlattenedUnion($name) => $expr,
$wrapper::UnknownVariant($name) => $expr,
}
};
Expand Down
173 changes: 173 additions & 0 deletions serde_arrow/src/internal/serialization/flattened_union_builder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use std::collections::BTreeMap;

use crate::internal::{
arrow::{Array, StructArray},
error::{fail, set_default, try_, Context, ContextSupport, Result},
utils::array_ext::{ArrayExt, CountArray, SeqArrayExt},
};

use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer};

#[derive(Debug, Clone)]
pub struct FlattenedUnionBuilder {
pub path: String,
pub fields: Vec<ArrayBuilder>,
pub seq: CountArray,
}

impl FlattenedUnionBuilder {
pub fn new(path: String, fields: Vec<ArrayBuilder>) -> Self {
Self {
path,
fields,
seq: CountArray::new(true),
}
}

pub fn take_self(&mut self) -> Self {
Self {
path: self.path.clone(),
fields: self.fields.clone(),
seq: self.seq.take(),
}
}

pub fn take(&mut self) -> ArrayBuilder {
ArrayBuilder::FlattenedUnion(self.take_self())
}

pub fn is_nullable(&self) -> bool {
self.seq.validity.is_some()
}

pub fn into_array(self) -> Result<Array> {
let mut fields = Vec::new();

for builder in self.fields.into_iter() {
let ArrayBuilder::Struct(builder) = builder else {
fail!("enum variant not built as a struct")
};

for (sub_builder, sub_meta) in builder.fields.into_iter() {
fields.push((sub_builder.into_array()?, sub_meta));
}
}

Ok(Array::Struct(StructArray {
len: fields.len(),
validity: self.seq.validity,
fields,
}))
}
}

impl FlattenedUnionBuilder {
pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> {
// self.len += 1;

let variant_index = variant_index as usize;

// call push_none for any variant that was not selected
for (idx, builder) in self.fields.iter_mut().enumerate() {
if idx != variant_index {
builder.serialize_none()?;
self.seq.push_seq_none()?;
}
}

let Some(variant_builder) = self.fields.get_mut(variant_index) else {
fail!("Could not find variant {variant_index} in Union");
};

Ok(variant_builder)
}
}

impl Context for FlattenedUnionBuilder {
fn annotate(&self, annotations: &mut BTreeMap<String, String>) {
set_default(annotations, "field", &self.path);
set_default(annotations, "data_type", "Union(..)");
}
}

impl SimpleSerializer for FlattenedUnionBuilder {
// fn serialize_unit_variant(
// &mut self,
// _: &'static str,
// variant_index: u32,
// _: &'static str,
// ) -> Result<()> {
// let mut ctx = BTreeMap::new();
// self.annotate(&mut ctx);

// try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx)
// }

// fn serialize_newtype_variant<V: serde::Serialize + ?Sized>(
// &mut self,
// _: &'static str,
// variant_index: u32,
// _: &'static str,
// value: &V,
// ) -> Result<()> {
// let mut ctx = BTreeMap::new();
// self.annotate(&mut ctx);

// try_(|| {
// let variant_builder = self.serialize_variant(variant_index)?;
// value.serialize(Mut(variant_builder))
// })
// .ctx(&ctx)
// }

fn serialize_struct_variant_start<'this>(
&'this mut self,
_: &'static str,
variant_index: u32,
variant: &'static str,
len: usize,
) -> Result<&'this mut ArrayBuilder> {
let mut ctx = BTreeMap::new();
self.annotate(&mut ctx);
self.seq.start_seq()?;
self.seq.push_seq_elements(1)?;

try_(|| {
let variant_builder = self.serialize_variant(variant_index)?;
variant_builder.serialize_struct_start(variant, len)?;
Ok(variant_builder)
})
.ctx(&ctx)
}

// fn serialize_tuple_variant_start<'this>(
// &'this mut self,
// _: &'static str,
// variant_index: u32,
// variant: &'static str,
// len: usize,
// ) -> Result<&'this mut ArrayBuilder> {
// let mut ctx = BTreeMap::new();
// self.annotate(&mut ctx);

// try_(|| {
// let variant_builder = self.serialize_variant(variant_index)?;
// variant_builder.serialize_tuple_struct_start(variant, len)?;
// Ok(variant_builder)
// })
// .ctx(&ctx)
// }
}

// #[cfg(test)]
// mod tests {
// fn test_serialize_union() {
// #[derive(Serialize, Deserialize)]
// enum Number {
// Real { value: f32 },
// Complex { i: f32, j: f32 },
// }

// let numbers = vec![];
// }
// }
1 change: 1 addition & 0 deletions serde_arrow/src/internal/serialization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod dictionary_utf8_builder;
pub mod duration_builder;
pub mod fixed_size_binary_builder;
pub mod fixed_size_list_builder;
pub mod flattened_union_builder;
pub mod float_builder;
pub mod int_builder;
pub mod list_builder;
Expand Down
30 changes: 29 additions & 1 deletion serde_arrow/src/internal/serialization/outer_sequence_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::internal::{
binary_builder::BinaryBuilder, duration_builder::DurationBuilder,
fixed_size_binary_builder::FixedSizeBinaryBuilder,
fixed_size_list_builder::FixedSizeListBuilder,
flattened_union_builder::FlattenedUnionBuilder,
},
utils::{btree_map, meta_from_field, ChildName, Mut},
};
Expand Down Expand Up @@ -226,7 +227,34 @@ fn build_builder(path: String, field: &Field) -> Result<ArrayBuilder> {
.ctx(&ctx)?,
)
}
T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?),
T::Struct(children) => {
if let Some(Strategy::EnumsWithNamedFieldsAsStructs) =
get_strategy_from_metadata(&field.metadata)?
{
let mut related_fields: HashMap<&str, Vec<Field>> = HashMap::new();
let mut builders: Vec<ArrayBuilder> = Vec::new();

for field in children {
let Some(variant_name) = field.enum_variant_name() else {
// TODO: warning? fail! ?
continue;
};
related_fields
.entry(variant_name)
.or_default()
.push(field.clone());
}

for (variant_name, fields) in related_fields {
let sub_struct_name = format!("{}.{}", path, variant_name);
builders.push(build_struct(sub_struct_name, fields.as_slice(), true)?.take());
}

A::FlattenedUnion(FlattenedUnionBuilder::new(path, builders))
} else {
A::Struct(build_struct(path, children, field.nullable)?)
}
}
T::Dictionary(key, value, _) => {
let key_path = format!("{path}.key");
let key_field = Field {
Expand Down

0 comments on commit 5a815e3

Please sign in to comment.