From 9f89dcabe10c30e9587821871ba96088679cb124 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Wed, 14 Aug 2024 17:35:07 -0700 Subject: [PATCH 01/19] add option to TracingOptions for enums with data to be flattened to structs --- .../src/internal/schema/tracing_options.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index 86d09474..b0dc8f9b 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -220,6 +220,15 @@ pub struct TracingOptions { /// Internal field to improve error messages for the different tracing /// functions pub(crate) tracing_mode: TracingMode, + + /// Whether to encode enums with data as structs + /// + /// If `false` enums with data are encoded as Union arrays. + /// If `true` enums with data are encoded as Structs. + /// + /// TODO: example + /// ``` + pub enums_with_data_as_structs: bool, } impl Default for TracingOptions { @@ -232,6 +241,7 @@ impl Default for TracingOptions { guess_dates: false, from_type_budget: 100, enums_without_data_as_strings: false, + enums_with_data_as_structs: false, overwrites: Overwrites::default(), sequence_as_large_list: true, string_as_large_utf8: true, @@ -299,6 +309,12 @@ impl TracingOptions { self } + /// Set [`enums_with_data_as_structs`](#structfield.enums_with_data_as_structs) + pub fn enums_with_data_as_structs(mut self, value: bool) -> Self { + self.enums_with_data_as_structs = value; + self + } + /// Add an overwrite to [`overwrites`](#structfield.overwrites) pub fn overwrite, F: Serialize>(mut self, path: P, field: F) -> Result { self.overwrites.0.insert( From 0a84386a4349dd5f91f0e325ba4a72fa3af4540c Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Wed, 14 Aug 2024 17:35:49 -0700 Subject: [PATCH 02/19] add test cases in from_samples to repro enums as flattened structs --- .../src/internal/schema/from_samples/mod.rs | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index ede6942c..9bfd6032 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -820,6 +820,13 @@ mod test { use super::*; + /// Dummy enum used only for testing + #[derive(Serialize)] + enum Number { + Real { value: f32 }, + Complex { i: f32, j: f32 }, + } + fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); @@ -920,4 +927,131 @@ mod test { expected, ); } + + #[test] + fn example_enum_as_union() { + let expected = json!({ + "name": "$", + "data_type": "Union", + "children": [ + { + "name": "Real", + "data_type": "Struct", + "children": [ + { + "name": "value", + "data_type": "F32", + } + ] + }, + { + "name": "Complex", + "data_type": "Struct", + "children": [ + { + "name": "i", + "data_type": "F32" + }, + { + "name": "j", + "data_type": "F32" + } + ] + } + ] + }); + + test_to_tracer( + &[ + Number::Real { value: 1.0 }, + Number::Complex { i: 0.5, j: 0.5 }, + ], + TracingOptions::default(), + expected, + ); + } + + #[test] + fn example_enum_as_struct() { + let expected = json!({ + "name": "$", + "data_type": "Struct", + "children": [ + { + "name": "real_value", + "data_type": "F32", + "nullable": true + }, + { + "name": "complex_i", + "data_type": "F32", + "nullable": true + }, + { + "name": "complex_j", + "data_type": "F32", + "nullable": true + } + ] + }); + + let opts = TracingOptions::default().enums_with_data_as_structs(true); + + test_to_tracer( + &[ + Number::Real { value: 1.0 }, + Number::Complex { i: 0.5, j: 0.5 }, + ], + opts, + expected, + ); + } + + #[test] + fn example_struct_with_nullable_fields() { + #[derive(Serialize, Default)] + struct Number { + real_value: Option, + complex_i: Option, + complex_j: Option, + } + + let expected = json!({ + "name": "$", + "data_type": "Struct", + "children": [ + { + "name": "real_value", + "data_type": "F32", + "nullable": true + }, + { + "name": "complex_i", + "data_type": "F32", + "nullable": true + }, + { + "name": "complex_j", + "data_type": "F32", + "nullable": true + } + ] + }); + + test_to_tracer( + &[ + Number { + real_value: Some(1.0), + ..Default::default() + }, + Number { + complex_i: Some(0.5), + complex_j: Some(0.5), + ..Default::default() + }, + ], + TracingOptions::default(), + expected, + ); + } } From e1d617a9b8b801f9a7466029eee3705a76c571de Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Wed, 14 Aug 2024 17:36:21 -0700 Subject: [PATCH 03/19] add tracer changes to create GenericField for enums as flattened structs --- serde_arrow/src/internal/schema/tracer.rs | 44 ++++++++++++++++++----- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index daa460af..d70a7f14 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -1065,18 +1065,46 @@ impl UnionTracer { } } - let mut fields = Vec::new(); - for (idx, variant) in self.variants.iter().enumerate() { - if let Some(variant) = variant { - fields.push((i8::try_from(idx)?, variant.tracer.to_field()?)); - } else { - fields.push((i8::try_from(idx)?, unknown_variant_field())); - }; + let data_type: DataType; + + if self.options.enums_with_data_as_structs { + let mut fields = Vec::new(); + + // For this option, we want to merge the variant children up one level, combining the names + // For each variant with name variant_name + // For each variant_field with field_name + // Add field {variant_name}_{field_name} -> variant_field.to_field() that is nullable + + for variant in &self.variants { + if let Some(variant) = variant { + // TODO: does this break if there are no child fields? + let schema = variant.tracer.to_schema()?; + for mut field in schema.fields { + field.name = format!("{}_{}", variant.name.to_lowercase(), field.name); + field.nullable = true; + fields.push(field) + } + } else { + fields.push(unknown_variant_field()) + }; + } + + data_type = DataType::Struct(fields); + } else { + let mut fields = Vec::new(); + for (idx, variant) in self.variants.iter().enumerate() { + if let Some(variant) = variant { + fields.push((i8::try_from(idx)?, variant.tracer.to_field()?)); + } else { + fields.push((i8::try_from(idx)?, unknown_variant_field())); + }; + } + data_type = DataType::Union(fields, UnionMode::Dense); } Ok(Field { name: self.name.to_owned(), - data_type: DataType::Union(fields, UnionMode::Dense), + data_type, nullable: self.nullable, metadata: HashMap::new(), }) From e35d0a9a96f2383bb19449f6ec9af9430a3cdf28 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Wed, 14 Aug 2024 18:43:17 -0700 Subject: [PATCH 04/19] change from_samples unit tests to be a single equality test --- .../src/internal/schema/from_samples/mod.rs | 149 ++++-------------- 1 file changed, 29 insertions(+), 120 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 9bfd6032..47aaf247 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -820,13 +820,6 @@ mod test { use super::*; - /// Dummy enum used only for testing - #[derive(Serialize)] - enum Number { - Real { value: f32 }, - Complex { i: f32, j: f32 }, - } - fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); @@ -929,129 +922,45 @@ mod test { } #[test] - fn example_enum_as_union() { - let expected = json!({ - "name": "$", - "data_type": "Union", - "children": [ - { - "name": "Real", - "data_type": "Struct", - "children": [ - { - "name": "value", - "data_type": "F32", - } - ] - }, - { - "name": "Complex", - "data_type": "Struct", - "children": [ - { - "name": "i", - "data_type": "F32" - }, - { - "name": "j", - "data_type": "F32" - } - ] - } - ] - }); - - test_to_tracer( - &[ - Number::Real { value: 1.0 }, - Number::Complex { i: 0.5, j: 0.5 }, - ], - TracingOptions::default(), - expected, - ); - } - - #[test] - fn example_enum_as_struct() { - let expected = json!({ - "name": "$", - "data_type": "Struct", - "children": [ - { - "name": "real_value", - "data_type": "F32", - "nullable": true - }, - { - "name": "complex_i", - "data_type": "F32", - "nullable": true - }, - { - "name": "complex_j", - "data_type": "F32", - "nullable": true - } - ] - }); - - let opts = TracingOptions::default().enums_with_data_as_structs(true); - - test_to_tracer( - &[ - Number::Real { value: 1.0 }, - Number::Complex { i: 0.5, j: 0.5 }, - ], - opts, - expected, - ); - } + fn example_enum_as_struct_equal_to_struct_with_nullable_fields() { + #[derive(Serialize)] + enum Number { + Real { value: f32 }, + Complex { i: f32, j: f32 }, + } - #[test] - fn example_struct_with_nullable_fields() { #[derive(Serialize, Default)] - struct Number { + struct StructNumber { real_value: Option, complex_i: Option, complex_j: Option, } - let expected = json!({ - "name": "$", - "data_type": "Struct", - "children": [ - { - "name": "real_value", - "data_type": "F32", - "nullable": true + let enum_items = [ + Number::Real { value: 1.0 }, + Number::Complex { i: 0.5, j: 0.5 }, + ]; + + let struct_items = [ + StructNumber { + real_value: Some(1.0), + ..Default::default() }, - { - "name": "complex_i", - "data_type": "F32", - "nullable": true + StructNumber { + complex_i: Some(0.5), + complex_j: Some(0.5), + ..Default::default() }, - { - "name": "complex_j", - "data_type": "F32", - "nullable": true - } - ] - }); + ]; - test_to_tracer( - &[ - Number { - real_value: Some(1.0), - ..Default::default() - }, - Number { - complex_i: Some(0.5), - complex_j: Some(0.5), - ..Default::default() - }, - ], - TracingOptions::default(), - expected, + let opts = TracingOptions::default().enums_with_data_as_structs(true); + + let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); + let struct_tracer = Tracer::from_samples(&struct_items, TracingOptions::default()).unwrap(); + + assert_eq!( + enum_tracer.to_field().unwrap(), + struct_tracer.to_field().unwrap() ); } } From 5d8941c069091d15598e376d79be8fe3e03dfa34 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Wed, 28 Aug 2024 19:12:14 -0700 Subject: [PATCH 05/19] add unit tests for edge cases, add basic test for from_type --- .../src/internal/schema/from_samples/mod.rs | 213 ++++++++++++++++-- .../src/internal/schema/from_type/mod.rs | 61 +++++ serde_arrow/src/internal/schema/strategy.rs | 9 + serde_arrow/src/internal/schema/tracer.rs | 14 +- .../src/internal/schema/tracing_options.rs | 10 +- 5 files changed, 275 insertions(+), 32 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 47aaf247..05aa86ad 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -813,13 +813,28 @@ mod impl_serialize_to_string { #[cfg(test)] mod test { + use std::collections::HashMap; + use serde::Serialize; use serde_json::{json, Value}; - use crate::internal::schema::{transmute_field, TracingOptions}; + use crate::{ + internal::{ + arrow::Field, + schema::{transmute_field, TracingOptions}, + }, + schema::STRATEGY_KEY, + }; use super::*; + fn enum_with_named_fields_metadata() -> HashMap { + HashMap::from([( + STRATEGY_KEY.to_string(), + Strategy::EnumsWithNamedFieldsAsStructs.to_string(), + )]) + } + fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); @@ -929,38 +944,192 @@ mod test { Complex { i: f32, j: f32 }, } - #[derive(Serialize, Default)] - struct StructNumber { - real_value: Option, - complex_i: Option, - complex_j: Option, - } - let enum_items = [ Number::Real { value: 1.0 }, Number::Complex { i: 0.5, j: 0.5 }, ]; - let struct_items = [ - StructNumber { - real_value: Some(1.0), - ..Default::default() + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); + + let expected_field = Field { + name: "$".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Real::value".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::i".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::j".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata: enum_with_named_fields_metadata(), + }; + + assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + } + + #[test] + fn example_enum_as_struct_no_fields() { + #[derive(Serialize)] + enum Coin { + Heads, + Tails, + } + + let enum_items = [Coin::Heads, Coin::Tails]; + + // This should continue to maintain previously implemented behavior, serializing as a map + let opts = TracingOptions::default() + .enums_with_named_fields_as_structs(true) + .enums_without_data_as_strings(true); + + let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); + + let expected_field = Field { + name: "$".to_string(), + data_type: DataType::Dictionary( + Box::new(DataType::UInt32), + Box::new(DataType::LargeUtf8), + false, + ), + nullable: false, + metadata: HashMap::new(), + }; + + assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + } + + #[test] + #[should_panic] + fn example_enum_as_struct_no_fields_panics_when_opts_not_set() { + #[derive(Serialize)] + enum TrafficLight { + Red, + Yellow, + Green, + } + + let enum_items = [TrafficLight::Red, TrafficLight::Yellow, TrafficLight::Green]; + + // This should continue to maintain previously implemented behavior, + // throwing an error because we detect Unions with no fields + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + + Tracer::from_samples(&enum_items, opts) + .unwrap() + .to_field() + .unwrap(); + } + + #[test] + fn example_enum_as_struct_all_fields_nullable() { + #[derive(Serialize)] + enum Optionals { + Something { + more: Option, + less: Option, + }, + Else { + one: Option, + another: Option, + }, + } + + let enum_items = [ + Optionals::Something { + more: Some(1), + less: None, + }, + Optionals::Something { + more: None, + less: Some(0), + }, + Optionals::Else { + one: None, + another: Some(0), }, - StructNumber { - complex_i: Some(0.5), - complex_j: Some(0.5), - ..Default::default() + Optionals::Else { + one: Some(1), + another: None, }, ]; - let opts = TracingOptions::default().enums_with_data_as_structs(true); + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); + + let expected_field = Field { + name: "$".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Something::more".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Something::less".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Else::one".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Else::another".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata: enum_with_named_fields_metadata(), + }; + assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + } + + #[test] + #[should_panic] + fn example_enum_as_struct_tuple_variants() { + #[derive(Serialize)] + enum Payment { + Cash(f32), // amount + Check(String, f32), // name, amount + CreditCard(String, f32, [u8; 16], String), // name, amount, cc number, exp + } + + let enum_items = [ + Payment::Cash(0.42), + Payment::Check("Bob".to_string(), 0.42), + Payment::CreditCard( + "Sue".to_string(), + 0.42, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6], + "01/2024".to_string(), + ), + ]; + + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); - let struct_tracer = Tracer::from_samples(&struct_items, TracingOptions::default()).unwrap(); - assert_eq!( - enum_tracer.to_field().unwrap(), - struct_tracer.to_field().unwrap() - ); + // Currently panics when `to_schema()` is called on the variant tracer + enum_tracer.to_field().unwrap(); } } diff --git a/serde_arrow/src/internal/schema/from_type/mod.rs b/serde_arrow/src/internal/schema/from_type/mod.rs index 38562591..193b6415 100644 --- a/serde_arrow/src/internal/schema/from_type/mod.rs +++ b/serde_arrow/src/internal/schema/from_type/mod.rs @@ -585,3 +585,64 @@ impl<'de, 'a> serde::de::Deserializer<'de> for IdentifierDeserializer<'a> { unimplemented!('de, deserialize_enum, _: &'static str, _: &'static [&'static str]); unimplemented!('de, deserialize_ignored_any); } + +#[cfg(test)] +mod test { + use std::collections::HashMap; + + use serde::{Deserialize, Serialize}; + + use crate::{ + internal::{ + arrow::{DataType, Field}, + schema::tracer::Tracer, + }, + schema::{Strategy, TracingOptions, STRATEGY_KEY}, + }; + + // TODO: combine these with the from_samples tests, dedup utility code, test all edge cases with from_type as well + #[test] + fn example_enum_as_struct_equal_to_struct_with_nullable_fields() { + // TODO: dedup? + #[derive(Serialize, Deserialize)] + enum Number { + Real { value: f32 }, + Complex { i: f32, j: f32 }, + } + + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + let enum_tracer = Tracer::from_type::(opts).unwrap(); + let metadata = HashMap::from([( + STRATEGY_KEY.to_string(), + Strategy::EnumsWithNamedFieldsAsStructs.to_string(), + )]); + + let expected_field = Field { + name: "$".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Real::value".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::i".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::j".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata, + }; + + assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + } +} diff --git a/serde_arrow/src/internal/schema/strategy.rs b/serde_arrow/src/internal/schema/strategy.rs index 29162d64..c79e2347 100644 --- a/serde_arrow/src/internal/schema/strategy.rs +++ b/serde_arrow/src/internal/schema/strategy.rs @@ -63,6 +63,14 @@ pub enum Strategy { /// polars does not support them) /// MapAsStruct, + /// Serialize Rust enums that contain named field data as flattened structs + /// + /// This strategy is a workaround for the fact that Unions are not supported in parquet. + /// Currently, only Serialization is supported. Deserialization is not. + /// When writing out the enum, it will be flattened into a Struct with + /// a list of Fields, where the names of those Fields are the field prefixed with + /// the name of the variant. + EnumsWithNamedFieldsAsStructs, /// Mark a variant as unknown /// /// This strategy applies only to fields with DataType Null. If @@ -79,6 +87,7 @@ impl std::fmt::Display for Strategy { Self::NaiveStrAsDate64 => write!(f, "NaiveStrAsDate64"), Self::TupleAsStruct => write!(f, "TupleAsStruct"), Self::MapAsStruct => write!(f, "MapAsStruct"), + Self::EnumsWithNamedFieldsAsStructs => write!(f, "EnumsWithNamedFieldsAsStructs"), Self::UnknownVariant => write!(f, "UnknownVariant"), } } diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index d70a7f14..02d177b3 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -1066,21 +1066,25 @@ impl UnionTracer { } let data_type: DataType; + let mut metadata = HashMap::new(); - if self.options.enums_with_data_as_structs { + if self.options.enums_with_named_fields_as_structs { + metadata.insert( + STRATEGY_KEY.to_string(), + Strategy::EnumsWithNamedFieldsAsStructs.to_string(), + ); let mut fields = Vec::new(); // For this option, we want to merge the variant children up one level, combining the names // For each variant with name variant_name // For each variant_field with field_name - // Add field {variant_name}_{field_name} -> variant_field.to_field() that is nullable + // Add field {variant_name}::{field_name} -> variant_field.to_field() that is nullable for variant in &self.variants { if let Some(variant) = variant { - // TODO: does this break if there are no child fields? let schema = variant.tracer.to_schema()?; for mut field in schema.fields { - field.name = format!("{}_{}", variant.name.to_lowercase(), field.name); + field.name = format!("{}::{}", variant.name, field.name); field.nullable = true; fields.push(field) } @@ -1106,7 +1110,7 @@ impl UnionTracer { name: self.name.to_owned(), data_type, nullable: self.nullable, - metadata: HashMap::new(), + metadata, }) } diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index b0dc8f9b..21847792 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -228,7 +228,7 @@ pub struct TracingOptions { /// /// TODO: example /// ``` - pub enums_with_data_as_structs: bool, + pub enums_with_named_fields_as_structs: bool, } impl Default for TracingOptions { @@ -241,7 +241,7 @@ impl Default for TracingOptions { guess_dates: false, from_type_budget: 100, enums_without_data_as_strings: false, - enums_with_data_as_structs: false, + enums_with_named_fields_as_structs: false, overwrites: Overwrites::default(), sequence_as_large_list: true, string_as_large_utf8: true, @@ -309,9 +309,9 @@ impl TracingOptions { self } - /// Set [`enums_with_data_as_structs`](#structfield.enums_with_data_as_structs) - pub fn enums_with_data_as_structs(mut self, value: bool) -> Self { - self.enums_with_data_as_structs = value; + /// Set [`enums_with_named_fields_as_structs`](#structfield.enums_with_named_fields_as_structs) + pub fn enums_with_named_fields_as_structs(mut self, value: bool) -> Self { + self.enums_with_named_fields_as_structs = value; self } From ec1ae8ed56754b1d39299d25f1c392c77db7d779 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Tue, 10 Sep 2024 18:54:17 -0700 Subject: [PATCH 06/19] missed an addition to internal schema checks --- serde_arrow/src/internal/schema/mod.rs | 5 ++++- serde_arrow/src/internal/schema/strategy.rs | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/schema/mod.rs b/serde_arrow/src/internal/schema/mod.rs index 830b2ce4..5f6b29c6 100644 --- a/serde_arrow/src/internal/schema/mod.rs +++ b/serde_arrow/src/internal/schema/mod.rs @@ -482,7 +482,10 @@ fn validate_time64_field(field: &Field, unit: TimeUnit) -> Result<()> { fn validate_struct_field(field: &Field, children: &[Field]) -> Result<()> { // NOTE: do not check number of children: arrow-rs can 0 children, arrow2 not match get_strategy_from_metadata(&field.metadata)? { - None | Some(Strategy::MapAsStruct) | Some(Strategy::TupleAsStruct) => {} + None + | Some(Strategy::MapAsStruct) + | Some(Strategy::TupleAsStruct) + | Some(Strategy::EnumsWithNamedFieldsAsStructs) => {} Some(strategy) => fail!("invalid strategy for Struct field: {strategy}"), } for child in children { diff --git a/serde_arrow/src/internal/schema/strategy.rs b/serde_arrow/src/internal/schema/strategy.rs index c79e2347..1f743de1 100644 --- a/serde_arrow/src/internal/schema/strategy.rs +++ b/serde_arrow/src/internal/schema/strategy.rs @@ -117,6 +117,7 @@ impl FromStr for Strategy { "NaiveStrAsDate64" => Ok(Self::NaiveStrAsDate64), "TupleAsStruct" => Ok(Self::TupleAsStruct), "MapAsStruct" => Ok(Self::MapAsStruct), + "EnumsWithNamedFieldsAsStructs" => Ok(Self::EnumsWithNamedFieldsAsStructs), "UnknownVariant" => Ok(Self::UnknownVariant), _ => fail!("Unknown strategy {s}"), } From 5cfcd01b01c6b0957ccfc78f5d0c6a0a2a92d5d8 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Tue, 10 Sep 2024 19:06:19 -0700 Subject: [PATCH 07/19] typo fix --- serde_arrow/src/internal/serialization/simple_serializer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/serialization/simple_serializer.rs b/serde_arrow/src/internal/serialization/simple_serializer.rs index 80183e0b..2de770b2 100644 --- a/serde_arrow/src/internal/serialization/simple_serializer.rs +++ b/serde_arrow/src/internal/serialization/simple_serializer.rs @@ -169,7 +169,7 @@ pub trait SimpleSerializer: Sized + Context { fn serialize_struct_start(&mut self, name: &'static str, len: usize) -> Result<()> { fail!( in self, - "serialize_start_start is not supported", + "serialize_struct_start is not supported", ) } From 52fecdd766915e46a0513b72c295bb44ed65b5a3 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Tue, 17 Sep 2024 17:54:14 -0700 Subject: [PATCH 08/19] add flattened union builder --- serde_arrow/src/internal/arrow/data_type.rs | 14 ++ serde_arrow/src/internal/schema/tracer.rs | 12 +- .../internal/serialization/array_builder.rs | 10 +- .../serialization/flattened_union_builder.rs | 173 ++++++++++++++++++ serde_arrow/src/internal/serialization/mod.rs | 1 + .../serialization/outer_sequence_builder.rs | 30 ++- 6 files changed, 234 insertions(+), 6 deletions(-) create mode 100644 serde_arrow/src/internal/serialization/flattened_union_builder.rs diff --git a/serde_arrow/src/internal/arrow/data_type.rs b/serde_arrow/src/internal/arrow/data_type.rs index a17405bd..7f1cd180 100644 --- a/serde_arrow/src/internal/arrow/data_type.rs +++ b/serde_arrow/src/internal/arrow/data_type.rs @@ -12,6 +12,20 @@ pub struct Field { pub metadata: HashMap, } +impl Field { + pub fn from_flattened_enum(&self) -> bool { + self.name.contains("::") + } + + pub fn enum_variant_name(&self) -> Option<&str> { + if self.from_flattened_enum() { + self.name.split("::").next() + } else { + None + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[non_exhaustive] pub enum DataType { diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 02d177b3..08f51daa 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -112,7 +112,17 @@ impl Tracer { let tracing_mode = dispatch_tracer!(self, tracer => tracer.options.tracing_mode); let fields = match root.data_type { - DataType::Struct(children) => children, + DataType::Struct(children) => { + if let Some(strategy) = root + .metadata.get(STRATEGY_KEY) { + if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() { + // TODO: combine with fail messaging below + fail!("Schema tracing is not directly supported for the root data Union. Consider using the `Item` / `Items` wrappers."); + } + } + + children + } DataType::Null => fail!("No records found to determine schema"), dt => fail!( concat!( diff --git a/serde_arrow/src/internal/serialization/array_builder.rs b/serde_arrow/src/internal/serialization/array_builder.rs index 75eb6ef5..7d3b2b68 100644 --- a/serde_arrow/src/internal/serialization/array_builder.rs +++ b/serde_arrow/src/internal/serialization/array_builder.rs @@ -13,10 +13,10 @@ use super::{ date64_builder::Date64Builder, decimal_builder::DecimalBuilder, dictionary_utf8_builder::DictionaryUtf8Builder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, - fixed_size_list_builder::FixedSizeListBuilder, float_builder::FloatBuilder, - int_builder::IntBuilder, list_builder::ListBuilder, map_builder::MapBuilder, - null_builder::NullBuilder, simple_serializer::SimpleSerializer, struct_builder::StructBuilder, - time_builder::TimeBuilder, union_builder::UnionBuilder, + fixed_size_list_builder::FixedSizeListBuilder, flattened_union_builder::FlattenedUnionBuilder, + float_builder::FloatBuilder, int_builder::IntBuilder, list_builder::ListBuilder, + map_builder::MapBuilder, null_builder::NullBuilder, simple_serializer::SimpleSerializer, + struct_builder::StructBuilder, time_builder::TimeBuilder, union_builder::UnionBuilder, unknown_variant_builder::UnknownVariantBuilder, utf8_builder::Utf8Builder, }; @@ -53,6 +53,7 @@ pub enum ArrayBuilder { LargeUtf8(Utf8Builder), DictionaryUtf8(DictionaryUtf8Builder), Union(UnionBuilder), + FlattenedUnion(FlattenedUnionBuilder), UnknownVariant(UnknownVariantBuilder), } @@ -90,6 +91,7 @@ macro_rules! dispatch { $wrapper::Struct($name) => $expr, $wrapper::DictionaryUtf8($name) => $expr, $wrapper::Union($name) => $expr, + $wrapper::FlattenedUnion($name) => $expr, $wrapper::UnknownVariant($name) => $expr, } }; diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs new file mode 100644 index 00000000..047ce438 --- /dev/null +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -0,0 +1,173 @@ +use std::collections::BTreeMap; + +use crate::internal::{ + arrow::{Array, StructArray}, + error::{fail, set_default, try_, Context, ContextSupport, Result}, + utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, +}; + +use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; + +#[derive(Debug, Clone)] +pub struct FlattenedUnionBuilder { + pub path: String, + pub fields: Vec, + pub seq: CountArray, +} + +impl FlattenedUnionBuilder { + pub fn new(path: String, fields: Vec) -> Self { + Self { + path, + fields, + seq: CountArray::new(true), + } + } + + pub fn take_self(&mut self) -> Self { + Self { + path: self.path.clone(), + fields: self.fields.clone(), + seq: self.seq.take(), + } + } + + pub fn take(&mut self) -> ArrayBuilder { + ArrayBuilder::FlattenedUnion(self.take_self()) + } + + pub fn is_nullable(&self) -> bool { + self.seq.validity.is_some() + } + + pub fn into_array(self) -> Result { + let mut fields = Vec::new(); + + for builder in self.fields.into_iter() { + let ArrayBuilder::Struct(builder) = builder else { + fail!("enum variant not built as a struct") + }; + + for (sub_builder, sub_meta) in builder.fields.into_iter() { + fields.push((sub_builder.into_array()?, sub_meta)); + } + } + + Ok(Array::Struct(StructArray { + len: fields.len(), + validity: self.seq.validity, + fields, + })) + } +} + +impl FlattenedUnionBuilder { + pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> { + // self.len += 1; + + let variant_index = variant_index as usize; + + // call push_none for any variant that was not selected + for (idx, builder) in self.fields.iter_mut().enumerate() { + if idx != variant_index { + builder.serialize_none()?; + self.seq.push_seq_none()?; + } + } + + let Some(variant_builder) = self.fields.get_mut(variant_index) else { + fail!("Could not find variant {variant_index} in Union"); + }; + + Ok(variant_builder) + } +} + +impl Context for FlattenedUnionBuilder { + fn annotate(&self, annotations: &mut BTreeMap) { + set_default(annotations, "field", &self.path); + set_default(annotations, "data_type", "Union(..)"); + } +} + +impl SimpleSerializer for FlattenedUnionBuilder { + // fn serialize_unit_variant( + // &mut self, + // _: &'static str, + // variant_index: u32, + // _: &'static str, + // ) -> Result<()> { + // let mut ctx = BTreeMap::new(); + // self.annotate(&mut ctx); + + // try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx) + // } + + // fn serialize_newtype_variant( + // &mut self, + // _: &'static str, + // variant_index: u32, + // _: &'static str, + // value: &V, + // ) -> Result<()> { + // let mut ctx = BTreeMap::new(); + // self.annotate(&mut ctx); + + // try_(|| { + // let variant_builder = self.serialize_variant(variant_index)?; + // value.serialize(Mut(variant_builder)) + // }) + // .ctx(&ctx) + // } + + fn serialize_struct_variant_start<'this>( + &'this mut self, + _: &'static str, + variant_index: u32, + variant: &'static str, + len: usize, + ) -> Result<&'this mut ArrayBuilder> { + let mut ctx = BTreeMap::new(); + self.annotate(&mut ctx); + self.seq.start_seq()?; + self.seq.push_seq_elements(1)?; + + try_(|| { + let variant_builder = self.serialize_variant(variant_index)?; + variant_builder.serialize_struct_start(variant, len)?; + Ok(variant_builder) + }) + .ctx(&ctx) + } + + // fn serialize_tuple_variant_start<'this>( + // &'this mut self, + // _: &'static str, + // variant_index: u32, + // variant: &'static str, + // len: usize, + // ) -> Result<&'this mut ArrayBuilder> { + // let mut ctx = BTreeMap::new(); + // self.annotate(&mut ctx); + + // try_(|| { + // let variant_builder = self.serialize_variant(variant_index)?; + // variant_builder.serialize_tuple_struct_start(variant, len)?; + // Ok(variant_builder) + // }) + // .ctx(&ctx) + // } +} + +// #[cfg(test)] +// mod tests { +// fn test_serialize_union() { +// #[derive(Serialize, Deserialize)] +// enum Number { +// Real { value: f32 }, +// Complex { i: f32, j: f32 }, +// } + +// let numbers = vec![]; +// } +// } diff --git a/serde_arrow/src/internal/serialization/mod.rs b/serde_arrow/src/internal/serialization/mod.rs index f6af48eb..8198e634 100644 --- a/serde_arrow/src/internal/serialization/mod.rs +++ b/serde_arrow/src/internal/serialization/mod.rs @@ -10,6 +10,7 @@ pub mod dictionary_utf8_builder; pub mod duration_builder; pub mod fixed_size_binary_builder; pub mod fixed_size_list_builder; +pub mod flattened_union_builder; pub mod float_builder; pub mod int_builder; pub mod list_builder; diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 58bbf1ba..1bffdba0 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -10,6 +10,7 @@ use crate::internal::{ binary_builder::BinaryBuilder, duration_builder::DurationBuilder, fixed_size_binary_builder::FixedSizeBinaryBuilder, fixed_size_list_builder::FixedSizeListBuilder, + flattened_union_builder::FlattenedUnionBuilder, }, utils::{btree_map, meta_from_field, ChildName, Mut}, }; @@ -226,7 +227,34 @@ fn build_builder(path: String, field: &Field) -> Result { .ctx(&ctx)?, ) } - T::Struct(children) => A::Struct(build_struct(path, children, field.nullable)?), + T::Struct(children) => { + if let Some(Strategy::EnumsWithNamedFieldsAsStructs) = + get_strategy_from_metadata(&field.metadata)? + { + let mut related_fields: HashMap<&str, Vec> = HashMap::new(); + let mut builders: Vec = Vec::new(); + + for field in children { + let Some(variant_name) = field.enum_variant_name() else { + // TODO: warning? fail! ? + continue; + }; + related_fields + .entry(variant_name) + .or_default() + .push(field.clone()); + } + + for (variant_name, fields) in related_fields { + let sub_struct_name = format!("{}.{}", path, variant_name); + builders.push(build_struct(sub_struct_name, fields.as_slice(), true)?.take()); + } + + A::FlattenedUnion(FlattenedUnionBuilder::new(path, builders)) + } else { + A::Struct(build_struct(path, children, field.nullable)?) + } + } T::Dictionary(key, value, _) => { let key_path = format!("{path}.key"); let key_field = Field { From fd144bff0e90fa771ce8a4b664cd698fcf97cebe Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Wed, 18 Sep 2024 14:21:40 -0700 Subject: [PATCH 09/19] wip: flattened union builder first working version --- serde_arrow/src/internal/arrow/data_type.rs | 37 +++++- serde_arrow/src/internal/schema/tracer.rs | 6 +- .../serialization/flattened_union_builder.rs | 111 +++++------------- .../serialization/outer_sequence_builder.rs | 34 ++++-- 4 files changed, 92 insertions(+), 96 deletions(-) diff --git a/serde_arrow/src/internal/arrow/data_type.rs b/serde_arrow/src/internal/arrow/data_type.rs index 7f1cd180..580db657 100644 --- a/serde_arrow/src/internal/arrow/data_type.rs +++ b/serde_arrow/src/internal/arrow/data_type.rs @@ -12,18 +12,49 @@ pub struct Field { pub metadata: HashMap, } +impl PartialOrd for Field { + fn partial_cmp(&self, other: &Self) -> Option { + self.name.partial_cmp(&other.name) + } +} + impl Field { - pub fn from_flattened_enum(&self) -> bool { + pub fn to_flattened_union_field(mut self, variant_name: &str) -> Self { + self.name = format!("{}::{}", variant_name, self.name); + self.nullable = true; + self + } + + fn from_flattened_union(&self) -> bool { self.name.contains("::") } - pub fn enum_variant_name(&self) -> Option<&str> { - if self.from_flattened_enum() { + pub fn union_variant_name(&self) -> Option<&str> { + if self.from_flattened_union() { self.name.split("::").next() } else { None } } + + pub fn union_field_name(&self) -> Option { + if self.from_flattened_union() { + Some( + self.name + .split("::") + .skip(1) + .fold(String::new(), |acc: String, e| { + if acc.is_empty() { + String::from(e) + } else { + format!("{acc}::{e}") + } + }), + ) + } else { + None + } + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 08f51daa..c4c749a5 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -1093,10 +1093,8 @@ impl UnionTracer { for variant in &self.variants { if let Some(variant) = variant { let schema = variant.tracer.to_schema()?; - for mut field in schema.fields { - field.name = format!("{}::{}", variant.name, field.name); - field.nullable = true; - fields.push(field) + for field in schema.fields { + fields.push(field.to_flattened_union_field(variant.name.as_str())) } } else { fields.push(unknown_variant_field()) diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs index 047ce438..fd788e0b 100644 --- a/serde_arrow/src/internal/serialization/flattened_union_builder.rs +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -1,9 +1,8 @@ use std::collections::BTreeMap; use crate::internal::{ - arrow::{Array, StructArray}, + arrow::{Array, FieldMeta, StructArray}, error::{fail, set_default, try_, Context, ContextSupport, Result}, - utils::array_ext::{ArrayExt, CountArray, SeqArrayExt}, }; use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; @@ -11,51 +10,51 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct FlattenedUnionBuilder { pub path: String, - pub fields: Vec, - pub seq: CountArray, + pub fields: Vec<(ArrayBuilder, FieldMeta)>, } impl FlattenedUnionBuilder { - pub fn new(path: String, fields: Vec) -> Self { - Self { - path, - fields, - seq: CountArray::new(true), - } - } - - pub fn take_self(&mut self) -> Self { - Self { - path: self.path.clone(), - fields: self.fields.clone(), - seq: self.seq.take(), - } + pub fn new(path: String, fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self { + Self { path, fields } } pub fn take(&mut self) -> ArrayBuilder { - ArrayBuilder::FlattenedUnion(self.take_self()) + ArrayBuilder::FlattenedUnion(Self { + path: self.path.clone(), + fields: self + .fields + .iter_mut() + .map(|(field, meta)| (field.take(), meta.clone())) + .collect(), + }) } pub fn is_nullable(&self) -> bool { - self.seq.validity.is_some() + false } pub fn into_array(self) -> Result { let mut fields = Vec::new(); + let mut num_elements = 0; - for builder in self.fields.into_iter() { + for (builder, meta) in self.fields.into_iter() { let ArrayBuilder::Struct(builder) = builder else { - fail!("enum variant not built as a struct") + fail!("enum variant not built as a struct") // TODO: better failure message }; - for (sub_builder, sub_meta) in builder.fields.into_iter() { + for (sub_builder, mut sub_meta) in builder.fields.into_iter() { + num_elements += 1; + // TODO: this mirrors the field name structure in the tracer but represents + // implementation details crossing boundaries. Is there another way? + // Currently necessary to allow struct field lookup to work correctly. + sub_meta.name = format!("{}::{}", meta.name, sub_meta.name); fields.push((sub_builder.into_array()?, sub_meta)); } } Ok(Array::Struct(StructArray { - len: fields.len(), - validity: self.seq.validity, + len: num_elements, + validity: None, // TODO: is this ok? fields, })) } @@ -63,19 +62,16 @@ impl FlattenedUnionBuilder { impl FlattenedUnionBuilder { pub fn serialize_variant(&mut self, variant_index: u32) -> Result<&mut ArrayBuilder> { - // self.len += 1; - let variant_index = variant_index as usize; - // call push_none for any variant that was not selected - for (idx, builder) in self.fields.iter_mut().enumerate() { + // don't serialize any variant not selected + for (idx, (builder, _meta)) in self.fields.iter_mut().enumerate() { if idx != variant_index { builder.serialize_none()?; - self.seq.push_seq_none()?; } } - let Some(variant_builder) = self.fields.get_mut(variant_index) else { + let Some((variant_builder, _variant_meta)) = self.fields.get_mut(variant_index) else { fail!("Could not find variant {variant_index} in Union"); }; @@ -86,40 +82,11 @@ impl FlattenedUnionBuilder { impl Context for FlattenedUnionBuilder { fn annotate(&self, annotations: &mut BTreeMap) { set_default(annotations, "field", &self.path); - set_default(annotations, "data_type", "Union(..)"); + set_default(annotations, "data_type", "Struct(..)"); } } impl SimpleSerializer for FlattenedUnionBuilder { - // fn serialize_unit_variant( - // &mut self, - // _: &'static str, - // variant_index: u32, - // _: &'static str, - // ) -> Result<()> { - // let mut ctx = BTreeMap::new(); - // self.annotate(&mut ctx); - - // try_(|| self.serialize_variant(variant_index)?.serialize_unit()).ctx(&ctx) - // } - - // fn serialize_newtype_variant( - // &mut self, - // _: &'static str, - // variant_index: u32, - // _: &'static str, - // value: &V, - // ) -> Result<()> { - // let mut ctx = BTreeMap::new(); - // self.annotate(&mut ctx); - - // try_(|| { - // let variant_builder = self.serialize_variant(variant_index)?; - // value.serialize(Mut(variant_builder)) - // }) - // .ctx(&ctx) - // } - fn serialize_struct_variant_start<'this>( &'this mut self, _: &'static str, @@ -129,8 +96,6 @@ impl SimpleSerializer for FlattenedUnionBuilder { ) -> Result<&'this mut ArrayBuilder> { let mut ctx = BTreeMap::new(); self.annotate(&mut ctx); - self.seq.start_seq()?; - self.seq.push_seq_elements(1)?; try_(|| { let variant_builder = self.serialize_variant(variant_index)?; @@ -139,26 +104,10 @@ impl SimpleSerializer for FlattenedUnionBuilder { }) .ctx(&ctx) } - - // fn serialize_tuple_variant_start<'this>( - // &'this mut self, - // _: &'static str, - // variant_index: u32, - // variant: &'static str, - // len: usize, - // ) -> Result<&'this mut ArrayBuilder> { - // let mut ctx = BTreeMap::new(); - // self.annotate(&mut ctx); - - // try_(|| { - // let variant_builder = self.serialize_variant(variant_index)?; - // variant_builder.serialize_tuple_struct_start(variant, len)?; - // Ok(variant_builder) - // }) - // .ctx(&ctx) - // } } +// TODO: add tests + // #[cfg(test)] // mod tests { // fn test_serialize_union() { diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 1bffdba0..30932981 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -3,7 +3,7 @@ use std::collections::{BTreeMap, HashMap}; use serde::Serialize; use crate::internal::{ - arrow::{DataType, Field, TimeUnit}, + arrow::{DataType, Field, FieldMeta, TimeUnit}, error::{fail, Context, ContextSupport, Result}, schema::{get_strategy_from_metadata, SerdeArrowSchema, Strategy}, serialization::{ @@ -231,23 +231,41 @@ fn build_builder(path: String, field: &Field) -> Result { if let Some(Strategy::EnumsWithNamedFieldsAsStructs) = get_strategy_from_metadata(&field.metadata)? { - let mut related_fields: HashMap<&str, Vec> = HashMap::new(); - let mut builders: Vec = Vec::new(); + let mut related_fields: BTreeMap<&str, Vec> = BTreeMap::new(); + let mut builders: Vec<(ArrayBuilder, FieldMeta)> = Vec::new(); for field in children { - let Some(variant_name) = field.enum_variant_name() else { - // TODO: warning? fail! ? + let Some(variant_name) = field.union_variant_name() else { + // TODO: failure message continue; }; + + let Some(field_name) = field.union_field_name() else { + // TODO: failure message + continue; + }; + + let mut new_field = field.clone(); + new_field.name = field_name; + related_fields .entry(variant_name) .or_default() - .push(field.clone()); + .push(new_field); } for (variant_name, fields) in related_fields { - let sub_struct_name = format!("{}.{}", path, variant_name); - builders.push(build_struct(sub_struct_name, fields.as_slice(), true)?.take()); + let builder = build_struct( + format!("{}.{}", path.to_owned(), variant_name), + fields.as_slice(), + true, + )? + .take(); + + let mut meta = meta_from_field(field.clone()); + meta.name = variant_name.to_owned(); + + builders.push((builder, meta)); } A::FlattenedUnion(FlattenedUnionBuilder::new(path, builders)) From 202c07865eaf714b9c2d1f2c17411c0f856c7e51 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Fri, 20 Sep 2024 17:08:30 -0700 Subject: [PATCH 10/19] fix num elements in struct and start adding unit tests for flattened union builder --- .../serialization/flattened_union_builder.rs | 162 ++++++++++++++++-- .../serialization/outer_sequence_builder.rs | 2 +- 2 files changed, 146 insertions(+), 18 deletions(-) diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs index fd788e0b..d760a08d 100644 --- a/serde_arrow/src/internal/serialization/flattened_union_builder.rs +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -35,7 +35,7 @@ impl FlattenedUnionBuilder { pub fn into_array(self) -> Result { let mut fields = Vec::new(); - let mut num_elements = 0; + let num_fields = self.fields.len(); for (builder, meta) in self.fields.into_iter() { let ArrayBuilder::Struct(builder) = builder else { @@ -43,18 +43,21 @@ impl FlattenedUnionBuilder { }; for (sub_builder, mut sub_meta) in builder.fields.into_iter() { - num_elements += 1; // TODO: this mirrors the field name structure in the tracer but represents // implementation details crossing boundaries. Is there another way? - // Currently necessary to allow struct field lookup to work correctly. + // Name change is currently needed for struct field lookup to work correctly. + sub_meta.name = format!("{}::{}", meta.name, sub_meta.name); fields.push((sub_builder.into_array()?, sub_meta)); } } Ok(Array::Struct(StructArray { - len: num_elements, - validity: None, // TODO: is this ok? + len: num_fields, + // TODO: is this ok to hardcode? + // assuming so because when testing manually, + // validity of struct with nullable fields was None + validity: None, fields, })) } @@ -106,17 +109,142 @@ impl SimpleSerializer for FlattenedUnionBuilder { } } -// TODO: add tests +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{ + internal::{ + array_builder::ArrayBuilder, + arrow::{DataType, Field}, + serialization::{self, outer_sequence_builder::build_builder}, + }, + schema::SerdeArrowSchema, + Serializer, + }; + use serde::{Deserialize, Serialize}; + + #[derive(Serialize, Deserialize)] + struct Number { + v: Value, + } + + #[derive(Serialize, Deserialize)] + enum Value { + Real { value: f32 }, + Complex { i: f32, j: f32 }, + Whole { value: usize }, + } + + fn number_field() -> Field { + Field { + name: "v".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Complex::i".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::j".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Real::value".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Whole::value".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata: HashMap::from([( + "SERDE_ARROW:strategy".to_string(), + "EnumsWithNamedFieldsAsStructs".to_string(), + )]), + } + } + + fn number_data() -> Vec { + vec![ + Number { + v: Value::Real { value: 0.0 }, + }, + Number { + v: Value::Complex { i: 0.5, j: 0.5 }, + }, + Number { + v: Value::Whole { value: 5 }, + }, + ] + } + + #[test] + fn test_build_flattened_union_builder() { + let field = number_field(); + + let array_builder = + build_builder("$".to_string(), &field).expect("failed to build builder"); -// #[cfg(test)] -// mod tests { -// fn test_serialize_union() { -// #[derive(Serialize, Deserialize)] -// enum Number { -// Real { value: f32 }, -// Complex { i: f32, j: f32 }, -// } + let serialization::ArrayBuilder::FlattenedUnion(builder) = array_builder else { + panic!("did not create correct builder"); + }; -// let numbers = vec![]; -// } -// } + // Should be 3 struct builders: one for Real, one for Complex, one for Whole + assert_eq!( + builder.fields.len(), + 3, + "contained {} builder fields", + builder.fields.len() + ); + assert!( + builder + .fields + .iter() + .all(|(inner, _)| matches!(inner, serialization::ArrayBuilder::Struct(_))), + "some inner builders were not Struct builders" + ); + } + + #[test] + fn test_serialize_flattened_union_builder() { + let field = number_field(); + let data = number_data(); + let schema = SerdeArrowSchema { + fields: vec![field], + }; + + let api_builder = ArrayBuilder::new(schema).expect("failed to create api array builder"); + let serializer = Serializer::new(api_builder); + data.serialize(serializer) + .expect("failed to serialize") + .into_inner() + .to_arrow() + .expect("failed to serialize to arrow"); + } + + #[test] + fn test_record_batch_flattened_union_builder() { + let field = number_field(); + let data = number_data(); + let schema = SerdeArrowSchema { + fields: vec![field], + }; + + let api_builder = ArrayBuilder::new(schema).expect("failed to create api array builder"); + let serializer = Serializer::new(api_builder); + data.serialize(serializer) + .expect("failed to serialize") + .into_inner() + .to_record_batch() + .expect("failed to create record batch"); + } +} diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 30932981..00608ce9 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -122,7 +122,7 @@ fn build_struct(path: String, struct_fields: &[Field], nullable: bool) -> Result StructBuilder::new(path, fields, nullable) } -fn build_builder(path: String, field: &Field) -> Result { +pub(crate) fn build_builder(path: String, field: &Field) -> Result { use {ArrayBuilder as A, DataType as T}; let ctx: BTreeMap = btree_map!("field" => path.clone()); From 180c99cace7654c313f577acc065fa572b155f04 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Fri, 20 Sep 2024 17:51:54 -0700 Subject: [PATCH 11/19] fix row count --- .../serialization/flattened_union_builder.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs index d760a08d..9a970c89 100644 --- a/serde_arrow/src/internal/serialization/flattened_union_builder.rs +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -9,13 +9,18 @@ use super::{array_builder::ArrayBuilder, simple_serializer::SimpleSerializer}; #[derive(Debug, Clone)] pub struct FlattenedUnionBuilder { - pub path: String, - pub fields: Vec<(ArrayBuilder, FieldMeta)>, + path: String, + fields: Vec<(ArrayBuilder, FieldMeta)>, + row_count: usize, } impl FlattenedUnionBuilder { pub fn new(path: String, fields: Vec<(ArrayBuilder, FieldMeta)>) -> Self { - Self { path, fields } + Self { + path, + fields, + row_count: 0, + } } pub fn take(&mut self) -> ArrayBuilder { @@ -26,6 +31,7 @@ impl FlattenedUnionBuilder { .iter_mut() .map(|(field, meta)| (field.take(), meta.clone())) .collect(), + row_count: self.row_count, }) } @@ -35,7 +41,6 @@ impl FlattenedUnionBuilder { pub fn into_array(self) -> Result { let mut fields = Vec::new(); - let num_fields = self.fields.len(); for (builder, meta) in self.fields.into_iter() { let ArrayBuilder::Struct(builder) = builder else { @@ -53,12 +58,12 @@ impl FlattenedUnionBuilder { } Ok(Array::Struct(StructArray { - len: num_fields, + len: self.row_count, + fields, // TODO: is this ok to hardcode? // assuming so because when testing manually, // validity of struct with nullable fields was None validity: None, - fields, })) } } @@ -78,6 +83,8 @@ impl FlattenedUnionBuilder { fail!("Could not find variant {variant_index} in Union"); }; + self.row_count += 1; + Ok(variant_builder) } } From ebe3117ab5c813fb99259fbea764d3aaee4ef7bd Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Fri, 20 Sep 2024 18:13:09 -0700 Subject: [PATCH 12/19] fix schema/array column ordering issue by using BTreeMaps in the Tracer and Builder --- serde_arrow/src/internal/schema/tracer.rs | 10 ++++++---- .../internal/serialization/flattened_union_builder.rs | 9 ++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index c4c749a5..97a6818e 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -1083,7 +1083,7 @@ impl UnionTracer { STRATEGY_KEY.to_string(), Strategy::EnumsWithNamedFieldsAsStructs.to_string(), ); - let mut fields = Vec::new(); + let mut fields = BTreeMap::new(); // For this option, we want to merge the variant children up one level, combining the names // For each variant with name variant_name @@ -1094,14 +1094,16 @@ impl UnionTracer { if let Some(variant) = variant { let schema = variant.tracer.to_schema()?; for field in schema.fields { - fields.push(field.to_flattened_union_field(variant.name.as_str())) + let flat_field = field.to_flattened_union_field(variant.name.as_str()); + fields.insert(flat_field.name.to_string(), flat_field); } } else { - fields.push(unknown_variant_field()) + let uf = unknown_variant_field(); + fields.insert(uf.name, unknown_variant_field()); }; } - data_type = DataType::Struct(fields); + data_type = DataType::Struct(fields.into_values().collect()); } else { let mut fields = Vec::new(); for (idx, variant) in self.variants.iter().enumerate() { diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs index 9a970c89..f5774028 100644 --- a/serde_arrow/src/internal/serialization/flattened_union_builder.rs +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -40,7 +40,7 @@ impl FlattenedUnionBuilder { } pub fn into_array(self) -> Result { - let mut fields = Vec::new(); + let mut fields = BTreeMap::new(); for (builder, meta) in self.fields.into_iter() { let ArrayBuilder::Struct(builder) = builder else { @@ -53,13 +53,16 @@ impl FlattenedUnionBuilder { // Name change is currently needed for struct field lookup to work correctly. sub_meta.name = format!("{}::{}", meta.name, sub_meta.name); - fields.push((sub_builder.into_array()?, sub_meta)); + fields.insert( + sub_meta.name.to_owned(), + (sub_builder.into_array()?, sub_meta), + ); } } Ok(Array::Struct(StructArray { len: self.row_count, - fields, + fields: fields.into_values().collect(), // TODO: is this ok to hardcode? // assuming so because when testing manually, // validity of struct with nullable fields was None From 0104fa5fd83914265b7bf29da0e4e68cef4f4279 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Mon, 23 Sep 2024 12:38:09 -0700 Subject: [PATCH 13/19] fix ordering of fields in unit tests and replace continue with todo --- .../src/internal/schema/from_samples/mod.rs | 14 +++++++------- serde_arrow/src/internal/schema/from_type/mod.rs | 6 +++--- .../serialization/outer_sequence_builder.rs | 6 ++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 05aa86ad..53c2ae66 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -956,19 +956,19 @@ mod test { name: "$".to_string(), data_type: DataType::Struct(vec![ Field { - name: "Real::value".to_string(), + name: "Complex::i".to_string(), data_type: DataType::Float32, nullable: true, metadata: HashMap::new(), }, Field { - name: "Complex::i".to_string(), + name: "Complex::j".to_string(), data_type: DataType::Float32, nullable: true, metadata: HashMap::new(), }, Field { - name: "Complex::j".to_string(), + name: "Real::value".to_string(), data_type: DataType::Float32, nullable: true, metadata: HashMap::new(), @@ -1074,25 +1074,25 @@ mod test { name: "$".to_string(), data_type: DataType::Struct(vec![ Field { - name: "Something::more".to_string(), + name: "Else::another".to_string(), data_type: DataType::UInt64, nullable: true, metadata: HashMap::new(), }, Field { - name: "Something::less".to_string(), + name: "Else::one".to_string(), data_type: DataType::UInt64, nullable: true, metadata: HashMap::new(), }, Field { - name: "Else::one".to_string(), + name: "Something::less".to_string(), data_type: DataType::UInt64, nullable: true, metadata: HashMap::new(), }, Field { - name: "Else::another".to_string(), + name: "Something::more".to_string(), data_type: DataType::UInt64, nullable: true, metadata: HashMap::new(), diff --git a/serde_arrow/src/internal/schema/from_type/mod.rs b/serde_arrow/src/internal/schema/from_type/mod.rs index 193b6415..d7500157 100644 --- a/serde_arrow/src/internal/schema/from_type/mod.rs +++ b/serde_arrow/src/internal/schema/from_type/mod.rs @@ -621,19 +621,19 @@ mod test { name: "$".to_string(), data_type: DataType::Struct(vec![ Field { - name: "Real::value".to_string(), + name: "Complex::i".to_string(), data_type: DataType::Float32, nullable: true, metadata: HashMap::new(), }, Field { - name: "Complex::i".to_string(), + name: "Complex::j".to_string(), data_type: DataType::Float32, nullable: true, metadata: HashMap::new(), }, Field { - name: "Complex::j".to_string(), + name: "Real::value".to_string(), data_type: DataType::Float32, nullable: true, metadata: HashMap::new(), diff --git a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs index 00608ce9..7d27dcf3 100644 --- a/serde_arrow/src/internal/serialization/outer_sequence_builder.rs +++ b/serde_arrow/src/internal/serialization/outer_sequence_builder.rs @@ -236,13 +236,11 @@ pub(crate) fn build_builder(path: String, field: &Field) -> Result for field in children { let Some(variant_name) = field.union_variant_name() else { - // TODO: failure message - continue; + todo!("union variant did not have a name"); }; let Some(field_name) = field.union_field_name() else { - // TODO: failure message - continue; + todo!("union field did not have a name"); }; let mut new_field = field.clone(); From 522140be67b2b09632083c92760f94f5064a3473 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Tue, 22 Oct 2024 13:49:45 -0700 Subject: [PATCH 14/19] Move flattened_union_builder tests to test_with_arrow directory --- serde_arrow/src/internal/schema/tracer.rs | 2 +- .../serialization/flattened_union_builder.rs | 140 ---------------- .../test_with_arrow/impls/flattened_union.rs | 157 ++++++++++++++++++ serde_arrow/src/test_with_arrow/impls/mod.rs | 1 + 4 files changed, 159 insertions(+), 141 deletions(-) create mode 100644 serde_arrow/src/test_with_arrow/impls/flattened_union.rs diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 97a6818e..9eabee38 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -106,7 +106,7 @@ impl Tracer { let root = self.to_field()?; if root.nullable { - fail!("The root type cannot be nullable"); + fail!("The root type cannot be nullable: {root:#?}"); } let tracing_mode = dispatch_tracer!(self, tracer => tracer.options.tracing_mode); diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs index f5774028..0c87d12e 100644 --- a/serde_arrow/src/internal/serialization/flattened_union_builder.rs +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -118,143 +118,3 @@ impl SimpleSerializer for FlattenedUnionBuilder { .ctx(&ctx) } } - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use crate::{ - internal::{ - array_builder::ArrayBuilder, - arrow::{DataType, Field}, - serialization::{self, outer_sequence_builder::build_builder}, - }, - schema::SerdeArrowSchema, - Serializer, - }; - use serde::{Deserialize, Serialize}; - - #[derive(Serialize, Deserialize)] - struct Number { - v: Value, - } - - #[derive(Serialize, Deserialize)] - enum Value { - Real { value: f32 }, - Complex { i: f32, j: f32 }, - Whole { value: usize }, - } - - fn number_field() -> Field { - Field { - name: "v".to_string(), - data_type: DataType::Struct(vec![ - Field { - name: "Complex::i".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Complex::j".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Real::value".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Whole::value".to_string(), - data_type: DataType::UInt64, - nullable: true, - metadata: HashMap::new(), - }, - ]), - nullable: false, - metadata: HashMap::from([( - "SERDE_ARROW:strategy".to_string(), - "EnumsWithNamedFieldsAsStructs".to_string(), - )]), - } - } - - fn number_data() -> Vec { - vec![ - Number { - v: Value::Real { value: 0.0 }, - }, - Number { - v: Value::Complex { i: 0.5, j: 0.5 }, - }, - Number { - v: Value::Whole { value: 5 }, - }, - ] - } - - #[test] - fn test_build_flattened_union_builder() { - let field = number_field(); - - let array_builder = - build_builder("$".to_string(), &field).expect("failed to build builder"); - - let serialization::ArrayBuilder::FlattenedUnion(builder) = array_builder else { - panic!("did not create correct builder"); - }; - - // Should be 3 struct builders: one for Real, one for Complex, one for Whole - assert_eq!( - builder.fields.len(), - 3, - "contained {} builder fields", - builder.fields.len() - ); - assert!( - builder - .fields - .iter() - .all(|(inner, _)| matches!(inner, serialization::ArrayBuilder::Struct(_))), - "some inner builders were not Struct builders" - ); - } - - #[test] - fn test_serialize_flattened_union_builder() { - let field = number_field(); - let data = number_data(); - let schema = SerdeArrowSchema { - fields: vec![field], - }; - - let api_builder = ArrayBuilder::new(schema).expect("failed to create api array builder"); - let serializer = Serializer::new(api_builder); - data.serialize(serializer) - .expect("failed to serialize") - .into_inner() - .to_arrow() - .expect("failed to serialize to arrow"); - } - - #[test] - fn test_record_batch_flattened_union_builder() { - let field = number_field(); - let data = number_data(); - let schema = SerdeArrowSchema { - fields: vec![field], - }; - - let api_builder = ArrayBuilder::new(schema).expect("failed to create api array builder"); - let serializer = Serializer::new(api_builder); - data.serialize(serializer) - .expect("failed to serialize") - .into_inner() - .to_record_batch() - .expect("failed to create record batch"); - } -} diff --git a/serde_arrow/src/test_with_arrow/impls/flattened_union.rs b/serde_arrow/src/test_with_arrow/impls/flattened_union.rs new file mode 100644 index 00000000..e7444945 --- /dev/null +++ b/serde_arrow/src/test_with_arrow/impls/flattened_union.rs @@ -0,0 +1,157 @@ +use std::collections::HashMap; + +use crate::{ + internal::{ + array_builder::ArrayBuilder, + arrow::{Array, DataType, Field}, + schema::{SchemaLike, TracingOptions}, + }, + schema::SerdeArrowSchema, + Serializer, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +struct Number { + v: Value, +} + +#[derive(Serialize, Deserialize)] +enum Value { + Real { value: f32 }, + Complex { i: f32, j: f32 }, + Whole { value: usize }, +} + +fn number_field() -> Field { + Field { + name: "v".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Complex::i".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::j".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Real::value".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Whole::value".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata: HashMap::from([( + "SERDE_ARROW:strategy".to_string(), + "EnumsWithNamedFieldsAsStructs".to_string(), + )]), + } +} + +fn number_schema() -> SerdeArrowSchema { + let options = TracingOptions::default() + .allow_null_fields(true) + .enums_with_named_fields_as_structs(true); + + SerdeArrowSchema::from_type::(options).unwrap() +} + +fn number_data() -> Vec { + vec![ + Number { + v: Value::Real { value: 0.0 }, + }, + Number { + v: Value::Complex { i: 0.5, j: 0.5 }, + }, + Number { + v: Value::Whole { value: 5 }, + }, + ] +} + +#[test] +fn test_build_flattened_union_builder() { + let mut builder = ArrayBuilder::new(number_schema()).unwrap(); + + // One struct in the array + let arrays = builder.build_arrays().unwrap(); + + assert_eq!(arrays.len(), 1); + + let array = &arrays[0]; + + let Array::Struct(ref struct_array) = array else { + panic!("expected a struct array, found {array:#?}"); + }; + + // Should be a single struct array with 4 fields: Complex::i, Complex::j, Real::value, Whole::value + assert_eq!( + struct_array.fields.len(), + 4, + "contained {} fields", + struct_array.fields.len() + ); + + let (first_field, meta) = &struct_array.fields[0]; + assert_eq!(meta.name, "Complex::i"); + assert!(matches!(first_field, Array::Float32(_))); + + let (second_field, meta) = &struct_array.fields[1]; + assert_eq!(meta.name, "Complex::j"); + assert!(matches!(second_field, Array::Float32(_))); + + let (third_field, meta) = &struct_array.fields[2]; + assert_eq!(meta.name, "Real::value"); + assert!(matches!(third_field, Array::Float32(_))); + + let (fourth_field, meta) = &struct_array.fields[3]; + assert_eq!(meta.name, "Whole::value"); + assert!(matches!(fourth_field, Array::UInt64(_))); +} + +#[test] +fn test_serialize_flattened_union_builder() { + let field = number_field(); + let data = number_data(); + let schema = SerdeArrowSchema { + fields: vec![field], + }; + + let api_builder = ArrayBuilder::new(schema).expect("failed to create api array builder"); + let serializer = Serializer::new(api_builder); + data.serialize(serializer) + .expect("failed to serialize") + .into_inner() + .to_arrow() + .expect("failed to serialize to arrow"); +} + +#[test] +fn test_record_batch_flattened_union_builder() { + let field = number_field(); + let data = number_data(); + let schema = SerdeArrowSchema { + fields: vec![field], + }; + + let api_builder = ArrayBuilder::new(schema).expect("failed to create api array builder"); + let serializer = Serializer::new(api_builder); + data.serialize(serializer) + .expect("failed to serialize") + .into_inner() + .to_record_batch() + .expect("failed to create record batch"); +} diff --git a/serde_arrow/src/test_with_arrow/impls/mod.rs b/serde_arrow/src/test_with_arrow/impls/mod.rs index 7bbef5d5..dcf02fe4 100644 --- a/serde_arrow/src/test_with_arrow/impls/mod.rs +++ b/serde_arrow/src/test_with_arrow/impls/mod.rs @@ -6,6 +6,7 @@ mod chrono; mod dictionary; mod examples; mod fixed_size_list; +mod flattened_union; mod jiff; mod json_values; mod list; From 25755fbb09241d8c12302b5b2eed8cdd45d2f71a Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Thu, 24 Oct 2024 12:37:47 -0700 Subject: [PATCH 15/19] complex nested enum test case --- .../test_with_arrow/impls/flattened_union.rs | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/serde_arrow/src/test_with_arrow/impls/flattened_union.rs b/serde_arrow/src/test_with_arrow/impls/flattened_union.rs index e7444945..bf898814 100644 --- a/serde_arrow/src/test_with_arrow/impls/flattened_union.rs +++ b/serde_arrow/src/test_with_arrow/impls/flattened_union.rs @@ -155,3 +155,75 @@ fn test_record_batch_flattened_union_builder() { .to_record_batch() .expect("failed to create record batch"); } + +#[derive(Serialize, Deserialize)] +struct ComplexMessage { + data: MsgData, +} + +#[derive(Serialize, Deserialize)] +enum MsgData { + One { data: usize }, + Two { opts: MsgOptions }, +} + +#[derive(Serialize, Deserialize)] +struct MsgOptions { + loc: Location, +} + +#[derive(Serialize, Deserialize)] +enum Location { + Left, + Right, +} + +fn nested_enum_schema() -> SerdeArrowSchema { + let options = TracingOptions::default() + .allow_null_fields(true) + .enums_without_data_as_strings(true) + .enums_with_named_fields_as_structs(true); + + SerdeArrowSchema::from_type::(options).unwrap() +} + +fn nested_enum_data() -> Vec { + vec![ + ComplexMessage { + data: MsgData::One { data: 3 }, + }, + ComplexMessage { + data: MsgData::Two { + opts: MsgOptions { + loc: Location::Right, + }, + }, + }, + ] +} + +#[test] +fn test_flattened_union_with_nested_enum() { + let mut builder = ArrayBuilder::new(nested_enum_schema()).unwrap(); + + // One struct in the array + let arrays = builder.build_arrays().unwrap(); + + println!("{arrays:#?}"); + + assert_eq!(arrays.len(), 1); + + let array = &arrays[0]; + + let Array::Struct(ref struct_array) = array else { + panic!("expected a struct array, found {array:#?}"); + }; + + let serializer = Serializer::new(builder); + nested_enum_data() + .serialize(serializer) + .expect("failed to serialize") + .into_inner() + .to_arrow() + .expect("failed to serialize to arrow"); +} From c5b98b6fa838545886f510ffb25dfa598d742efa Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Mon, 28 Oct 2024 14:49:20 -0700 Subject: [PATCH 16/19] fix duplicated error logs and fix/force nested enum/dictionaries to be nullable --- .../src/internal/schema/extensions/mod.rs | 1 + .../src/internal/schema/extensions/utils.rs | 15 +++++- serde_arrow/src/internal/schema/tracer.rs | 53 +++++++++++-------- .../test_with_arrow/impls/flattened_union.rs | 2 +- 4 files changed, 48 insertions(+), 23 deletions(-) diff --git a/serde_arrow/src/internal/schema/extensions/mod.rs b/serde_arrow/src/internal/schema/extensions/mod.rs index 9b11016e..426368b9 100644 --- a/serde_arrow/src/internal/schema/extensions/mod.rs +++ b/serde_arrow/src/internal/schema/extensions/mod.rs @@ -5,6 +5,7 @@ mod variable_shape_tensor_field; pub use bool8_field::Bool8Field; pub use fixed_shape_tensor_field::FixedShapeTensorField; +pub(crate) use utils::fix_dictionaries; pub use variable_shape_tensor_field::VariableShapeTensorField; const _: () = { diff --git a/serde_arrow/src/internal/schema/extensions/utils.rs b/serde_arrow/src/internal/schema/extensions/utils.rs index aae4d1c0..99e8df55 100644 --- a/serde_arrow/src/internal/schema/extensions/utils.rs +++ b/serde_arrow/src/internal/schema/extensions/utils.rs @@ -1,4 +1,7 @@ -use crate::internal::error::{fail, Result}; +use crate::internal::{ + arrow::{DataType, Field}, + error::{fail, Result}, +}; pub fn check_dim_names(ndim: usize, dim_names: &[String]) -> Result<()> { if dim_names.len() != ndim { @@ -56,3 +59,13 @@ impl std::fmt::Display for DebugRepr { write!(f, "{:?}", self.0) } } + +pub(crate) fn fix_dictionaries(field: &mut Field) { + if matches!(field.data_type, DataType::Dictionary(_, _, _)) { + field.nullable = true; + } else if let DataType::Struct(children) = &mut field.data_type { + for child in children { + fix_dictionaries(child); + } + } +} diff --git a/serde_arrow/src/internal/schema/tracer.rs b/serde_arrow/src/internal/schema/tracer.rs index 9eabee38..c45a77a0 100644 --- a/serde_arrow/src/internal/schema/tracer.rs +++ b/serde_arrow/src/internal/schema/tracer.rs @@ -7,8 +7,8 @@ use crate::internal::{ arrow::{DataType, Field, UnionMode}, error::{fail, set_default, Context, Result}, schema::{ - DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy, TracingMode, TracingOptions, - STRATEGY_KEY, + extensions::fix_dictionaries, DataTypeDisplay, Overwrites, SerdeArrowSchema, Strategy, + TracingMode, TracingOptions, STRATEGY_KEY, }, }; @@ -101,6 +101,30 @@ impl Tracer { Self::Unknown(UnknownTracer::new(name, path, options)) } + fn schema_tracing_error( + failed_data_type: impl std::fmt::Display, + tracing_mode: TracingMode, + ) -> Result { + fail!( + concat!( + "Schema tracing is not directly supported for the root data type {failed_data_type}. ", + "Only struct-like types are supported as root types in schema tracing. ", + "{mitigation}", + ), + failed_data_type = failed_data_type, + mitigation = match tracing_mode { + TracingMode::FromType => { + "Consider using the `Item` wrapper, i.e., `::from_type>()`." + } + TracingMode::FromSamples => { + "Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`." + } + TracingMode::Unknown => "Consider using the `Item` / `Items` wrappers.", + }, + + ) + } + /// Convert the traced schema into a schema object pub fn to_schema(&self) -> Result { let root = self.to_field()?; @@ -113,30 +137,16 @@ impl Tracer { let fields = match root.data_type { DataType::Struct(children) => { - if let Some(strategy) = root - .metadata.get(STRATEGY_KEY) { - if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() { - // TODO: combine with fail messaging below - fail!("Schema tracing is not directly supported for the root data Union. Consider using the `Item` / `Items` wrappers."); + if let Some(strategy) = root.metadata.get(STRATEGY_KEY) { + if *strategy == Strategy::EnumsWithNamedFieldsAsStructs.to_string() { + return Self::schema_tracing_error("Union", tracing_mode); } } children } DataType::Null => fail!("No records found to determine schema"), - dt => fail!( - concat!( - "Schema tracing is not directly supported for the root data type {dt}. ", - "Only struct-like types are supported as root types in schema tracing. ", - "{mitigation}", - ), - dt = DataTypeDisplay(&dt), - mitigation = match tracing_mode { - TracingMode::FromType => "Consider using the `Item` wrapper, i.e., `::from_type>()`.", - TracingMode::FromSamples => "Consider using the `Items` wrapper, i.e., `::from_samples(Items(samples))`.", - TracingMode::Unknown => "Consider using the `Item` / `Items` wrappers.", - }, - ), + dt => return Self::schema_tracing_error(DataTypeDisplay(&dt), tracing_mode), }; Ok(SerdeArrowSchema { fields }) @@ -1094,7 +1104,8 @@ impl UnionTracer { if let Some(variant) = variant { let schema = variant.tracer.to_schema()?; for field in schema.fields { - let flat_field = field.to_flattened_union_field(variant.name.as_str()); + let mut flat_field = field.to_flattened_union_field(variant.name.as_str()); + fix_dictionaries(&mut flat_field); fields.insert(flat_field.name.to_string(), flat_field); } } else { diff --git a/serde_arrow/src/test_with_arrow/impls/flattened_union.rs b/serde_arrow/src/test_with_arrow/impls/flattened_union.rs index bf898814..6900be56 100644 --- a/serde_arrow/src/test_with_arrow/impls/flattened_union.rs +++ b/serde_arrow/src/test_with_arrow/impls/flattened_union.rs @@ -215,7 +215,7 @@ fn test_flattened_union_with_nested_enum() { let array = &arrays[0]; - let Array::Struct(ref struct_array) = array else { + let Array::Struct(ref _struct_array) = array else { panic!("expected a struct array, found {array:#?}"); }; From 7ed5ce5030d10cf982707df505dc31625010f044 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Mon, 28 Oct 2024 16:48:07 -0700 Subject: [PATCH 17/19] consolidate helpers and cleanup from_samples and from_type unit tests --- .../src/internal/schema/from_samples/mod.rs | 183 ++---------------- .../src/internal/schema/from_type/mod.rs | 87 ++++----- serde_arrow/src/internal/testing.rs | 176 ++++++++++++++++- 3 files changed, 227 insertions(+), 219 deletions(-) diff --git a/serde_arrow/src/internal/schema/from_samples/mod.rs b/serde_arrow/src/internal/schema/from_samples/mod.rs index 53c2ae66..50624eef 100644 --- a/serde_arrow/src/internal/schema/from_samples/mod.rs +++ b/serde_arrow/src/internal/schema/from_samples/mod.rs @@ -813,28 +813,16 @@ mod impl_serialize_to_string { #[cfg(test)] mod test { - use std::collections::HashMap; - use serde::Serialize; use serde_json::{json, Value}; - use crate::{ - internal::{ - arrow::Field, - schema::{transmute_field, TracingOptions}, - }, - schema::STRATEGY_KEY, + use crate::internal::{ + schema::{transmute_field, TracingOptions}, + testing::{Coin, Number, Optionals, Payment}, }; use super::*; - fn enum_with_named_fields_metadata() -> HashMap { - HashMap::from([( - STRATEGY_KEY.to_string(), - Strategy::EnumsWithNamedFieldsAsStructs.to_string(), - )]) - } - fn test_to_tracer(items: &T, options: TracingOptions, expected: Value) { let tracer = Tracer::from_samples(items, options).unwrap(); let field = tracer.to_field().unwrap(); @@ -938,97 +926,30 @@ mod test { #[test] fn example_enum_as_struct_equal_to_struct_with_nullable_fields() { - #[derive(Serialize)] - enum Number { - Real { value: f32 }, - Complex { i: f32, j: f32 }, - } - - let enum_items = [ - Number::Real { value: 1.0 }, - Number::Complex { i: 0.5, j: 0.5 }, - ]; - let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); - let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); - - let expected_field = Field { - name: "$".to_string(), - data_type: DataType::Struct(vec![ - Field { - name: "Complex::i".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Complex::j".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Real::value".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - ]), - nullable: false, - metadata: enum_with_named_fields_metadata(), - }; - - assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + let enum_tracer = Tracer::from_samples(Number::sample_items().as_slice(), opts).unwrap(); + assert_eq!(enum_tracer.to_field().unwrap(), Number::expected_field()); } #[test] fn example_enum_as_struct_no_fields() { - #[derive(Serialize)] - enum Coin { - Heads, - Tails, - } - - let enum_items = [Coin::Heads, Coin::Tails]; - // This should continue to maintain previously implemented behavior, serializing as a map let opts = TracingOptions::default() .enums_with_named_fields_as_structs(true) .enums_without_data_as_strings(true); - let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); - - let expected_field = Field { - name: "$".to_string(), - data_type: DataType::Dictionary( - Box::new(DataType::UInt32), - Box::new(DataType::LargeUtf8), - false, - ), - nullable: false, - metadata: HashMap::new(), - }; - - assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + let enum_tracer = Tracer::from_samples(Coin::sample_items(), opts).unwrap(); + assert_eq!(enum_tracer.to_field().unwrap(), Coin::expected_field()); } #[test] #[should_panic] fn example_enum_as_struct_no_fields_panics_when_opts_not_set() { - #[derive(Serialize)] - enum TrafficLight { - Red, - Yellow, - Green, - } - - let enum_items = [TrafficLight::Red, TrafficLight::Yellow, TrafficLight::Green]; - // This should continue to maintain previously implemented behavior, // throwing an error because we detect Unions with no fields let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); - Tracer::from_samples(&enum_items, opts) + Tracer::from_samples(Coin::sample_items(), opts) .unwrap() .to_field() .unwrap(); @@ -1036,98 +957,16 @@ mod test { #[test] fn example_enum_as_struct_all_fields_nullable() { - #[derive(Serialize)] - enum Optionals { - Something { - more: Option, - less: Option, - }, - Else { - one: Option, - another: Option, - }, - } - - let enum_items = [ - Optionals::Something { - more: Some(1), - less: None, - }, - Optionals::Something { - more: None, - less: Some(0), - }, - Optionals::Else { - one: None, - another: Some(0), - }, - Optionals::Else { - one: Some(1), - another: None, - }, - ]; - let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); - let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); - - let expected_field = Field { - name: "$".to_string(), - data_type: DataType::Struct(vec![ - Field { - name: "Else::another".to_string(), - data_type: DataType::UInt64, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Else::one".to_string(), - data_type: DataType::UInt64, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Something::less".to_string(), - data_type: DataType::UInt64, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Something::more".to_string(), - data_type: DataType::UInt64, - nullable: true, - metadata: HashMap::new(), - }, - ]), - nullable: false, - metadata: enum_with_named_fields_metadata(), - }; - - assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + let enum_tracer = Tracer::from_samples(Optionals::sample_items(), opts).unwrap(); + assert_eq!(enum_tracer.to_field().unwrap(), Optionals::expected_field()); } #[test] #[should_panic] fn example_enum_as_struct_tuple_variants() { - #[derive(Serialize)] - enum Payment { - Cash(f32), // amount - Check(String, f32), // name, amount - CreditCard(String, f32, [u8; 16], String), // name, amount, cc number, exp - } - - let enum_items = [ - Payment::Cash(0.42), - Payment::Check("Bob".to_string(), 0.42), - Payment::CreditCard( - "Sue".to_string(), - 0.42, - [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6], - "01/2024".to_string(), - ), - ]; - let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); - let enum_tracer = Tracer::from_samples(&enum_items, opts).unwrap(); + let enum_tracer = Tracer::from_samples(Payment::sample_items(), opts).unwrap(); // Currently panics when `to_schema()` is called on the variant tracer enum_tracer.to_field().unwrap(); diff --git a/serde_arrow/src/internal/schema/from_type/mod.rs b/serde_arrow/src/internal/schema/from_type/mod.rs index d7500157..7acdb77c 100644 --- a/serde_arrow/src/internal/schema/from_type/mod.rs +++ b/serde_arrow/src/internal/schema/from_type/mod.rs @@ -588,61 +588,56 @@ impl<'de, 'a> serde::de::Deserializer<'de> for IdentifierDeserializer<'a> { #[cfg(test)] mod test { - use std::collections::HashMap; - - use serde::{Deserialize, Serialize}; - use crate::{ internal::{ - arrow::{DataType, Field}, schema::tracer::Tracer, + testing::{Coin, Number, Optionals, Payment}, }, - schema::{Strategy, TracingOptions, STRATEGY_KEY}, + schema::TracingOptions, }; - // TODO: combine these with the from_samples tests, dedup utility code, test all edge cases with from_type as well #[test] fn example_enum_as_struct_equal_to_struct_with_nullable_fields() { - // TODO: dedup? - #[derive(Serialize, Deserialize)] - enum Number { - Real { value: f32 }, - Complex { i: f32, j: f32 }, - } - let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); let enum_tracer = Tracer::from_type::(opts).unwrap(); - let metadata = HashMap::from([( - STRATEGY_KEY.to_string(), - Strategy::EnumsWithNamedFieldsAsStructs.to_string(), - )]); - - let expected_field = Field { - name: "$".to_string(), - data_type: DataType::Struct(vec![ - Field { - name: "Complex::i".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Complex::j".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - Field { - name: "Real::value".to_string(), - data_type: DataType::Float32, - nullable: true, - metadata: HashMap::new(), - }, - ]), - nullable: false, - metadata, - }; - - assert_eq!(enum_tracer.to_field().unwrap(), expected_field); + assert_eq!(enum_tracer.to_field().unwrap(), Number::expected_field()); + } + + #[test] + fn example_enum_as_struct_no_fields() { + // This should continue to maintain previously implemented behavior, serializing as a map + let opts = TracingOptions::default() + .enums_with_named_fields_as_structs(true) + .enums_without_data_as_strings(true); + + let enum_tracer = Tracer::from_type::(opts).unwrap(); + assert_eq!(enum_tracer.to_field().unwrap(), Coin::expected_field()); + } + + #[test] + #[should_panic] + fn example_enum_as_struct_no_fields_panics_when_opts_not_set() { + // This should continue to maintain previously implemented behavior, + // throwing an error because we detect Unions with no fields + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + + Tracer::from_type::(opts).unwrap().to_field().unwrap(); + } + + #[test] + fn example_enum_as_struct_all_fields_nullable() { + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + let enum_tracer = Tracer::from_type::(opts).unwrap(); + assert_eq!(enum_tracer.to_field().unwrap(), Optionals::expected_field()); + } + + #[test] + #[should_panic] + fn example_enum_as_struct_tuple_variants() { + let opts = TracingOptions::default().enums_with_named_fields_as_structs(true); + let enum_tracer = Tracer::from_type::(opts).unwrap(); + + // Currently panics when `to_schema()` is called on the variant tracer + enum_tracer.to_field().unwrap(); } } diff --git a/serde_arrow/src/internal/testing.rs b/serde_arrow/src/internal/testing.rs index 1b6ac8b0..9fc20e82 100644 --- a/serde_arrow/src/internal/testing.rs +++ b/serde_arrow/src/internal/testing.rs @@ -1,10 +1,14 @@ //! Support for tests use core::str; +use std::collections::HashMap; use crate::internal::{ - arrow::{Array, BytesArray}, + arrow::{Array, BytesArray, DataType, Field}, error::{fail, Error, Result}, }; +use crate::schema::{Strategy, STRATEGY_KEY}; + +use serde::{Deserialize, Serialize}; pub fn assert_error_contains(actual: &Result, expected: &str) { let Err(actual) = actual else { @@ -75,3 +79,173 @@ where Ok(Some(str::from_utf8(data)?)) } + +fn enum_with_named_fields_metadata() -> HashMap { + HashMap::from([( + STRATEGY_KEY.to_string(), + Strategy::EnumsWithNamedFieldsAsStructs.to_string(), + )]) +} + +// Simple enum test structure for schema from_type/from_samples unit testing +#[derive(Serialize, Deserialize)] +pub(crate) enum Number { + Real { value: f32 }, + Complex { i: f32, j: f32 }, +} + +impl Number { + pub(crate) fn sample_items() -> Vec { + vec![ + Number::Real { value: 1.0 }, + Number::Complex { i: 0.5, j: 0.5 }, + ] + } + + pub(crate) fn expected_field() -> Field { + Field { + name: "$".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Complex::i".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Complex::j".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Real::value".to_string(), + data_type: DataType::Float32, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata: enum_with_named_fields_metadata(), + } + } +} + +// No data test enum +#[derive(Serialize, Deserialize)] +pub(crate) enum Coin { + Heads, + Tails, +} + +impl Coin { + pub(crate) fn sample_items() -> Vec { + vec![Coin::Heads, Coin::Tails] + } + + pub(crate) fn expected_field() -> Field { + Field { + name: "$".to_string(), + data_type: DataType::Dictionary( + Box::new(DataType::UInt32), + Box::new(DataType::LargeUtf8), + false, + ), + nullable: false, + metadata: HashMap::new(), + } + } +} + +// Optional variant field test enum +#[derive(Serialize, Deserialize)] +pub(crate) enum Optionals { + Something { + more: Option, + less: Option, + }, + Else { + one: Option, + another: Option, + }, +} + +impl Optionals { + pub(crate) fn sample_items() -> Vec { + vec![ + Optionals::Something { + more: Some(1), + less: None, + }, + Optionals::Something { + more: None, + less: Some(0), + }, + Optionals::Else { + one: None, + another: Some(0), + }, + Optionals::Else { + one: Some(1), + another: None, + }, + ] + } + + pub(crate) fn expected_field() -> Field { + Field { + name: "$".to_string(), + data_type: DataType::Struct(vec![ + Field { + name: "Else::another".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Else::one".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Something::less".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + Field { + name: "Something::more".to_string(), + data_type: DataType::UInt64, + nullable: true, + metadata: HashMap::new(), + }, + ]), + nullable: false, + metadata: enum_with_named_fields_metadata(), + } + } +} + +// Tuple variant test enum +#[derive(Serialize, Deserialize)] +pub(crate) enum Payment { + Cash(f32), // amount + Check(String, f32), // name, amount + CreditCard(String, f32, [u8; 16], String), // name, amount, cc number, exp +} + +impl Payment { + pub(crate) fn sample_items() -> Vec { + vec![ + Payment::Cash(0.42), + Payment::Check("Bob".to_string(), 0.42), + Payment::CreditCard( + "Sue".to_string(), + 0.42, + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6], + "01/2024".to_string(), + ), + ] + } +} From 019b2a5d8cef7145814357e0d0115a9c143ebf55 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Mon, 28 Oct 2024 16:59:09 -0700 Subject: [PATCH 18/19] added tracing options rustdoc --- .../src/internal/schema/tracing_options.rs | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/serde_arrow/src/internal/schema/tracing_options.rs b/serde_arrow/src/internal/schema/tracing_options.rs index 21847792..276ab84b 100644 --- a/serde_arrow/src/internal/schema/tracing_options.rs +++ b/serde_arrow/src/internal/schema/tracing_options.rs @@ -226,7 +226,20 @@ pub struct TracingOptions { /// If `false` enums with data are encoded as Union arrays. /// If `true` enums with data are encoded as Structs. /// - /// TODO: example + /// ``` + /// # fn main() -> serde_arrow::Result<()> { + /// # use serde_arrow::_impl::arrow; + /// # use arrow::datatypes::{FieldRef, Field, DataType, TimeUnit}; + /// # use serde_arrow::schema::{SchemaLike, TracingOptions}; + /// # use serde::{Serialize, Deserialize}; + /// #[derive(Serialize, Deserialize)] + /// enum Number { + /// Real { value: f32 }, + /// Complex { i: f32, j: f32 }, + /// } + /// let options = TracingOptions::default().enums_with_named_fields_as_structs(true); + /// let fields = Tracer::from_type::(options)?; + /// # } /// ``` pub enums_with_named_fields_as_structs: bool, } From d6314b1696e43027ffe232decca3c268bc845408 Mon Sep 17 00:00:00 2001 From: Raj Sahae Date: Mon, 28 Oct 2024 17:05:42 -0700 Subject: [PATCH 19/19] slightly logging improvement in flattened union builder --- .../src/internal/serialization/flattened_union_builder.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/serde_arrow/src/internal/serialization/flattened_union_builder.rs b/serde_arrow/src/internal/serialization/flattened_union_builder.rs index 0c87d12e..3f647225 100644 --- a/serde_arrow/src/internal/serialization/flattened_union_builder.rs +++ b/serde_arrow/src/internal/serialization/flattened_union_builder.rs @@ -44,7 +44,7 @@ impl FlattenedUnionBuilder { for (builder, meta) in self.fields.into_iter() { let ArrayBuilder::Struct(builder) = builder else { - fail!("enum variant not built as a struct") // TODO: better failure message + fail!("Attempting to flatten a not-struct builder: {builder:?}"); }; for (sub_builder, mut sub_meta) in builder.fields.into_iter() { @@ -63,9 +63,8 @@ impl FlattenedUnionBuilder { Ok(Array::Struct(StructArray { len: self.row_count, fields: fields.into_values().collect(), - // TODO: is this ok to hardcode? - // assuming so because when testing manually, - // validity of struct with nullable fields was None + // assuming this is OK to hardcode because empirically, + // the validity of struct with nullable fields was always None validity: None, })) }