From 55ed28eda8686b62834d60809c091f84e1e85904 Mon Sep 17 00:00:00 2001 From: Noah Prince <83885631+ChewingGlass@users.noreply.github.com> Date: Fri, 16 Feb 2024 17:08:47 +0900 Subject: [PATCH] Add support for repeated fields in nested messages (#24) * Add support for repeated fields in nested messages * Clean up --- protobuf-delta-lake-sink/src/main.rs | 4 +- protobuf-delta-lake-sink/src/proto/parse.rs | 156 +++++++++++++++++-- protobuf-delta-lake-sink/src/proto/schema.rs | 8 +- 3 files changed, 149 insertions(+), 19 deletions(-) diff --git a/protobuf-delta-lake-sink/src/main.rs b/protobuf-delta-lake-sink/src/main.rs index 3dcdcc2..8eab786 100644 --- a/protobuf-delta-lake-sink/src/main.rs +++ b/protobuf-delta-lake-sink/src/main.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Context, Result}; use chrono::{NaiveDateTime, Utc}; use clap::Parser; -use datafusion::{arrow::array::StringArray, common::delta}; +use datafusion::arrow::array::StringArray; use deltalake::{ action::{self, Action, CommitInfo, SaveMode}, checkpoints, crate_version, @@ -116,7 +116,7 @@ async fn main() -> Result<()> { args.source_proto_name, ) .await?; - let mut delta_fields = get_delta_schema(&descriptor, false); + let mut delta_fields = get_delta_schema(&descriptor); if args.partition_timestamp_column.is_some() { let date_field = SchemaField::new( "date".to_string(), diff --git a/protobuf-delta-lake-sink/src/proto/parse.rs b/protobuf-delta-lake-sink/src/proto/parse.rs index f37979f..0c9c16b 100644 --- a/protobuf-delta-lake-sink/src/proto/parse.rs +++ b/protobuf-delta-lake-sink/src/proto/parse.rs @@ -13,7 +13,7 @@ use deltalake::{ }, Schema, SchemaTypeStruct, }; -use protobuf::reflect::{EnumDescriptor, ReflectValueBox}; +use protobuf::reflect::{EnumDescriptor, ReflectRepeatedRef, ReflectValueBox}; use protobuf::{ reflect::{FieldDescriptor, MessageDescriptor, ReflectValueRef, RuntimeType}, MessageDyn, @@ -24,6 +24,7 @@ use super::get_delta_schema; trait ReflectBuilder: ArrayBuilder { fn append_value(&mut self, v: Option); + fn append_repeated_value(&mut self, v: Option); } macro_rules! make_builder_wrapper { @@ -75,6 +76,10 @@ impl ReflectBuilder for BinaryReflectBuilder { .unwrap_or_default(), ) } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } impl ReflectBuilder for StringReflectBuilder { @@ -88,6 +93,10 @@ impl ReflectBuilder for StringReflectBuilder { .unwrap_or_default(), ) } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } impl ReflectBuilder for BoolReflectBuilder { @@ -97,6 +106,10 @@ impl ReflectBuilder for BoolReflectBuilder { .unwrap_or_default(), ) } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } pub struct EnumReflectBuilder { @@ -147,6 +160,10 @@ impl ReflectBuilder for EnumReflectBuilder { .unwrap_or_default(), ) } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } struct PrimitiveReflectBuilder { @@ -190,6 +207,10 @@ impl ReflectBuilder for PrimitiveReflectBuilder { .unwrap_or_default(), ); } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } impl ReflectBuilder for PrimitiveReflectBuilder { @@ -202,6 +223,10 @@ impl ReflectBuilder for PrimitiveReflectBuilder { .unwrap_or_default(), ); } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } pub struct U64ReflectBuilder { @@ -263,6 +288,10 @@ impl ReflectBuilder for U64ReflectBuilder { .unwrap_or_default(), ); } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } impl ReflectBuilder for PrimitiveReflectBuilder { @@ -272,6 +301,10 @@ impl ReflectBuilder for PrimitiveReflectBuilder { .unwrap_or_default(), ); } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } impl ReflectBuilder for PrimitiveReflectBuilder { @@ -281,6 +314,96 @@ impl ReflectBuilder for PrimitiveReflectBuilder { .unwrap_or_default(), ); } + + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } +} + +struct RepeatedReflectBuilder { + pub builder: Box, + pub offsets: BufferBuilder, + pub t: RuntimeType, + pub capacity: usize, +} + +impl RepeatedReflectBuilder { + fn new( + capacity: usize, + builder: Box, + t: RuntimeType, + ) -> RepeatedReflectBuilder { + let mut offsets = BufferBuilder::::new(0); + offsets.append(0); + RepeatedReflectBuilder { + builder, + offsets, + capacity, + t, + } + } +} + +impl ArrayBuilder for RepeatedReflectBuilder { + fn len(&self) -> usize { + self.capacity + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn finish(&mut self) -> ArrayRef { + let field = Arc::new(Field::new("item", runtime_type_to_data_type(&self.t), true)); + let data_type = DataType::List(field); + let values_arr = self.builder.finish(); + let values_data = values_arr.to_data(); + let array_data_builder = ArrayData::builder(data_type) + .len(self.capacity) + .add_buffer(self.offsets.finish()) + .add_child_data(values_data) + .null_bit_buffer(None); + // .null_bit_buffer(Some(self.nulls.finish().values().clone().into_inner())); + let array_data = unsafe { array_data_builder.build_unchecked() }; + Arc::new(GenericListArray::::from(array_data)) + } + + fn finish_cloned(&self) -> ArrayRef { + unimplemented!() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_box_any(self: Box) -> Box { + self + } +} + +impl ReflectBuilder for RepeatedReflectBuilder { + fn append_value(&mut self, _: Option) { + panic!("Operation not supported!"); + } + fn append_repeated_value(&mut self, v: Option) { + let messages = v + .iter() + .flat_map(|i| i.into_iter().collect::>()) + .collect::>(); + + // self.nulls.append_value(v.is_none()); + + for value in messages.iter() { + self.builder.append_value(Some(value.clone())); + } + + self.offsets + .append(i32::try_from(self.builder.len()).unwrap()); + } } struct StructReflectBuilder { @@ -347,10 +470,7 @@ impl StructReflectBuilder { impl ReflectBuilder for StructReflectBuilder { fn append_value(&mut self, v: Option) { - let message_ref = v - .map(|i| { - i.to_message().expect("Not a message") - }); + let message_ref = v.map(|i| i.to_message().expect("Not a message")); let message = message_ref.as_deref(); for (index, field) in self.descriptor.fields().enumerate() { match field.runtime_field_type() { @@ -359,7 +479,8 @@ impl ReflectBuilder for StructReflectBuilder { builder.append_value(message.and_then(|m| field.get_singular(m))) } protobuf::reflect::RuntimeFieldType::Repeated(_) => { - // Do nothing + let builder = self.builders.get_mut(index).unwrap(); + builder.append_repeated_value(message.map(|m| field.get_repeated(m))) } protobuf::reflect::RuntimeFieldType::Map(_, _) => { panic!("Map fields are not supported") @@ -367,6 +488,9 @@ impl ReflectBuilder for StructReflectBuilder { }; } } + fn append_repeated_value(&mut self, _: Option) { + panic!("Operation not supported"); + } } fn runtime_type_to_data_type(value: &RuntimeType) -> DataType { @@ -382,7 +506,7 @@ fn runtime_type_to_data_type(value: &RuntimeType) -> DataType { RuntimeType::VecU8 => DataType::Binary, RuntimeType::Enum(_) => DataType::Binary, RuntimeType::Message(m) => { - let fields = get_delta_schema(m, true); + let fields = get_delta_schema(m); let schema = >::try_from( &SchemaTypeStruct::new(fields), ) @@ -428,20 +552,26 @@ fn get_builder(t: &RuntimeType, capacity: usize) -> Result { - let schema = Schema::new(get_delta_schema(m, true)); + let schema = Schema::new(get_delta_schema(m)); let arrow_schema = >::try_from(&schema)?; let builders = m .clone() .fields() .flat_map(|field| match field.runtime_field_type() { - protobuf::reflect::RuntimeFieldType::Singular(t) => Some(get_builder(&t, capacity)), - protobuf::reflect::RuntimeFieldType::Repeated(_) => { - None + protobuf::reflect::RuntimeFieldType::Singular(t) => { + Some(get_builder(&t, capacity)) } - protobuf::reflect::RuntimeFieldType::Map(_, _) => { - None + protobuf::reflect::RuntimeFieldType::Repeated(t) => { + let builder: Box = + Box::new(RepeatedReflectBuilder::new( + capacity, + get_builder(&t, 0).ok().unwrap(), + t, + )); + Some(Ok(builder)) } + protobuf::reflect::RuntimeFieldType::Map(_, _) => None, }) .collect::>>>()?; Box::new(StructReflectBuilder { diff --git a/protobuf-delta-lake-sink/src/proto/schema.rs b/protobuf-delta-lake-sink/src/proto/schema.rs index dd2a0d2..70736a0 100644 --- a/protobuf-delta-lake-sink/src/proto/schema.rs +++ b/protobuf-delta-lake-sink/src/proto/schema.rs @@ -73,7 +73,7 @@ pub fn get_single_delta_schema(field_name: &str, field_type: RuntimeType) -> Sch protobuf::reflect::RuntimeType::Message(m) => { return SchemaField::new( field_name.to_string(), - SchemaDataType::r#struct(SchemaTypeStruct::new(get_delta_schema(&m, true))), + SchemaDataType::r#struct(SchemaTypeStruct::new(get_delta_schema(&m))), true, HashMap::new(), ); @@ -88,14 +88,14 @@ pub fn get_single_delta_schema(field_name: &str, field_type: RuntimeType) -> Sch ) } -pub fn get_delta_schema(descriptor: &MessageDescriptor, nested: bool) -> Vec { +pub fn get_delta_schema(descriptor: &MessageDescriptor) -> Vec { descriptor .fields() .flat_map(|f| { let field_name = f.name(); let field_type = match f.runtime_field_type() { protobuf::reflect::RuntimeFieldType::Singular(t) => Some(t), - protobuf::reflect::RuntimeFieldType::Repeated(t) if !nested => { + protobuf::reflect::RuntimeFieldType::Repeated(t) => { return Some(SchemaField::new( field_name.to_string(), SchemaDataType::array(SchemaTypeArray::new( @@ -106,7 +106,7 @@ pub fn get_delta_schema(descriptor: &MessageDescriptor, nested: bool) -> Vec None + _ => None, }; field_type.map(|t| get_single_delta_schema(field_name, t)) })