diff --git a/metaflow/runner/click_api.py b/metaflow/runner/click_api.py index b89ddae6831..b6c8a641cb2 100644 --- a/metaflow/runner/click_api.py +++ b/metaflow/runner/click_api.py @@ -188,28 +188,46 @@ def get_inspect_param_obj(p: Union[click.Argument, click.Option], kind: str): def extract_flow_class_from_file(flow_file: str) -> FlowSpec: if not os.path.exists(flow_file): raise FileNotFoundError("Flow file not present at '%s'" % flow_file) - # Check if the module has already been loaded - if flow_file in loaded_modules: - module = loaded_modules[flow_file] - else: - # Load the module if it's not already loaded - spec = importlib.util.spec_from_file_location("module", flow_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - # Cache the loaded module - loaded_modules[flow_file] = module - classes = inspect.getmembers(module, inspect.isclass) - - flow_cls = None - for _, kls in classes: - if kls != FlowSpec and issubclass(kls, FlowSpec): - if flow_cls is not None: - raise MetaflowException( - "Multiple FlowSpec classes found in %s" % flow_file - ) - flow_cls = kls - - return flow_cls + + flow_dir = os.path.dirname(os.path.abspath(flow_file)) + path_was_added = False + + # Only add to path if it's not already there + if flow_dir not in sys.path: + sys.path.insert(0, flow_dir) + path_was_added = True + + try: + # Check if the module has already been loaded + if flow_file in loaded_modules: + module = loaded_modules[flow_file] + else: + # Load the module if it's not already loaded + spec = importlib.util.spec_from_file_location("module", flow_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + # Cache the loaded module + loaded_modules[flow_file] = module + classes = inspect.getmembers(module, inspect.isclass) + + flow_cls = None + for _, kls in classes: + if kls != FlowSpec and issubclass(kls, FlowSpec): + if flow_cls is not None: + raise MetaflowException( + "Multiple FlowSpec classes found in %s" % flow_file + ) + flow_cls = kls + + return flow_cls + finally: + # Only remove from path if we added it + if path_was_added: + try: + sys.path.remove(flow_dir) + except ValueError: + # User's code might have removed it already + pass class MetaflowAPI(object):