Skip to content

Commit

Permalink
Introduce config.query_columns
Browse files Browse the repository at this point in the history
  • Loading branch information
olejandro committed Feb 16, 2024
1 parent 4ab1db7 commit 58a06d3
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 19 deletions.
15 changes: 14 additions & 1 deletion xl2times/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def __init__(
self.column_aliases,
self.row_comment_chars,
self.discard_if_empty,
self.query_columns,
self.known_columns,
) = Config._read_veda_tags_info(veda_tags_file)
self.veda_attr_defaults, self.attr_aliases = Config._read_veda_attr_defaults(
Expand Down Expand Up @@ -354,6 +355,7 @@ def _read_veda_tags_info(
Dict[Tag, Dict[str, list]],
Iterable[Tag],
Dict[Tag, Set[str]],
Dict[Tag, Set[str]],
]:
def to_tag(s: str) -> Tag:
# The file stores the tag name in lowercase, and without the ~
Expand All @@ -374,6 +376,7 @@ def to_tag(s: str) -> Tag:
valid_column_names = {}
row_comment_chars = {}
discard_if_empty = []
query_cols = defaultdict(set)
known_cols = defaultdict(set)

for tag_info in veda_tags_info:
Expand All @@ -395,6 +398,8 @@ def to_tag(s: str) -> Tag:
else:
field_name = valid_field["name"]

if valid_field["query_field"]:
query_cols[tag_name].add(field_name)
known_cols[tag_name].add(field_name)

for valid_field_name in valid_field_names:
Expand All @@ -411,10 +416,18 @@ def to_tag(s: str) -> Tag:
discard_if_empty.append(tag_name)
if base_tag in row_comment_chars:
row_comment_chars[tag_name] = row_comment_chars[base_tag]
if base_tag in query_cols:
query_cols[tag_name] = query_cols[base_tag]
if base_tag in known_cols:
known_cols[tag_name] = known_cols[base_tag]

return valid_column_names, row_comment_chars, discard_if_empty, known_cols
return (
valid_column_names,
row_comment_chars,
discard_if_empty,
query_cols,
known_cols,
)

@staticmethod
def _read_veda_attr_defaults(
Expand Down
39 changes: 21 additions & 18 deletions xl2times/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
import pandas as pd
from dataclasses import replace
from typing import Dict, List
from typing import Dict, List, Set
from more_itertools import locate, one
from itertools import groupby
import re
Expand All @@ -17,17 +17,6 @@
from . import datatypes
from . import utils

query_columns = {
"pset_set",
"pset_pn",
"pset_pd",
"pset_ci",
"pset_co",
"cset_set",
"cset_cn",
"cset_cd",
}

csets_ordered_for_pcg = ["DEM", "MAT", "NRG", "ENV", "FIN"]
default_pcg_suffixes = [
cset + io for cset in csets_ordered_for_pcg for io in ["I", "O"]
Expand Down Expand Up @@ -740,13 +729,16 @@ def fill_in_missing_values_table(table):
return result


def expand_rows(table: datatypes.EmbeddedXlTable) -> datatypes.EmbeddedXlTable:
def expand_rows(
query_columns: Set[str], table: datatypes.EmbeddedXlTable
) -> datatypes.EmbeddedXlTable:
"""
Expand entries with commas into separate entries in the same column. Do this
for all tables except transformation update tables.
:param table: Table in EmbeddedXlTable format.
:return: Table in EmbeddedXlTable format with expanded comma entries.
:param query_columns: List of query column names.
:param table: Table in EmbeddedXlTable format.
:return: Table in EmbeddedXlTable format with expanded comma entries.
"""

def has_comma(s):
Expand Down Expand Up @@ -1720,6 +1712,7 @@ def is_year(col_name):
if table.tag == datatypes.Tag.tfm_ins_ts:
# ~TFM_INS-TS: Gather columns whose names are years into a single "Year" column:
df = table.dataframe
query_columns = config.query_columns[datatypes.Tag(table.tag)]
if "year" in df.columns:
raise ValueError(f"TFM_INS-TS table already has Year column: {table}")
# TODO: can we remove this hacky shortcut? Or should it be also applied to the AT variant?
Expand Down Expand Up @@ -1807,7 +1800,9 @@ def process_transform_tables(
df = table.dataframe.copy()

# Standardize column names
known_columns = config.known_columns[table.tag] | query_columns
known_columns = (
config.known_columns[table.tag] | config.query_columns[table.tag]
)

# Handle Regions:
if set(df.columns).isdisjoint(
Expand Down Expand Up @@ -2032,6 +2027,7 @@ def make_str(df):
if tag in tables:
start_time = time.time()
df = tables[tag]
query_columns = config.query_columns[tag]
dictionary = generate_topology_dictionary(tables, model)

df["process"] = df.apply(
Expand All @@ -2044,14 +2040,15 @@ def make_str(df):
cols_to_drop = [col for col in df.columns if col in query_columns]

df = expand_rows(
query_columns,
datatypes.EmbeddedXlTable(
tag="",
uc_sets={},
sheetname="",
range="",
filename="",
dataframe=df.drop(columns=cols_to_drop),
)
),
).dataframe

tables[tag] = df
Expand Down Expand Up @@ -2520,5 +2517,11 @@ def expand_rows_parallel(
tables: List[datatypes.EmbeddedXlTable],
model: datatypes.TimesModel,
) -> List[datatypes.EmbeddedXlTable]:
query_columns_lists = [
config.query_columns[datatypes.Tag(table.tag)]
if datatypes.Tag.has_tag(table.tag)
else set()
for table in tables
]
with ProcessPoolExecutor(max_workers) as executor:
return list(executor.map(expand_rows, tables))
return list(executor.map(expand_rows, query_columns_lists, tables))

0 comments on commit 58a06d3

Please sign in to comment.