Skip to content

Commit

Permalink
Add a simple validation transform to yaml. (#32956)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored Nov 2, 2024
1 parent 90c1ee9 commit eed82f0
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
61 changes: 61 additions & 0 deletions sdks/python/apache_beam/yaml/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType:
raise ValueError(f'Unable to convert {json_type} to a Beam schema.')


def beam_schema_to_json_schema(
beam_schema: schema_pb2.Schema) -> Dict[str, Any]:
return {
'type': 'object',
'properties': {
field.name: beam_type_to_json_type(field.type)
for field in beam_schema.fields
},
'additionalProperties': False
}


def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]:
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
Expand Down Expand Up @@ -267,3 +279,52 @@ def json_formater(
convert = row_to_json(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))
return lambda row: json.dumps(convert(row), sort_keys=True).encode('utf-8')


def _validate_compatible(weak_schema, strong_schema):
if not weak_schema:
return
if weak_schema['type'] != strong_schema['type']:
raise ValueError(
'Incompatible types: %r vs %r' %
(weak_schema['type'] != strong_schema['type']))
if weak_schema['type'] == 'array':
_validate_compatible(weak_schema['items'], strong_schema['items'])
elif weak_schema == 'object':
for required in strong_schema.get('required', []):
if required not in weak_schema['properties']:
raise ValueError('Missing or unkown property %r' % required)
for name, spec in weak_schema.get('properties', {}):
if name in strong_schema['properties']:
try:
_validate_compatible(spec, strong_schema['properties'][name])
except Exception as exn:
raise ValueError('Incompatible schema for %r' % name) from exn
elif not strong_schema.get('additionalProperties'):
raise ValueError(
'Prohibited property: {property}; '
'perhaps additionalProperties: False is missing?')


def row_validator(beam_schema: schema_pb2.Schema,
json_schema: Dict[str, Any]) -> Callable[[Any], Any]:
"""Returns a callable that will fail on elements not respecting json_schema.
"""
if not json_schema:
return lambda x: None

# Validate that this compiles, but avoid pickling the validator itself.
_ = jsonschema.validators.validator_for(json_schema)(json_schema)
_validate_compatible(beam_schema_to_json_schema(beam_schema), json_schema)
validator = None

convert = row_to_json(
schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=beam_schema)))

def validate(row):
nonlocal validator
if validator is None:
validator = jsonschema.validators.validator_for(json_schema)(json_schema)
validator.validate(convert(row))

return validate
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from apache_beam.typehints.native_type_compatibility import convert_to_beam_type
from apache_beam.typehints.row_type import RowTypeConstraint
from apache_beam.typehints.schemas import named_fields_from_element_type
from apache_beam.typehints.schemas import schema_from_element_type
from apache_beam.utils import python_callable
from apache_beam.yaml import json_utils
from apache_beam.yaml import options
Expand Down Expand Up @@ -435,7 +436,8 @@ def _map_errors_to_standard_format(input_type):
# TODO(https://github.com/apache/beam/issues/24755): Switch to MapTuple.

return beam.Map(
lambda x: beam.Row(element=x[0], msg=str(x[1][1]), stack=str(x[1][2]))
lambda x: beam.Row(
element=x[0], msg=str(x[1][1]), stack=''.join(x[1][2]))
).with_output_types(
RowTypeConstraint.from_fields([("element", input_type), ("msg", str),
("stack", str)]))
Expand Down Expand Up @@ -475,6 +477,40 @@ def expand(pcoll, error_handling=None, **kwargs):
return expand


class _Validate(beam.PTransform):
"""Validates each element of a PCollection against a json schema.
Args:
schema: A json schema against which to validate each element.
error_handling: Whether and how to handle errors during iteration.
If this is not set, invalid elements will fail the pipeline, otherwise
invalid elements will be passed to the specified error output along
with information about how the schema was invalidated.
"""
def __init__(
self,
schema: Dict[str, Any],
error_handling: Optional[Mapping[str, Any]] = None):
self._schema = schema
self._exception_handling_args = exception_handling_args(error_handling)

@maybe_with_exception_handling
def expand(self, pcoll):
validator = json_utils.row_validator(
schema_from_element_type(pcoll.element_type), self._schema)

def invoke_validator(x):
validator(x)
return x

return pcoll | beam.Map(invoke_validator)

def with_exception_handling(self, **kwargs):
# It's possible there's an error in iteration...
self._exception_handling_args = kwargs
return self


class _Explode(beam.PTransform):
"""Explodes (aka unnest/flatten) one or more fields producing multiple rows.
Expand Down Expand Up @@ -797,6 +833,7 @@ def create_mapping_providers():
'Partition-python': _Partition,
'Partition-javascript': _Partition,
'Partition-generic': _Partition,
'ValidateWithSchema': _Validate,
}),
yaml_provider.SqlBackedProvider({
'Filter-sql': _SqlFilterTransform,
Expand Down
37 changes: 37 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_mapping_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ def test_explode(self):
beam.Row(a=3, b='y', c=.125, range=2),
]))

def test_validate(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
elements = p | beam.Create([
beam.Row(key='good', small=[5], nested=beam.Row(big=100)),
beam.Row(key='bad1', small=[500], nested=beam.Row(big=100)),
beam.Row(key='bad2', small=[5], nested=beam.Row(big=1)),
])
result = elements | YamlTransform(
'''
type: ValidateWithSchema
config:
schema:
type: object
properties:
small:
type: array
items:
type: integer
maximum: 10
nested:
type: object
properties:
big:
type: integer
minimum: 10
error_handling:
output: bad
''')

assert_that(
result['good'] | beam.Map(lambda x: x.key), equal_to(['good']))
assert_that(
result['bad'] | beam.Map(lambda x: x.element.key),
equal_to(['bad1', 'bad2']),
label='Errors')

def test_validate_explicit_types(self):
with self.assertRaisesRegex(TypeError, r'.*violates schema.*'):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
Expand Down

0 comments on commit eed82f0

Please sign in to comment.