diff --git a/xl2times/datatypes.py b/xl2times/datatypes.py index 13f2d31..d26a9f2 100644 --- a/xl2times/datatypes.py +++ b/xl2times/datatypes.py @@ -207,6 +207,7 @@ def __init__( self.column_aliases, self.row_comment_chars, self.discard_if_empty, + self.query_columns, self.known_columns, ) = Config._read_veda_tags_info(veda_tags_file) self.veda_attr_defaults, self.attr_aliases = Config._read_veda_attr_defaults( @@ -354,6 +355,7 @@ def _read_veda_tags_info( Dict[Tag, Dict[str, list]], Iterable[Tag], Dict[Tag, Set[str]], + Dict[Tag, Set[str]], ]: def to_tag(s: str) -> Tag: # The file stores the tag name in lowercase, and without the ~ @@ -374,6 +376,7 @@ def to_tag(s: str) -> Tag: valid_column_names = {} row_comment_chars = {} discard_if_empty = [] + query_cols = defaultdict(set) known_cols = defaultdict(set) for tag_info in veda_tags_info: @@ -395,6 +398,8 @@ def to_tag(s: str) -> Tag: else: field_name = valid_field["name"] + if valid_field["query_field"]: + query_cols[tag_name].add(field_name) known_cols[tag_name].add(field_name) for valid_field_name in valid_field_names: @@ -411,10 +416,18 @@ def to_tag(s: str) -> Tag: discard_if_empty.append(tag_name) if base_tag in row_comment_chars: row_comment_chars[tag_name] = row_comment_chars[base_tag] + if base_tag in query_cols: + query_cols[tag_name] = query_cols[base_tag] if base_tag in known_cols: known_cols[tag_name] = known_cols[base_tag] - return valid_column_names, row_comment_chars, discard_if_empty, known_cols + return ( + valid_column_names, + row_comment_chars, + discard_if_empty, + query_cols, + known_cols, + ) @staticmethod def _read_veda_attr_defaults( diff --git a/xl2times/transforms.py b/xl2times/transforms.py index 901b826..837737b 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -3,7 +3,7 @@ from pathlib import Path import pandas as pd from dataclasses import replace -from typing import Dict, List +from typing import Dict, List, Set from more_itertools import locate, one from itertools import groupby import re @@ -17,17 +17,6 @@ from . import datatypes from . import utils -query_columns = { - "pset_set", - "pset_pn", - "pset_pd", - "pset_ci", - "pset_co", - "cset_set", - "cset_cn", - "cset_cd", -} - csets_ordered_for_pcg = ["DEM", "MAT", "NRG", "ENV", "FIN"] default_pcg_suffixes = [ cset + io for cset in csets_ordered_for_pcg for io in ["I", "O"] @@ -740,13 +729,16 @@ def fill_in_missing_values_table(table): return result -def expand_rows(table: datatypes.EmbeddedXlTable) -> datatypes.EmbeddedXlTable: +def expand_rows( + query_columns: Set[str], table: datatypes.EmbeddedXlTable +) -> datatypes.EmbeddedXlTable: """ Expand entries with commas into separate entries in the same column. Do this for all tables except transformation update tables. - :param table: Table in EmbeddedXlTable format. - :return: Table in EmbeddedXlTable format with expanded comma entries. + :param query_columns: List of query column names. + :param table: Table in EmbeddedXlTable format. + :return: Table in EmbeddedXlTable format with expanded comma entries. """ def has_comma(s): @@ -1720,6 +1712,7 @@ def is_year(col_name): if table.tag == datatypes.Tag.tfm_ins_ts: # ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column: df = table.dataframe + query_columns = config.query_columns[datatypes.Tag(table.tag)] if "year" in df.columns: raise ValueError(f"TFM_INS-TS table already has Year column: {table}") # TODO: can we remove this hacky shortcut? Or should it be also applied to the AT variant? @@ -1807,7 +1800,9 @@ def process_transform_tables( df = table.dataframe.copy() # Standardize column names - known_columns = config.known_columns[table.tag] | query_columns + known_columns = ( + config.known_columns[table.tag] | config.query_columns[table.tag] + ) # Handle Regions: if set(df.columns).isdisjoint( @@ -2032,6 +2027,7 @@ def make_str(df): if tag in tables: start_time = time.time() df = tables[tag] + query_columns = config.query_columns[tag] dictionary = generate_topology_dictionary(tables, model) df["process"] = df.apply( @@ -2044,6 +2040,7 @@ def make_str(df): cols_to_drop = [col for col in df.columns if col in query_columns] df = expand_rows( + query_columns, datatypes.EmbeddedXlTable( tag="", uc_sets={}, @@ -2051,7 +2048,7 @@ def make_str(df): range="", filename="", dataframe=df.drop(columns=cols_to_drop), - ) + ), ).dataframe tables[tag] = df @@ -2520,5 +2517,11 @@ def expand_rows_parallel( tables: List[datatypes.EmbeddedXlTable], model: datatypes.TimesModel, ) -> List[datatypes.EmbeddedXlTable]: + query_columns_lists = [ + config.query_columns[datatypes.Tag(table.tag)] + if datatypes.Tag.has_tag(table.tag) + else set() + for table in tables + ] with ProcessPoolExecutor(max_workers) as executor: - return list(executor.map(expand_rows, tables)) + return list(executor.map(expand_rows, query_columns_lists, tables))