From eed82f012e96ed66c1d17449493dd39a792d8284 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 1 Nov 2024 17:09:20 -0700 Subject: [PATCH] Add a simple validation transform to yaml. (#32956) --- sdks/python/apache_beam/yaml/json_utils.py | 61 +++++++++++++++++++ sdks/python/apache_beam/yaml/yaml_mapping.py | 39 +++++++++++- .../apache_beam/yaml/yaml_mapping_test.py | 37 +++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/yaml/json_utils.py b/sdks/python/apache_beam/yaml/json_utils.py index 40e515ee6946..76cc80bc2036 100644 --- a/sdks/python/apache_beam/yaml/json_utils.py +++ b/sdks/python/apache_beam/yaml/json_utils.py @@ -106,6 +106,18 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType: raise ValueError(f'Unable to convert {json_type} to a Beam schema.') +def beam_schema_to_json_schema( + beam_schema: schema_pb2.Schema) -> Dict[str, Any]: + return { + 'type': 'object', + 'properties': { + field.name: beam_type_to_json_type(field.type) + for field in beam_schema.fields + }, + 'additionalProperties': False + } + + def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]: type_info = beam_type.WhichOneof("type_info") if type_info == "atomic_type": @@ -267,3 +279,52 @@ def json_formater( convert = row_to_json( schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) return lambda row: json.dumps(convert(row), sort_keys=True).encode('utf-8') + + +def _validate_compatible(weak_schema, strong_schema): + if not weak_schema: + return + if weak_schema['type'] != strong_schema['type']: + raise ValueError( + 'Incompatible types: %r vs %r' % + (weak_schema['type'] != strong_schema['type'])) + if weak_schema['type'] == 'array': + _validate_compatible(weak_schema['items'], strong_schema['items']) + elif weak_schema == 'object': + for required in strong_schema.get('required', []): + if required not in weak_schema['properties']: + raise ValueError('Missing or unkown property %r' % required) + for name, spec in weak_schema.get('properties', {}): + if name in strong_schema['properties']: + try: + _validate_compatible(spec, strong_schema['properties'][name]) + except Exception as exn: + raise ValueError('Incompatible schema for %r' % name) from exn + elif not strong_schema.get('additionalProperties'): + raise ValueError( + 'Prohibited property: {property}; ' + 'perhaps additionalProperties: False is missing?') + + +def row_validator(beam_schema: schema_pb2.Schema, + json_schema: Dict[str, Any]) -> Callable[[Any], Any]: + """Returns a callable that will fail on elements not respecting json_schema. + """ + if not json_schema: + return lambda x: None + + # Validate that this compiles, but avoid pickling the validator itself. + _ = jsonschema.validators.validator_for(json_schema)(json_schema) + _validate_compatible(beam_schema_to_json_schema(beam_schema), json_schema) + validator = None + + convert = row_to_json( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema))) + + def validate(row): + nonlocal validator + if validator is None: + validator = jsonschema.validators.validator_for(json_schema)(json_schema) + validator.validate(convert(row)) + + return validate diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py b/sdks/python/apache_beam/yaml/yaml_mapping.py index 377bcac0e31a..960fcdeecf30 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping.py @@ -43,6 +43,7 @@ from apache_beam.typehints.native_type_compatibility import convert_to_beam_type from apache_beam.typehints.row_type import RowTypeConstraint from apache_beam.typehints.schemas import named_fields_from_element_type +from apache_beam.typehints.schemas import schema_from_element_type from apache_beam.utils import python_callable from apache_beam.yaml import json_utils from apache_beam.yaml import options @@ -435,7 +436,8 @@ def _map_errors_to_standard_format(input_type): # TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple. return beam.Map( - lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2])) + lambda x: beam.Row( + element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2])) ).with_output_types( RowTypeConstraint.from_fields([("element", input_type), ("msg", str), ("stack", str)])) @@ -475,6 +477,40 @@ def expand(pcoll, error_handling=None, **kwargs): return expand +class _Validate(beam.PTransform): + """Validates each element of a PCollection against a json schema. + + Args: + schema: A json schema against which to validate each element. + error_handling: Whether and how to handle errors during iteration. + If this is not set, invalid elements will fail the pipeline, otherwise + invalid elements will be passed to the specified error output along + with information about how the schema was invalidated. + """ + def __init__( + self, + schema: Dict[str, Any], + error_handling: Optional[Mapping[str, Any]] = None): + self._schema = schema + self._exception_handling_args = exception_handling_args(error_handling) + + @maybe_with_exception_handling + def expand(self, pcoll): + validator = json_utils.row_validator( + schema_from_element_type(pcoll.element_type), self._schema) + + def invoke_validator(x): + validator(x) + return x + + return pcoll | beam.Map(invoke_validator) + + def with_exception_handling(self, **kwargs): + # It's possible there's an error in iteration... + self._exception_handling_args = kwargs + return self + + class _Explode(beam.PTransform): """Explodes (aka unnest/flatten) one or more fields producing multiple rows. @@ -797,6 +833,7 @@ def create_mapping_providers(): 'Partition-python': _Partition, 'Partition-javascript': _Partition, 'Partition-generic': _Partition, + 'ValidateWithSchema': _Validate, }), yaml_provider.SqlBackedProvider({ 'Filter-sql': _SqlFilterTransform, diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 1b74a765e54b..2c5feec18278 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -134,6 +134,43 @@ def test_explode(self): beam.Row(a=3, b='y', c=.125, range=2), ])) + def test_validate(self): + with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions( + pickle_library='cloudpickle')) as p: + elements = p | beam.Create([ + beam.Row(key='good', small=[5], nested=beam.Row(big=100)), + beam.Row(key='bad1', small=[500], nested=beam.Row(big=100)), + beam.Row(key='bad2', small=[5], nested=beam.Row(big=1)), + ]) + result = elements | YamlTransform( + ''' + type: ValidateWithSchema + config: + schema: + type: object + properties: + small: + type: array + items: + type: integer + maximum: 10 + nested: + type: object + properties: + big: + type: integer + minimum: 10 + error_handling: + output: bad + ''') + + assert_that( + result['good'] | beam.Map(lambda x: x.key), equal_to(['good'])) + assert_that( + result['bad'] | beam.Map(lambda x: x.element.key), + equal_to(['bad1', 'bad2']), + label='Errors') + def test_validate_explicit_types(self): with self.assertRaisesRegex(TypeError, r'.*violates schema.*'): with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(