Skip to content

Commit

Permalink
protobuf: Normalize protobuf message names (MaterializeInc#9385)
Browse files Browse the repository at this point in the history
Create new wrapper struct `NormalizedProtobufMessageName` that ensures protobuf message names have a leading dot.

### Motivation
The rust protoc library will panic if message names don't have a leading dot.  This is fixed in MaterializeInc#9381 but I think we can do better by ensuring the message names are properly formatted right at the source.

I think this approach is an improvement because:
1. We're encoding the information that these strings have an extra constraint into the type system.  It's now impossible to create a `DecodedDescriptors` with a message name that doesn't have a leading dot.
2. We don't have to reason about when we need to normalize them (e.g. making sure we add a leading dot before passing to any external library).
3. The `NormalizedProtobufMessageName`s are created without throwing any errors so it isn't any trickier / painful to use than a `String`
  • Loading branch information
cjubb39 authored Dec 2, 2021
1 parent d6c4aee commit 4502e6a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 13 deletions.
6 changes: 3 additions & 3 deletions src/dataflow-types/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use uuid::Uuid;

use expr::{GlobalId, MirRelationExpr, MirScalarExpr, OptimizedMirRelationExpr, PartitionId};
use interchange::avro::{self, DebeziumDeduplicationStrategy};
use interchange::protobuf;
use interchange::protobuf::{self, NormalizedProtobufMessageName};
use kafka_util::KafkaAddrs;
use repr::{ColumnName, ColumnType, Diff, RelationDesc, RelationType, Row, ScalarType, Timestamp};

Expand Down Expand Up @@ -366,7 +366,7 @@ impl DataEncoding {
DataEncoding::Protobuf(ProtobufEncoding {
descriptors,
message_name,
}) => protobuf::DecodedDescriptors::from_bytes(descriptors, message_name.into())?
}) => protobuf::DecodedDescriptors::from_bytes(descriptors, message_name.to_owned())?
.columns()
.iter()
.fold(RelationDesc::empty(), |desc, (name, ty)| {
Expand Down Expand Up @@ -451,7 +451,7 @@ pub struct AvroOcfEncoding {
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct ProtobufEncoding {
pub descriptors: Vec<u8>,
pub message_name: String,
pub message_name: NormalizedProtobufMessageName,
}

/// Arguments necessary to define how to decode from CSV format
Expand Down
9 changes: 6 additions & 3 deletions src/interchange/benches/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
use criterion::{black_box, Criterion, Throughput};
use protobuf::{Message, MessageField};

use interchange::protobuf::{DecodedDescriptors, Decoder};
use interchange::protobuf::{DecodedDescriptors, Decoder, NormalizedProtobufMessageName};

use gen::benchmark::{Connector, Record, Value};

Expand Down Expand Up @@ -64,8 +64,11 @@ pub fn bench_protobuf(c: &mut Criterion) {
.expect("record failed to serialize to bytes");
let len = buf.len() as u64;
let mut decoder = Decoder::new(
DecodedDescriptors::from_bytes(gen::FILE_DESCRIPTOR_SET_DATA, ".bench.Record".to_string())
.unwrap(),
DecodedDescriptors::from_bytes(
gen::FILE_DESCRIPTOR_SET_DATA,
NormalizedProtobufMessageName::new(".bench.Record".to_string()),
)
.unwrap(),
);

let mut bg = c.benchmark_group("protobuf");
Expand Down
24 changes: 20 additions & 4 deletions src/interchange/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,26 @@ use protobuf::reflect::{
RuntimeFieldType, RuntimeTypeBox,
};
use protobuf::{CodedInputStream, Message, MessageDyn};
use serde::{Deserialize, Serialize};

use ore::str::StrExt;
use repr::{ColumnName, ColumnType, Datum, Row, ScalarType};

/// Wrapper type that ensures a protobuf message name is properly normalized.
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct NormalizedProtobufMessageName(String);

impl NormalizedProtobufMessageName {
/// Create a new normalized protobuf message name. A leading dot will be
/// prepended to the provided message name if necessary.
pub fn new(mut message_name: String) -> Self {
if !message_name.starts_with('.') {
message_name = format!(".{}", message_name);
}
NormalizedProtobufMessageName(message_name)
}
}

/// A decoded description of the schema of a Protobuf message.
#[derive(Debug)]
pub struct DecodedDescriptors {
Expand All @@ -33,10 +49,10 @@ impl DecodedDescriptors {
/// Builds a `DecodedDescriptors` from an encoded [`FileDescriptorSet`]
/// and the fully qualified name of a message inside that file descriptor
/// set.
pub fn from_bytes(bytes: &[u8], mut message_name: String) -> Result<Self, anyhow::Error> {
if !message_name.starts_with('.') {
message_name = format!(".{}", message_name);
}
pub fn from_bytes(
bytes: &[u8],
NormalizedProtobufMessageName(message_name): NormalizedProtobufMessageName,
) -> Result<Self, anyhow::Error> {
let fds =
FileDescriptorSet::parse_from_bytes(bytes).context("parsing file descriptor set")?;
let fds = FileDescriptor::new_dynamic_fds(fds.file);
Expand Down
11 changes: 8 additions & 3 deletions src/sql/src/plan/statement/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::time::Duration;
use anyhow::{anyhow, bail};
use aws_arn::ARN;
use globset::GlobBuilder;
use interchange::protobuf::NormalizedProtobufMessageName;
use itertools::Itertools;
use log::{debug, error};
use regex::Regex;
Expand Down Expand Up @@ -1055,13 +1056,17 @@ fn get_encoding_inner<T: sql_parser::ast::AstInfo>(
{
let value = DataEncoding::Protobuf(ProtobufEncoding {
descriptors: strconv::parse_bytes(&value.schema)?,
message_name: value.message_name.clone(),
message_name: NormalizedProtobufMessageName::new(
value.message_name.clone(),
),
});
if let Some(key) = key {
return Ok(SourceDataEncoding::KeyValue {
key: DataEncoding::Protobuf(ProtobufEncoding {
descriptors: strconv::parse_bytes(&key.schema)?,
message_name: key.message_name.clone(),
message_name: NormalizedProtobufMessageName::new(
key.message_name.clone(),
),
}),
value,
});
Expand All @@ -1084,7 +1089,7 @@ fn get_encoding_inner<T: sql_parser::ast::AstInfo>(

DataEncoding::Protobuf(ProtobufEncoding {
descriptors,
message_name: message_name.to_owned(),
message_name: NormalizedProtobufMessageName::new(message_name.to_owned()),
})
}
},
Expand Down

0 comments on commit 4502e6a

Please sign in to comment.