Skip to content

Commit

Permalink
Merge pull request #1 from HarshMN2345/HarshMN2345-patch-1
Browse files Browse the repository at this point in the history
Update grpc.rs
  • Loading branch information
HarshMN2345 authored Dec 3, 2024
2 parents 41861db + 19dabaf commit 5978b97
Showing 1 changed file with 113 additions and 111 deletions.
224 changes: 113 additions & 111 deletions src/core/blueprint/operators/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,81 +64,69 @@ fn validate_schema(
field_schema: FieldSchema,
operation: &ProtobufOperation,
name: &str,
) -> Valid<(), BlueprintError> {
) -> Valid<(), String> {
let input_type = &operation.input_type;
let output_type = &operation.output_type;

let input_type = match JsonSchema::try_from(input_type) {
Ok(input_schema) => Valid::succeed(input_schema),
Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)),
};
Valid::from(JsonSchema::try_from(input_type))
.zip(Valid::from(JsonSchema::try_from(output_type)))
.and_then(|(input_schema, output_schema)| {
let fields = &field_schema.field;
let args = &field_schema.args;

let output_type = match JsonSchema::try_from(output_type) {
Ok(output_type) => Valid::succeed(output_type),
Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)),
};
// Treat repeated message types as optional in input schema
let normalized_input_schema = normalize_repeated_types(&input_schema);

input_type
.zip(output_type)
.and_then(|(_input_schema, sub_type)| {
// TODO: add validation for input schema - should compare result grpc.body to
// schema
let super_type = field_schema.field;
// TODO: all of the fields in protobuf are optional actually
// and if we want to mark some fields as required in GraphQL
// JsonSchema won't match and the validation will fail
match sub_type.is_a(&super_type, name).to_result() {
Ok(res) => Valid::succeed(res),
Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)),
}
// Validate input schema against args
args.compare(&normalized_input_schema, &format!("Input validation failed for {}", name))?;

// Validate output schema against fields
fields.compare(&output_schema, &format!("Output validation failed for {}", name))
})
}

fn normalize_repeated_types(schema: &JsonSchema) -> JsonSchema {
match schema {
JsonSchema::Arr(inner_schema) => {
// Treat repeated types (arrays) as optional
JsonSchema::Optional(Box::new(inner_schema.clone()))
}
JsonSchema::Object(fields) => {
let normalized_fields = fields
.iter()
.map(|(key, value)| (key.clone(), normalize_repeated_types(value)))
.collect();
JsonSchema::Object(normalized_fields)
}
_ => schema.clone(),
}
}
fn validate_group_by(
field_schema: &FieldSchema,
operation: &ProtobufOperation,
group_by: Vec<String>,
) -> Valid<(), BlueprintError> {
) -> Valid<(), String> {
let input_type = &operation.input_type;
let output_type = &operation.output_type;
let mut field_descriptor: Result<FieldDescriptor, ValidationError<BlueprintError>> = None
.ok_or(ValidationError::new(BlueprintError::FieldNotFound(
group_by[0].clone(),
)));
for item in group_by.iter().take(&group_by.len() - 1) {
field_descriptor =
output_type
.get_field_by_json_name(item.as_str())
.ok_or(ValidationError::new(BlueprintError::FieldNotFound(
item.clone(),
)));
}
let output_type = field_descriptor
.and_then(|f| JsonSchema::try_from(&f).map_err(BlueprintError::from_validation_string));

let json_schema = match JsonSchema::try_from(input_type) {
Ok(schema) => Valid::succeed(schema),
Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)),
};
let input_schema = JsonSchema::try_from(input_type)?;
let output_schema = JsonSchema::try_from(output_type)?;

json_schema
.zip(Valid::from(output_type))
.and_then(|(_input_schema, output_schema)| {
// TODO: add validation for input schema - should compare result grpc.body to
// schema considering repeated message type
let fields = &field_schema.field;
// we're treating List types for gRPC as optional.
let fields = JsonSchema::Opt(Box::new(JsonSchema::Arr(Box::new(fields.to_owned()))));
match fields
.is_a(&output_schema, group_by[0].as_str())
.to_result()
{
Ok(res) => Valid::succeed(res),
Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)),
}
})
let normalized_input_schema = normalize_repeated_types(&input_schema);

let fields = JsonSchema::Arr(Box::new(field_schema.field.to_owned()));
let args = JsonSchema::Arr(Box::new(field_schema.args.to_owned()));

args.compare(
&normalized_input_schema,
&format!("Input validation failed for group_by {:?}", group_by),
)?;
fields.compare(
&output_schema,
&format!("Output validation failed for group_by {:?}", group_by),
)
}


pub struct CompileGrpc<'a> {
pub config_module: &'a ConfigModule,
pub operation_type: &'a GraphQLOperationType,
Expand Down Expand Up @@ -187,63 +175,35 @@ pub fn compile_grpc(inputs: CompileGrpc) -> Valid<IR, BlueprintError> {
let validate_with_schema = inputs.validate_with_schema;
let dedupe = grpc.dedupe.unwrap_or_default();

Valid::from(GrpcMethod::try_from(grpc.method.as_str()))
.and_then(|method| {
let file_descriptor_set = config_module.extensions().get_file_descriptor_set();
Valid::from(GrpcMethod::try_from(grpc.method.as_str()))
.and_then(|method| {
let file_descriptor_set = config_module.extensions().get_file_descriptor_set();

if file_descriptor_set.file.is_empty() {
return Valid::fail(BlueprintError::ProtobufFilesNotSpecifiedInConfig);
}
if file_descriptor_set.file.is_empty() {
return Valid::fail("Protobuf files were not specified in the config".to_string());
}

match to_operation(&method, file_descriptor_set)
.fuse(to_url(grpc, &method))
.fuse(helpers::headers::to_mustache_headers(&grpc.headers))
.fuse(helpers::body::to_body(grpc.body.as_ref()))
.to_result()
{
Ok(data) => Valid::succeed(data),
Err(e) => Valid::from_validation_err(BlueprintError::from_validation_string(e)),
}
})
.and_then(|(operation, url, headers, body)| {
let validation = if validate_with_schema {
let field_schema = json_schema_from_field(config_module, field);
if grpc.batch_key.is_empty() {
validate_schema(field_schema, &operation, field.type_of.name()).unit()
} else {
validate_group_by(&field_schema, &operation, grpc.batch_key.clone()).unit()
}
} else {
Valid::succeed(())
};
validation.map(|_| (url, headers, operation, body))
})
.map(|(url, headers, operation, body)| {
let req_template = RequestTemplate {
url,
headers,
operation,
body,
operation_type: operation_type.clone(),
};
let on_response = grpc.on_response_body.clone();
let hook = WorkerHooks::try_new(None, on_response).ok();

let io = if !grpc.batch_key.is_empty() {
IR::IO(IO::Grpc {
req_template,
group_by: Some(GroupBy::new(grpc.batch_key.clone(), None)),
dl_id: None,
dedupe,
hook,
})
to_operation(&method, file_descriptor_set)
.fuse(to_url(grpc, &method, config_module))
.fuse(helpers::headers::to_mustache_headers(&grpc.headers))
.fuse(helpers::body::to_body(grpc.body.as_ref()))
.into()
})
.and_then(|(operation, url, headers, body)| {
let validation = if validate_with_schema {
let field_schema = json_schema_from_field(config_module, field);
if grpc.batch_key.is_empty() {
// Add input validation with repeated type normalization
validate_schema(field_schema, &operation, field.name()).unit()
} else {
IR::IO(IO::Grpc { req_template, group_by: None, dl_id: None, dedupe, hook })
};
validate_group_by(&field_schema, &operation, grpc.batch_key.clone()).unit()
}
} else {
Valid::succeed(())
};
validation.map(|_| (url, headers, operation, body))
})

(io, &grpc.select)
})
.and_then(apply_select)
}

#[cfg(test)]
Expand All @@ -254,6 +214,22 @@ mod tests {

use super::GrpcMethod;
use crate::core::blueprint::BlueprintError;
#[test]
fn validate_repeated_types_as_optional() {
let operation = ProtobufOperation {
input_type: "RepeatedInputType".to_string(),
output_type: "ValidOutputType".to_string(),
};

let field_schema = FieldSchema {
args: JsonSchema::Arr(Box::new(JsonSchema::String)),
field: JsonSchema::Object(HashMap::new()),
};

let result = validate_schema(field_schema, &operation, "test_operation");
assert!(result.is_ok());
}


#[test]
fn try_from_grpc_method() {
Expand All @@ -268,6 +244,32 @@ mod tests {
assert_eq!(method1.service, "ServiceName");
assert_eq!(method1.name, "MethodName");
}
#[test]
fn grpc_repeated_types_validation_integration() {
let config_module = MockConfigModule::new();
let operation_type = GraphQLOperationType::Query;
let field = Field::new("test_field", "RepeatedInputType");

let grpc = Grpc {
method: "package.Service.Method".to_string(),
base_url: Some("http://localhost:5000".to_string()),
headers: None,
body: Some(vec!["repeated_field"]),
batch_key: vec![],
};

let compile_inputs = CompileGrpc {
config_module: &config_module,
operation_type: &operation_type,
field: &field,
grpc: &grpc,
validate_with_schema: true,
};

let result = compile_grpc(compile_inputs);
assert!(result.is_ok());
}


#[test]
fn try_from_grpc_method_invalid() {
Expand Down

0 comments on commit 5978b97

Please sign in to comment.