Skip to content

Commit

Permalink
Add support for repeated fields in nested messages (#24)
Browse files Browse the repository at this point in the history
* Add support for repeated fields in nested messages

* Clean up
  • Loading branch information
ChewingGlass authored Feb 16, 2024
1 parent cb0db13 commit 55ed28e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 19 deletions.
4 changes: 2 additions & 2 deletions protobuf-delta-lake-sink/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Context, Result};
use chrono::{NaiveDateTime, Utc};
use clap::Parser;
use datafusion::{arrow::array::StringArray, common::delta};
use datafusion::arrow::array::StringArray;
use deltalake::{
action::{self, Action, CommitInfo, SaveMode},
checkpoints, crate_version,
Expand Down Expand Up @@ -116,7 +116,7 @@ async fn main() -> Result<()> {
args.source_proto_name,
)
.await?;
let mut delta_fields = get_delta_schema(&descriptor, false);
let mut delta_fields = get_delta_schema(&descriptor);
if args.partition_timestamp_column.is_some() {
let date_field = SchemaField::new(
"date".to_string(),
Expand Down
156 changes: 143 additions & 13 deletions protobuf-delta-lake-sink/src/proto/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use deltalake::{
},
Schema, SchemaTypeStruct,
};
use protobuf::reflect::{EnumDescriptor, ReflectValueBox};
use protobuf::reflect::{EnumDescriptor, ReflectRepeatedRef, ReflectValueBox};
use protobuf::{
reflect::{FieldDescriptor, MessageDescriptor, ReflectValueRef, RuntimeType},
MessageDyn,
Expand All @@ -24,6 +24,7 @@ use super::get_delta_schema;

trait ReflectBuilder: ArrayBuilder {
fn append_value(&mut self, v: Option<ReflectValueRef>);
fn append_repeated_value(&mut self, v: Option<ReflectRepeatedRef>);
}

macro_rules! make_builder_wrapper {
Expand Down Expand Up @@ -75,6 +76,10 @@ impl ReflectBuilder for BinaryReflectBuilder {
.unwrap_or_default(),
)
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

impl ReflectBuilder for StringReflectBuilder {
Expand All @@ -88,6 +93,10 @@ impl ReflectBuilder for StringReflectBuilder {
.unwrap_or_default(),
)
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

impl ReflectBuilder for BoolReflectBuilder {
Expand All @@ -97,6 +106,10 @@ impl ReflectBuilder for BoolReflectBuilder {
.unwrap_or_default(),
)
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

pub struct EnumReflectBuilder {
Expand Down Expand Up @@ -147,6 +160,10 @@ impl ReflectBuilder for EnumReflectBuilder {
.unwrap_or_default(),
)
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

struct PrimitiveReflectBuilder<T: ArrowPrimitiveType> {
Expand Down Expand Up @@ -190,6 +207,10 @@ impl ReflectBuilder for PrimitiveReflectBuilder<Int32Type> {
.unwrap_or_default(),
);
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

impl ReflectBuilder for PrimitiveReflectBuilder<Int64Type> {
Expand All @@ -202,6 +223,10 @@ impl ReflectBuilder for PrimitiveReflectBuilder<Int64Type> {
.unwrap_or_default(),
);
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

pub struct U64ReflectBuilder {
Expand Down Expand Up @@ -263,6 +288,10 @@ impl ReflectBuilder for U64ReflectBuilder {
.unwrap_or_default(),
);
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

impl ReflectBuilder for PrimitiveReflectBuilder<Float32Type> {
Expand All @@ -272,6 +301,10 @@ impl ReflectBuilder for PrimitiveReflectBuilder<Float32Type> {
.unwrap_or_default(),
);
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

impl ReflectBuilder for PrimitiveReflectBuilder<Float64Type> {
Expand All @@ -281,6 +314,96 @@ impl ReflectBuilder for PrimitiveReflectBuilder<Float64Type> {
.unwrap_or_default(),
);
}

fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

struct RepeatedReflectBuilder {
pub builder: Box<dyn ReflectArrayBuilder>,
pub offsets: BufferBuilder<i32>,
pub t: RuntimeType,
pub capacity: usize,
}

impl RepeatedReflectBuilder {
fn new(
capacity: usize,
builder: Box<dyn ReflectArrayBuilder>,
t: RuntimeType,
) -> RepeatedReflectBuilder {
let mut offsets = BufferBuilder::<i32>::new(0);
offsets.append(0);
RepeatedReflectBuilder {
builder,
offsets,
capacity,
t,
}
}
}

impl ArrayBuilder for RepeatedReflectBuilder {
fn len(&self) -> usize {
self.capacity
}

fn is_empty(&self) -> bool {
self.len() == 0
}

fn finish(&mut self) -> ArrayRef {
let field = Arc::new(Field::new("item", runtime_type_to_data_type(&self.t), true));
let data_type = DataType::List(field);
let values_arr = self.builder.finish();
let values_data = values_arr.to_data();
let array_data_builder = ArrayData::builder(data_type)
.len(self.capacity)
.add_buffer(self.offsets.finish())
.add_child_data(values_data)
.null_bit_buffer(None);
// .null_bit_buffer(Some(self.nulls.finish().values().clone().into_inner()));
let array_data = unsafe { array_data_builder.build_unchecked() };
Arc::new(GenericListArray::<i32>::from(array_data))
}

fn finish_cloned(&self) -> ArrayRef {
unimplemented!()
}

fn as_any(&self) -> &dyn Any {
self
}

fn as_any_mut(&mut self) -> &mut dyn Any {
self
}

fn into_box_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}

impl ReflectBuilder for RepeatedReflectBuilder {
fn append_value(&mut self, _: Option<ReflectValueRef>) {
panic!("Operation not supported!");
}
fn append_repeated_value(&mut self, v: Option<ReflectRepeatedRef>) {
let messages = v
.iter()
.flat_map(|i| i.into_iter().collect::<Vec<_>>())
.collect::<Vec<_>>();

// self.nulls.append_value(v.is_none());

for value in messages.iter() {
self.builder.append_value(Some(value.clone()));
}

self.offsets
.append(i32::try_from(self.builder.len()).unwrap());
}
}

struct StructReflectBuilder {
Expand Down Expand Up @@ -347,10 +470,7 @@ impl StructReflectBuilder {

impl ReflectBuilder for StructReflectBuilder {
fn append_value(&mut self, v: Option<ReflectValueRef>) {
let message_ref = v
.map(|i| {
i.to_message().expect("Not a message")
});
let message_ref = v.map(|i| i.to_message().expect("Not a message"));
let message = message_ref.as_deref();
for (index, field) in self.descriptor.fields().enumerate() {
match field.runtime_field_type() {
Expand All @@ -359,14 +479,18 @@ impl ReflectBuilder for StructReflectBuilder {
builder.append_value(message.and_then(|m| field.get_singular(m)))
}
protobuf::reflect::RuntimeFieldType::Repeated(_) => {
// Do nothing
let builder = self.builders.get_mut(index).unwrap();
builder.append_repeated_value(message.map(|m| field.get_repeated(m)))
}
protobuf::reflect::RuntimeFieldType::Map(_, _) => {
panic!("Map fields are not supported")
}
};
}
}
fn append_repeated_value(&mut self, _: Option<ReflectRepeatedRef>) {
panic!("Operation not supported");
}
}

fn runtime_type_to_data_type(value: &RuntimeType) -> DataType {
Expand All @@ -382,7 +506,7 @@ fn runtime_type_to_data_type(value: &RuntimeType) -> DataType {
RuntimeType::VecU8 => DataType::Binary,
RuntimeType::Enum(_) => DataType::Binary,
RuntimeType::Message(m) => {
let fields = get_delta_schema(m, true);
let fields = get_delta_schema(m);
let schema = <deltalake::arrow::datatypes::Schema as TryFrom<&Schema>>::try_from(
&SchemaTypeStruct::new(fields),
)
Expand Down Expand Up @@ -428,20 +552,26 @@ fn get_builder(t: &RuntimeType, capacity: usize) -> Result<Box<dyn ReflectArrayB
enum_descriptor: enum_descriptor.clone(),
}),
RuntimeType::Message(m) => {
let schema = Schema::new(get_delta_schema(m, true));
let schema = Schema::new(get_delta_schema(m));
let arrow_schema =
<deltalake::arrow::datatypes::Schema as TryFrom<&Schema>>::try_from(&schema)?;
let builders = m
.clone()
.fields()
.flat_map(|field| match field.runtime_field_type() {
protobuf::reflect::RuntimeFieldType::Singular(t) => Some(get_builder(&t, capacity)),
protobuf::reflect::RuntimeFieldType::Repeated(_) => {
None
protobuf::reflect::RuntimeFieldType::Singular(t) => {
Some(get_builder(&t, capacity))
}
protobuf::reflect::RuntimeFieldType::Map(_, _) => {
None
protobuf::reflect::RuntimeFieldType::Repeated(t) => {
let builder: Box<dyn ReflectArrayBuilder> =
Box::new(RepeatedReflectBuilder::new(
capacity,
get_builder(&t, 0).ok().unwrap(),
t,
));
Some(Ok(builder))
}
protobuf::reflect::RuntimeFieldType::Map(_, _) => None,
})
.collect::<Result<Vec<Box<dyn ReflectArrayBuilder>>>>()?;
Box::new(StructReflectBuilder {
Expand Down
8 changes: 4 additions & 4 deletions protobuf-delta-lake-sink/src/proto/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub fn get_single_delta_schema(field_name: &str, field_type: RuntimeType) -> Sch
protobuf::reflect::RuntimeType::Message(m) => {
return SchemaField::new(
field_name.to_string(),
SchemaDataType::r#struct(SchemaTypeStruct::new(get_delta_schema(&m, true))),
SchemaDataType::r#struct(SchemaTypeStruct::new(get_delta_schema(&m))),
true,
HashMap::new(),
);
Expand All @@ -88,14 +88,14 @@ pub fn get_single_delta_schema(field_name: &str, field_type: RuntimeType) -> Sch
)
}

pub fn get_delta_schema(descriptor: &MessageDescriptor, nested: bool) -> Vec<SchemaField> {
pub fn get_delta_schema(descriptor: &MessageDescriptor) -> Vec<SchemaField> {
descriptor
.fields()
.flat_map(|f| {
let field_name = f.name();
let field_type = match f.runtime_field_type() {
protobuf::reflect::RuntimeFieldType::Singular(t) => Some(t),
protobuf::reflect::RuntimeFieldType::Repeated(t) if !nested => {
protobuf::reflect::RuntimeFieldType::Repeated(t) => {
return Some(SchemaField::new(
field_name.to_string(),
SchemaDataType::array(SchemaTypeArray::new(
Expand All @@ -106,7 +106,7 @@ pub fn get_delta_schema(descriptor: &MessageDescriptor, nested: bool) -> Vec<Sch
HashMap::new(),
));
}
_ => None
_ => None,
};
field_type.map(|t| get_single_delta_schema(field_name, t))
})
Expand Down

0 comments on commit 55ed28e

Please sign in to comment.