diff --git a/xl2times/__main__.py b/xl2times/__main__.py index 1153f66..b342bf7 100644 --- a/xl2times/__main__.py +++ b/xl2times/__main__.py @@ -95,6 +95,7 @@ def convert_xl_to_times( transforms.convert_aliases, transforms.rename_cgs, transforms.fix_topology, + transforms.final_cleanup, transforms.convert_to_string, lambda config, tables: dump_tables( tables, os.path.join(output_dir, "merged_tables.txt") @@ -382,6 +383,12 @@ def main(): nargs="*", help="Either an input directory, or a list of input xlsx files to process", ) + args_parser.add_argument( + "--regions", + type=str, + default="", + help="Comma-separated list of regions to include in the model", + ) args_parser.add_argument( "--output_dir", type=str, default="output", help="Output directory" ) @@ -410,6 +417,7 @@ def main(): "times-info.json", "veda-tags.json", "veda-attr-defaults.json", + args.regions, ) if not isinstance(args.input, list) or len(args.input) < 1: diff --git a/xl2times/datatypes.py b/xl2times/datatypes.py index db2238b..483cde3 100644 --- a/xl2times/datatypes.py +++ b/xl2times/datatypes.py @@ -156,6 +156,8 @@ class Config: veda_attr_defaults: Dict[str, Dict[str, list]] # Known columns for each tag known_columns: Dict[Tag, Set[str]] + # Names of regions to include in the model; if empty, all regions are included. + filter_regions: Set[str] def __init__( self, @@ -163,6 +165,7 @@ def __init__( times_info_file: str, veda_tags_file: str, veda_attr_defaults_file: str, + regions: str, ): self.times_xl_maps = Config._read_mappings(mapping_file) ( @@ -184,6 +187,7 @@ def __init__( for m in param_mappings: name_to_map[m.times_name] = m self.times_xl_maps = list(name_to_map.values()) + self.filter_regions = Config._read_regions_filter(regions) @staticmethod def _process_times_info( @@ -416,3 +420,10 @@ def _read_veda_attr_defaults( veda_attr_defaults["tslvl"][tslvl].append(attr) return veda_attr_defaults, attr_aliases + + @staticmethod + def _read_regions_filter(regions_list: str) -> Set[str]: + if regions_list == "": + return set() + else: + return set(regions_list.strip(" ").upper().split(sep=",")) diff --git a/xl2times/transforms.py b/xl2times/transforms.py index 90fec31..7b56974 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -2160,6 +2160,24 @@ def apply_more_fixups( return tables +def final_cleanup( + config: datatypes.Config, tables: Dict[str, DataFrame] +) -> Dict[str, DataFrame]: + """Apply final clean up. E.g. discard not relevant data""" + + # Apply regions filter + # TODO: Apply regions filtering earlier (incl. populating default regions) + # TODO: Warn if invalid filter entries? + # TODO: Do not filter if no valid filter entries? + if config.filter_regions: + for k, v in tables.items(): + if "region" in v.columns: + df = v[v["region"].isin(config.filter_regions)] + tables[k] = df + + return tables + + def expand_rows_parallel( config: datatypes.Config, tables: List[datatypes.EmbeddedXlTable],