diff --git a/xl2times/__main__.py b/xl2times/__main__.py index b8e5976..c24675c 100644 --- a/xl2times/__main__.py +++ b/xl2times/__main__.py @@ -454,6 +454,8 @@ def run(args: argparse.Namespace) -> str | None: for path in Path(args.input[0]).rglob("*") if path.suffix in [".xlsx", ".xlsm"] and not path.name.startswith("~") ] + if utils.is_veda_based(input_files): + input_files = utils.filter_veda_filename_patterns(input_files) logger.info(f"Loading {len(input_files)} files from {args.input[0]}") else: input_files = args.input diff --git a/xl2times/utils.py b/xl2times/utils.py index cce77b9..4da5269 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -12,7 +12,7 @@ from collections.abc import Iterable from dataclasses import replace from math import floor, log10 -from pathlib import Path +from pathlib import Path, PurePath import loguru import numpy @@ -315,6 +315,57 @@ def round_sig(x, sig_figs): default_log_name = "log" if default_log_name == "" else default_log_name +def _case_insensitive_match(path: str, pattern: str) -> bool: + """Do case-insensitive path match. Convert to lowercase first, because + case_sensitive parameter in match is not available before Python 3.12. + """ + return PurePath(path.lower()).match(pattern.lower()) + + +def is_veda_based(files: list[str]) -> bool: + """Determine whether the model follows Veda file structure. + This function does not verify file extensions. + """ + marker = "SysSettings.*" + + matches = [file for file in files if _case_insensitive_match(file, marker)] + + if len(matches) == 1: + return True + elif len(matches) > 1: + raise ValueError(f"Only one {marker} expected. Multiple detected: {matches}") + else: + return False + + +def filter_veda_filename_patterns(files: list[str]) -> list[str]: + """Filter files by patterns recognised by Veda. + This function does not verify file extensions. + """ + legal_paths = ( + "BY_Trans.*", + "LMA*.*", + "Set*.*", + "SysSettings.*", + "VT_*.*", + "SubRES_TMPL/SubRES_*.*", + "SuppXLS/Demands/Dem_Alloc+Series.*", + "SuppXLS/Demands/ScenDem_*.*", + "SuppXLS/ParScenFiles/Scen_Par-*.*", + "SuppXLS/Scen_*.*", + "SuppXLS/Trades/ScenTrade_*.*", + ) + # Generate a set of fiels that match the patterns + filtered = { + file + for file in files + for legal_path in legal_paths + if _case_insensitive_match(file, legal_path) + } + # Return as a list + return list(filtered) + + def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.Logger: """Return a configured loguru logger.