Skip to content

Commit

Permalink
wip: flattened union builder first working version
Browse files Browse the repository at this point in the history
  • Loading branch information
raj-nimble committed Sep 18, 2024
1 parent 5a815e3 commit 49d3cfd
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 96 deletions.
37 changes: 34 additions & 3 deletions serde_arrow/src/internal/arrow/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,49 @@ pub struct Field {
pub metadata: HashMap<String, String>,
}

impl PartialOrd for Field {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.name.partial_cmp(&other.name)
}
}

impl Field {
pub fn from_flattened_enum(&self) -> bool {
pub fn to_flattened_union_field(mut self, variant_name: &str) -> Self {
self.name = format!("{}::{}", variant_name, self.name);
self.nullable = true;
self
}

fn from_flattened_union(&self) -> bool {
self.name.contains("::")
}

pub fn enum_variant_name(&self) -> Option<&str> {
if self.from_flattened_enum() {
pub fn union_variant_name(&self) -> Option<&str> {
if self.from_flattened_union() {
self.name.split("::").next()
} else {
None
}
}

pub fn union_field_name(&self) -> Option<String> {
if self.from_flattened_union() {
Some(
self.name
.split("::")
.skip(1)
.fold(String::new(), |acc: String, e| {
if acc.is_empty() {
String::from(e)
} else {
format!("{acc}::{e}")
}
}),
)
} else {
None
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
Expand Down
6 changes: 2 additions & 4 deletions serde_arrow/src/internal/schema/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1086,10 +1086,8 @@ impl UnionTracer {
for variant in &self.variants {
if let Some(variant) = variant {
let schema = variant.tracer.to_schema()?;
for mut field in schema.fields {
field.name = format!("{}::{}", variant.name, field.name);
field.nullable = true;
fields.push(field)
for field in schema.fields {
fields.push(field.to_flattened_union_field(variant.name.as_str()))
}
} else {
fields.push(unknown_variant_field())
Expand Down
111 changes: 30 additions & 81 deletions serde_arrow/src/internal/serialization/flattened_union_builder.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,77 @@
use std::collections::BTreeMap;

use crate::internal::{
arrow::{Array, StructArray},
arrow::{Array, FieldMeta, StructArray},
error::{fail, set_default, try_, Context, ContextSupport, Result},
utils::array_ext::{ArrayExt, CountArray, SeqArrayExt},
};

use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer};

#[derive(Debug, Clone)]
pub struct FlattenedUnionBuilder {
pub path: String,
pub fields: Vec<ArrayBuilder>,
pub seq: CountArray,
pub fields: Vec<(ArrayBuilder, FieldMeta)>,
}

impl FlattenedUnionBuilder {
pub fn new(path: String, fields: Vec<ArrayBuilder>) -> Self {
Self {
path,
fields,
seq: CountArray::new(true),
}
}

pub fn take_self(&mut self) -> Self {
Self {
path: self.path.clone(),
fields: self.fields.clone(),
seq: self.seq.take(),
}
pub fn new(path: String, fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self {
Self { path, fields }
}

pub fn take(&mut self) -> ArrayBuilder {
ArrayBuilder::FlattenedUnion(self.take_self())
ArrayBuilder::FlattenedUnion(Self {
path: self.path.clone(),
fields: self
.fields
.iter_mut()
.map(|(field, meta)| (field.take(), meta.clone()))
.collect(),
})
}

pub fn is_nullable(&self) -> bool {
self.seq.validity.is_some()
false
}

pub fn into_array(self) -> Result<Array> {
let mut fields = Vec::new();
let mut num_elements = 0;

for builder in self.fields.into_iter() {
for (builder, meta) in self.fields.into_iter() {
let ArrayBuilder::Struct(builder) = builder else {
fail!("enum variant not built as a struct")
fail!("enum variant not built as a struct") // TODO: better failure message
};

for (sub_builder, sub_meta) in builder.fields.into_iter() {
for (sub_builder, mut sub_meta) in builder.fields.into_iter() {
num_elements += 1;
// TODO: this mirrors the field name structure in the tracer but represents
// implementation details crossing boundaries. Is there another way?
// Currently necessary to allow struct field lookup to work correctly.
sub_meta.name = format!("{}::{}", meta.name, sub_meta.name);
fields.push((sub_builder.into_array()?, sub_meta));
}
}

Ok(Array::Struct(StructArray {
len: fields.len(),
validity: self.seq.validity,
len: num_elements,
validity: None, // TODO: is this ok?
fields,
}))
}
}

impl FlattenedUnionBuilder {
pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> {
// self.len += 1;

let variant_index = variant_index as usize;

// call push_none for any variant that was not selected
for (idx, builder) in self.fields.iter_mut().enumerate() {
// don't serialize any variant not selected
for (idx, (builder, _meta)) in self.fields.iter_mut().enumerate() {
if idx != variant_index {
builder.serialize_none()?;
self.seq.push_seq_none()?;
}
}

let Some(variant_builder) = self.fields.get_mut(variant_index) else {
let Some((variant_builder, _variant_meta)) = self.fields.get_mut(variant_index) else {
fail!("Could not find variant {variant_index} in Union");
};

Expand All @@ -86,40 +82,11 @@ impl FlattenedUnionBuilder {
impl Context for FlattenedUnionBuilder {
fn annotate(&self, annotations: &mut BTreeMap<String, String>) {
set_default(annotations, "field", &self.path);
set_default(annotations, "data_type", "Union(..)");
set_default(annotations, "data_type", "Struct(..)");
}
}

impl SimpleSerializer for FlattenedUnionBuilder {
// fn serialize_unit_variant(
// &mut self,
// _: &'static str,
// variant_index: u32,
// _: &'static str,
// ) -> Result<()> {
// let mut ctx = BTreeMap::new();
// self.annotate(&mut ctx);

// try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx)
// }

// fn serialize_newtype_variant<V: serde::Serialize + ?Sized>(
// &mut self,
// _: &'static str,
// variant_index: u32,
// _: &'static str,
// value: &V,
// ) -> Result<()> {
// let mut ctx = BTreeMap::new();
// self.annotate(&mut ctx);

// try_(|| {
// let variant_builder = self.serialize_variant(variant_index)?;
// value.serialize(Mut(variant_builder))
// })
// .ctx(&ctx)
// }

fn serialize_struct_variant_start<'this>(
&'this mut self,
_: &'static str,
Expand All @@ -129,8 +96,6 @@ impl SimpleSerializer for FlattenedUnionBuilder {
) -> Result<&'this mut ArrayBuilder> {
let mut ctx = BTreeMap::new();
self.annotate(&mut ctx);
self.seq.start_seq()?;
self.seq.push_seq_elements(1)?;

try_(|| {
let variant_builder = self.serialize_variant(variant_index)?;
Expand All @@ -139,26 +104,10 @@ impl SimpleSerializer for FlattenedUnionBuilder {
})
.ctx(&ctx)
}

// fn serialize_tuple_variant_start<'this>(
// &'this mut self,
// _: &'static str,
// variant_index: u32,
// variant: &'static str,
// len: usize,
// ) -> Result<&'this mut ArrayBuilder> {
// let mut ctx = BTreeMap::new();
// self.annotate(&mut ctx);

// try_(|| {
// let variant_builder = self.serialize_variant(variant_index)?;
// variant_builder.serialize_tuple_struct_start(variant, len)?;
// Ok(variant_builder)
// })
// .ctx(&ctx)
// }
}

// TODO: add tests

// #[cfg(test)]
// mod tests {
// fn test_serialize_union() {
Expand Down
34 changes: 26 additions & 8 deletions serde_arrow/src/internal/serialization/outer_sequence_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, HashMap};
use serde::Serialize;

use crate::internal::{
arrow::{DataType, Field, TimeUnit},
arrow::{DataType, Field, FieldMeta, TimeUnit},
error::{fail, Context, ContextSupport, Result},
schema::{get_strategy_from_metadata, SerdeArrowSchema, Strategy},
serialization::{
Expand Down Expand Up @@ -231,23 +231,41 @@ fn build_builder(path: String, field: &Field) -> Result<ArrayBuilder> {
if let Some(Strategy::EnumsWithNamedFieldsAsStructs) =
get_strategy_from_metadata(&field.metadata)?
{
let mut related_fields: HashMap<&str, Vec<Field>> = HashMap::new();
let mut builders: Vec<ArrayBuilder> = Vec::new();
let mut related_fields: BTreeMap<&str, Vec<Field>> = BTreeMap::new();
let mut builders: Vec<(ArrayBuilder, FieldMeta)> = Vec::new();

for field in children {
let Some(variant_name) = field.enum_variant_name() else {
// TODO: warning? fail! ?
let Some(variant_name) = field.union_variant_name() else {
// TODO: failure message
continue;
};

let Some(field_name) = field.union_field_name() else {
// TODO: failure message
continue;
};

let mut new_field = field.clone();
new_field.name = field_name;

related_fields
.entry(variant_name)
.or_default()
.push(field.clone());
.push(new_field);
}

for (variant_name, fields) in related_fields {
let sub_struct_name = format!("{}.{}", path, variant_name);
builders.push(build_struct(sub_struct_name, fields.as_slice(), true)?.take());
let builder = build_struct(
format!("{}.{}", path.to_owned(), variant_name),
fields.as_slice(),
true,
)?
.take();

let mut meta = meta_from_field(field.clone());
meta.name = variant_name.to_owned();

builders.push((builder, meta));
}

A::FlattenedUnion(FlattenedUnionBuilder::new(path, builders))
Expand Down

0 comments on commit 49d3cfd

Please sign in to comment.