diff --git a/xl2times/transforms.py b/xl2times/transforms.py
index bb91c8c..3557683 100644
--- a/xl2times/transforms.py
+++ b/xl2times/transforms.py
@@ -372,9 +372,15 @@ def merge_tables(
missing_cols = [concat_cols - set(t.dataframe.columns) for t in group]
if any([len(m) for m in missing_cols]):
- err = f"WARNING: Possible merge error for table: '{key}'! Merged table has more columns than individual table(s), see details below:"
+ err = (
+ f"WARNING: Possible merge error for table: '{key}'! Merged table has more columns than individual "
+ f"table(s), see details below:"
+ )
for table in group:
- err += f"\n\tColumns: {list(table.dataframe.columns)} from {table.range}, {table.sheetname}, {table.filename}"
+ err += (
+ f"\n\tColumns: {list(table.dataframe.columns)} from {table.range}, {table.sheetname}, "
+ f"{table.filename}"
+ )
logger.warning(err)
match key:
@@ -1264,7 +1270,8 @@ def _process_comm_groups_vectorised(
comm_groups: 'Process' DataFrame with columns ["region", "process", "io", "csets", "commoditygroup"]
csets_ordered_for_pcg: List of csets in the order they should be considered for default pcg
Returns:
- Processed DataFrame with a new column "DefaultVedaPCG" set to True for the default pcg in eachregion/process/io combination.
+ Processed DataFrame with a new column "DefaultVedaPCG" set to True for the default pcg in eachregion/process/io
+ combination.
"""
def _set_default_veda_pcg(group):
@@ -2017,11 +2024,22 @@ def process_transform_availability(
return result
-def filter_by_pattern(df: pd.DataFrame, pattern: str) -> pd.DataFrame:
+def filter_by_pattern(df: pd.DataFrame, pattern: str, combined: bool) -> pd.DataFrame:
+ """
+ Filter dataframe index by a regex pattern. Parameter combined indicates whether commas should
+ be treated as a pattern separator or belong to the pattern.
+ """
# Duplicates can be created when a process has multiple commodities that match the pattern
- df = df.filter(regex=utils.create_regexp(pattern), axis="index").drop_duplicates()
- exclude = df.filter(regex=utils.create_negative_regexp(pattern), axis="index").index
- return df.drop(exclude)
+ df = df.filter(
+ regex=utils.create_regexp(pattern, combined), axis="index"
+ ).drop_duplicates()
+ if combined:
+ exclude = df.filter(
+ regex=utils.create_negative_regexp(pattern), axis="index"
+ ).index
+ return df.drop(exclude)
+ else:
+ return df
def intersect(acc, df):
@@ -2030,13 +2048,15 @@ def intersect(acc, df):
return acc.merge(df)
-def get_matching_processes(row: pd.Series, topology: dict[str, DataFrame]) -> pd.Series:
+def get_matching_processes(
+ row: pd.Series, topology: dict[str, DataFrame]
+) -> pd.Series | None:
matching_processes = None
for col, key in process_map.items():
if col in row.index and row[col] is not None:
proc_set = topology[key]
pattern = row[col].upper()
- filtered = filter_by_pattern(proc_set, pattern)
+ filtered = filter_by_pattern(proc_set, pattern, col != "pset_pd")
matching_processes = intersect(matching_processes, filtered)
if matching_processes is not None and any(matching_processes.duplicated()):
@@ -2051,7 +2071,7 @@ def get_matching_commodities(row: pd.Series, topology: dict[str, DataFrame]):
if col in row.index and row[col] is not None:
matching_commodities = intersect(
matching_commodities,
- filter_by_pattern(topology[key], row[col].upper()),
+ filter_by_pattern(topology[key], row[col].upper(), col != "cset_cd"),
)
return matching_commodities
@@ -2187,7 +2207,8 @@ def _match_wildcards(
explode: Whether to explode the results_col ('process'/'commodities') column into a long-format table.
Returns:
- The table with the wildcard columns removed and the results of the wildcard matches added as a column named `results_col`
+ The table with the wildcard columns removed and the results of the wildcard matches added as
+ a column named `results_col`
"""
wild_cols = list(col_map.keys())
diff --git a/xl2times/utils.py b/xl2times/utils.py
index 2f21a96..2dfc801 100644
--- a/xl2times/utils.py
+++ b/xl2times/utils.py
@@ -193,40 +193,51 @@ def get_scalar(table_tag: str, tables: list[datatypes.EmbeddedXlTable]):
return table.dataframe["value"].values[0]
-def has_negative_patterns(pattern):
+def has_negative_patterns(pattern: str) -> bool:
if len(pattern) == 0:
return False
return pattern[0] == "-" or ",-" in pattern
-def remove_negative_patterns(pattern):
+def remove_negative_patterns(pattern: str) -> str:
if len(pattern) == 0:
return pattern
return ",".join([word for word in pattern.split(",") if word[0] != "-"])
-def remove_positive_patterns(pattern):
+def remove_positive_patterns(pattern: str) -> str:
if len(pattern) == 0:
return pattern
return ",".join([word[1:] for word in pattern.split(",") if word[0] == "-"])
@functools.lru_cache(maxsize=int(1e6))
-def create_regexp(pattern):
- # exclude negative patterns
- if has_negative_patterns(pattern):
- pattern = remove_negative_patterns(pattern)
+def create_regexp(pattern: str, combined: bool = True) -> str:
+ # Distinguish comma-separated list of patterns vs a pattern with a comma(s)
+ if combined:
+ # Remove whitespaces
+ pattern = pattern.replace(" ", "")
+ # Exclude negative patterns
+ if has_negative_patterns(pattern):
+ pattern = remove_negative_patterns(pattern)
+ # Handle comma-separated values
+ pattern = pattern.replace(",", r"$|^")
if len(pattern) == 0:
- return re.compile(pattern) # matches everything
- # Handle VEDA wildcards
- pattern = pattern.replace("*", ".*").replace("?", ".").replace(",", r"$|^")
+ return r".*" # matches everything
+ # Handle substite VEDA wildcards with regex patterns
+ for substition in (("*", ".*"), ("?", ".")):
+ old, new = substition
+ pattern = pattern.replace(old, new)
# Do not match substrings
pattern = rf"^{pattern}$"
- return re.compile(pattern)
+ return pattern
@functools.lru_cache(maxsize=int(1e6))
-def create_negative_regexp(pattern):
+def create_negative_regexp(pattern: str) -> str:
+ # Remove whitespaces
+ pattern = pattern.replace(" ", "")
+ # Exclude positive patterns
pattern = remove_positive_patterns(pattern)
if len(pattern) == 0:
pattern = r"^$" # matches nothing
@@ -250,7 +261,8 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L
Call this once from entrypoints to set up a new logger.
In non-entrypoint modules, just use `from loguru import logger` directly.
- To set the log level, use the `LOGURU_LEVEL` environment variable before or during runtime. E.g. `os.environ["LOGURU_LEVEL"] = "INFO"`
+ To set the log level, use the `LOGURU_LEVEL` environment variable before or during runtime.
+ E.g. `os.environ["LOGURU_LEVEL"] = "INFO"`
Available levels are `TRACE`, `DEBUG`, `INFO`, `SUCCESS`, `WARNING`, `ERROR`, and `CRITICAL`. Default is `INFO`.
Log file will be written to `f"{log_dir}/{log_name}.log"`
@@ -272,7 +284,8 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L
{
"sink": sys.stdout,
"diagnose": True,
- "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} : {message} ({name}:{"
+ "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} : {"
+ "message} ({name}:{"
'thread.name}:pid-{process} "{'
'file.path}:{line}")',
},