diff --git a/python/src/main/python/yaml-template/main.py b/python/src/main/python/yaml-template/main.py index 4898b67ff6..3c2d434517 100644 --- a/python/src/main/python/yaml-template/main.py +++ b/python/src/main/python/yaml-template/main.py @@ -31,6 +31,44 @@ LogicalType.register_logical_type(MillisInstant) +def _preparse_jinja_flags(argv): + """Promotes any flags to --jinja_variables based on --jinja_variable_flags. + This is to facilitate tools (such as dataflow templates) that must pass + options as un-nested flags. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + '--jinja_variable_flags', + default=[], + type=lambda s: s.split(','), + help='A list of flag names that should be used as jinja variables.') + parser.add_argument( + '--jinja_variables', + default={}, + type=json.loads, + help='A json dict of variables used when invoking the jinja preprocessor ' + 'on the provided yaml pipeline.') + jinja_args, other_args = parser.parse_known_args(argv) + if not jinja_args.jinja_variable_flags: + return argv + + jinja_variable_parser = argparse.ArgumentParser() + for flag_name in jinja_args.jinja_variable_flags: + jinja_variable_parser.add_argument('--' + flag_name) + jinja_flag_variables, pipeline_args = jinja_variable_parser.parse_known_args( + other_args) + jinja_args.jinja_variables.update( + ** + {k: v + for (k, v) in vars(jinja_flag_variables).items() if v is not None}) + if jinja_args.jinja_variables: + pipeline_args = pipeline_args + [ + '--jinja_variables=' + json.dumps(jinja_args.jinja_variables) + ] + + return pipeline_args + + def _configure_parser(argv): parser = argparse.ArgumentParser() parser.add_argument( @@ -79,6 +117,7 @@ def get_source(self, environment, path): def run(argv=None): + argv = _preparse_jinja_flags(argv) known_args, pipeline_args = _configure_parser(argv) pipeline_yaml = ( # keep formatting jinja2.Environment(