diff --git a/xl2times/datatypes.py b/xl2times/datatypes.py index 82c652ff..c76f7e73 100644 --- a/xl2times/datatypes.py +++ b/xl2times/datatypes.py @@ -7,6 +7,7 @@ from functools import cached_property from importlib import resources from itertools import chain +from pathlib import PurePath from loguru import logger from pandas.core.frame import DataFrame @@ -71,6 +72,63 @@ def has_tag(cls, tag): return tag in cls._value2member_map_ +class DataModule(str, Enum): + """Categorise data into modules based on the file they are coming from.""" + + base = "VT_*.*, BY_Trans.*" + syssettings = "SysSettings.*" + subres = "SubRES_TMPL/SubRES_*.*" + sets = "Set*.*" + lma = "LMA*.*" + demand = "SuppXLS/Demands/Dem_Alloc+Series.*, SuppXLS/Demands/ScenDem_*.*" + scen = "SuppXLS/Scen_*.*" + trade = "SuppXLS/Trades/ScenTrade_*.*" + + @classmethod + def determine_type(cls, path: str) -> "DataModule | None": + for data_module in cls: + if any( + PurePath(path.lower()).match(pattern.lower().strip()) + for pattern in data_module.value.split(",") + ): + return data_module + return None + + @classmethod + def module_type(cls, path: str) -> str | None: + module_type = cls.determine_type(path) + if module_type: + return module_type.name + else: + return None + + @classmethod + def submodule(cls, path: str) -> str | None: + match cls.determine_type(path): + case DataModule.base | DataModule.subres: + if PurePath(path.lower()).match("*_trans.*"): + return "trans" + else: + return "main" + case None: + return None + case _: + return "main" + + @classmethod + def module_name(cls, path: str) -> str | None: + module_type = cls.determine_type(path) + match module_type: + case DataModule.base | DataModule.sets | DataModule.lma | DataModule.demand | DataModule.trade | DataModule.syssettings: + return module_type.name + case DataModule.subres: + return re.sub("_trans$", "", PurePath(path).stem.lower()) + case None: + return None + case _: + return PurePath(path).stem + + @dataclass class EmbeddedXlTable: """A table object: a pandas dataframe wrapped with some metadata. diff --git a/xl2times/transforms.py b/xl2times/transforms.py index db430e15..c64bf77d 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -15,7 +15,7 @@ from tqdm import tqdm from . import utils -from .datatypes import Config, EmbeddedXlTable, Tag, TimesModel +from .datatypes import Config, DataModule, EmbeddedXlTable, Tag, TimesModel from .utils import max_workers csets_ordered_for_pcg = ["DEM", "MAT", "NRG", "ENV", "FIN"] @@ -358,7 +358,10 @@ def include_tables_source( def include_table_source(table: EmbeddedXlTable): df = table.dataframe.copy() - df["source_filename"] = table.filename + df["source_filename"] = Path(table.filename).stem + df["data_module_type"] = DataModule.module_type(table.filename) + df["data_submodule"] = DataModule.submodule(table.filename) + df["data_module_name"] = DataModule.module_name(table.filename) return replace(table, dataframe=df) return [include_table_source(table) for table in tables]