-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RunInference.with_errors() API #14
base: master
Are you sure you want to change the base?
Conversation
tfx_bsl/beam/run_inference.py
Outdated
return beam.pvalue.TaggedOutput(_get_operation_type(batch[0]), batch) | ||
else: | ||
try: | ||
return beam.pvalue.TaggedOutput(_get_operation_type(batch[0]), batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this throw error? It seems _get_operation_type will return str or unicode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_get_operation_type
has return value Text
(str/unicode) and beam.pvalue.TaggedOutput
accepts either str or unicode for a tag. See: https://github.com/apache/beam/blob/master/sdks/python/apache_beam/pvalue.py#L328-L343
Benchmarks showed that TagByOperation was a performance bottleneck* as it requires disc access per query batch. To mitigate this I implemented operation caching inside the DoFn. For readability, I also renamed this operation to "SplitByOperation" as that more accurately describes its purpose. On a dataset with 1m examples, TagByOperation took ~25% of the total wall time. After implementing caching, this was reduced to ~2%.
with_errors()
APIThis PR introduces the
RunInference(...).with_errors()
API which allows users to catch runtime errors as a separate PCollection stream.By default, runtime errors (for example invalid model specs or invalid examples) are thrown which can crash a pipeline. If this is not desirable, users can use
.with_errors()
to catch runtime errors:The error output stream has the type:
Tuple[Exception, Any]
and contains both the original error and whatever object is relevant to the error.Note: when runtime errors are allowed to be raised, they are raised from their original location (e.g. inside a nested PTransform) which makes debugging easier.
Internal details
RunInferenceImpl
is now a class (rather than a function with @beam.ptransform_fn). This enables us to add awith_errors
method that can take effect in theexpand
method.RunInferenceCore
returns a dict containing{'predictions': ..., 'errors': ...}
and takes an additionalcatch_errors: bool = False
parameter which indicates whether to catch or allow runtime errors.Added the
_ParDoExceptionWrapper
utility which runsbeam.ParDo
on a providedbeam.DoFn
and optionally catches exceptions raised in theprocess()
method.Operation wrapper transforms (e.g.
_Classify
,_Regress
, ...) accept an additionalcatch_errors: bool = False
parameter and return a dict containing{'predictions': ..., 'errors': ...}
Dependencies
This PR depends on #13