diff --git a/xl2times/transforms.py b/xl2times/transforms.py index bb91c8c..3557683 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -372,9 +372,15 @@ def merge_tables( missing_cols = [concat_cols - set(t.dataframe.columns) for t in group] if any([len(m) for m in missing_cols]): - err = f"WARNING: Possible merge error for table: '{key}'! Merged table has more columns than individual table(s), see details below:" + err = ( + f"WARNING: Possible merge error for table: '{key}'! Merged table has more columns than individual " + f"table(s), see details below:" + ) for table in group: - err += f"\n\tColumns: {list(table.dataframe.columns)} from {table.range}, {table.sheetname}, {table.filename}" + err += ( + f"\n\tColumns: {list(table.dataframe.columns)} from {table.range}, {table.sheetname}, " + f"{table.filename}" + ) logger.warning(err) match key: @@ -1264,7 +1270,8 @@ def _process_comm_groups_vectorised( comm_groups: 'Process' DataFrame with columns ["region", "process", "io", "csets", "commoditygroup"] csets_ordered_for_pcg: List of csets in the order they should be considered for default pcg Returns: - Processed DataFrame with a new column "DefaultVedaPCG" set to True for the default pcg in eachregion/process/io combination. + Processed DataFrame with a new column "DefaultVedaPCG" set to True for the default pcg in eachregion/process/io + combination. """ def _set_default_veda_pcg(group): @@ -2017,11 +2024,22 @@ def process_transform_availability( return result -def filter_by_pattern(df: pd.DataFrame, pattern: str) -> pd.DataFrame: +def filter_by_pattern(df: pd.DataFrame, pattern: str, combined: bool) -> pd.DataFrame: + """ + Filter dataframe index by a regex pattern. Parameter combined indicates whether commas should + be treated as a pattern separator or belong to the pattern. + """ # Duplicates can be created when a process has multiple commodities that match the pattern - df = df.filter(regex=utils.create_regexp(pattern), axis="index").drop_duplicates() - exclude = df.filter(regex=utils.create_negative_regexp(pattern), axis="index").index - return df.drop(exclude) + df = df.filter( + regex=utils.create_regexp(pattern, combined), axis="index" + ).drop_duplicates() + if combined: + exclude = df.filter( + regex=utils.create_negative_regexp(pattern), axis="index" + ).index + return df.drop(exclude) + else: + return df def intersect(acc, df): @@ -2030,13 +2048,15 @@ def intersect(acc, df): return acc.merge(df) -def get_matching_processes(row: pd.Series, topology: dict[str, DataFrame]) -> pd.Series: +def get_matching_processes( + row: pd.Series, topology: dict[str, DataFrame] +) -> pd.Series | None: matching_processes = None for col, key in process_map.items(): if col in row.index and row[col] is not None: proc_set = topology[key] pattern = row[col].upper() - filtered = filter_by_pattern(proc_set, pattern) + filtered = filter_by_pattern(proc_set, pattern, col != "pset_pd") matching_processes = intersect(matching_processes, filtered) if matching_processes is not None and any(matching_processes.duplicated()): @@ -2051,7 +2071,7 @@ def get_matching_commodities(row: pd.Series, topology: dict[str, DataFrame]): if col in row.index and row[col] is not None: matching_commodities = intersect( matching_commodities, - filter_by_pattern(topology[key], row[col].upper()), + filter_by_pattern(topology[key], row[col].upper(), col != "cset_cd"), ) return matching_commodities @@ -2187,7 +2207,8 @@ def _match_wildcards( explode: Whether to explode the results_col ('process'/'commodities') column into a long-format table. Returns: - The table with the wildcard columns removed and the results of the wildcard matches added as a column named `results_col` + The table with the wildcard columns removed and the results of the wildcard matches added as + a column named `results_col` """ wild_cols = list(col_map.keys()) diff --git a/xl2times/utils.py b/xl2times/utils.py index 2f21a96..2dfc801 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -193,40 +193,51 @@ def get_scalar(table_tag: str, tables: list[datatypes.EmbeddedXlTable]): return table.dataframe["value"].values[0] -def has_negative_patterns(pattern): +def has_negative_patterns(pattern: str) -> bool: if len(pattern) == 0: return False return pattern[0] == "-" or ",-" in pattern -def remove_negative_patterns(pattern): +def remove_negative_patterns(pattern: str) -> str: if len(pattern) == 0: return pattern return ",".join([word for word in pattern.split(",") if word[0] != "-"]) -def remove_positive_patterns(pattern): +def remove_positive_patterns(pattern: str) -> str: if len(pattern) == 0: return pattern return ",".join([word[1:] for word in pattern.split(",") if word[0] == "-"]) @functools.lru_cache(maxsize=int(1e6)) -def create_regexp(pattern): - # exclude negative patterns - if has_negative_patterns(pattern): - pattern = remove_negative_patterns(pattern) +def create_regexp(pattern: str, combined: bool = True) -> str: + # Distinguish comma-separated list of patterns vs a pattern with a comma(s) + if combined: + # Remove whitespaces + pattern = pattern.replace(" ", "") + # Exclude negative patterns + if has_negative_patterns(pattern): + pattern = remove_negative_patterns(pattern) + # Handle comma-separated values + pattern = pattern.replace(",", r"$|^") if len(pattern) == 0: - return re.compile(pattern) # matches everything - # Handle VEDA wildcards - pattern = pattern.replace("*", ".*").replace("?", ".").replace(",", r"$|^") + return r".*" # matches everything + # Handle substite VEDA wildcards with regex patterns + for substition in (("*", ".*"), ("?", ".")): + old, new = substition + pattern = pattern.replace(old, new) # Do not match substrings pattern = rf"^{pattern}$" - return re.compile(pattern) + return pattern @functools.lru_cache(maxsize=int(1e6)) -def create_negative_regexp(pattern): +def create_negative_regexp(pattern: str) -> str: + # Remove whitespaces + pattern = pattern.replace(" ", "") + # Exclude positive patterns pattern = remove_positive_patterns(pattern) if len(pattern) == 0: pattern = r"^$" # matches nothing @@ -250,7 +261,8 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L Call this once from entrypoints to set up a new logger. In non-entrypoint modules, just use `from loguru import logger` directly. - To set the log level, use the `LOGURU_LEVEL` environment variable before or during runtime. E.g. `os.environ["LOGURU_LEVEL"] = "INFO"` + To set the log level, use the `LOGURU_LEVEL` environment variable before or during runtime. + E.g. `os.environ["LOGURU_LEVEL"] = "INFO"` Available levels are `TRACE`, `DEBUG`, `INFO`, `SUCCESS`, `WARNING`, `ERROR`, and `CRITICAL`. Default is `INFO`. Log file will be written to `f"{log_dir}/{log_name}.log"` @@ -272,7 +284,8 @@ def get_logger(log_name: str = default_log_name, log_dir: str = ".") -> loguru.L { "sink": sys.stdout, "diagnose": True, - "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} : {message} ({name}:{" + "format": "{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} : {" + "message} ({name}:{" 'thread.name}:pid-{process} "{' 'file.path}:{line}")', },