diff --git a/.gitignore b/.gitignore index ceafcba..c88d75a 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ docs/api/ /out.txt *.log /profile.* +.cache/ diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 694c31f..995a0d2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -22,7 +22,7 @@ "display.max_columns", 20, "display.width", - 300, + 150, "display.max_colwidth", 75, "display.precision", @@ -30,6 +30,40 @@ ) +def test_merge_duplicate_columns(): + """ + Tests that this: + >>> df + ... a a a b b c + ... 0 NaN 4.0 7.0 1 4.0 1 + ... 1 2.0 NaN 8.0 2 5.0 2 + ... 2 3.0 6.0 NaN 3 NaN 3 + + Gets transformed into this: + >>> transforms._merge_duplicate_named_columns(df.copy()) + ... df2 + ... a b c + ... 0 7.0 4.0 1 + ... 1 8.0 5.0 2 + ... 2 6.0 3.0 3 + """ + df = pd.DataFrame( + { + "a": [None, 2, 3], + "a2": [4, None, 6], + "a3": [7, 8, None], + "b": [1, 2, 3], + "b2": [4, 5, None], + "c": [1, 2, 3], + } + ).rename(columns={"a2": "a", "a3": "a", "b2": "b"}) + df2 = transforms._merge_duplicate_named_columns(df.copy()) + assert df2.columns.tolist() == ["a", "b", "c"] + assert df2["a"].tolist() == [7, 8, 6] + assert df2["b"].tolist() == [4, 5, 3] + assert df2["c"].tolist() == [1, 2, 3] + + def _match_uc_wildcards_old( df: pd.DataFrame, dictionary: dict[str, pd.DataFrame] ) -> pd.DataFrame: diff --git a/utils/run_benchmarks.py b/utils/run_benchmarks.py index 060a0fc..c3693a9 100644 --- a/utils/run_benchmarks.py +++ b/utils/run_benchmarks.py @@ -23,6 +23,19 @@ logger = utils.get_logger() +pd.set_option( + "display.max_rows", + 20, + "display.max_columns", + 20, + "display.width", + 150, + "display.max_colwidth", + 75, + "display.precision", + 3, +) + def parse_result(output: str) -> Tuple[float, int, int]: # find pattern in multiline string diff --git a/xl2times/transforms.py b/xl2times/transforms.py index 211191c..c33cc85 100644 --- a/xl2times/transforms.py +++ b/xl2times/transforms.py @@ -205,23 +205,52 @@ def discard(table): result = [] for table in tables: + if not datatypes.Tag.has_tag(table.tag.split(":")[0]): logger.warning(f"Dropping table with unrecognized tag {table.tag}") continue + if discard(table): continue + # Check for duplicate columns: - seen = set() - dupes = [x for x in table.dataframe.columns if x in seen or seen.add(x)] + df = table.dataframe + dupes = df.columns[table.dataframe.columns.duplicated()] + if len(dupes) > 0: logger.warning( - f"Duplicate columns in {table.range}, {table.sheetname}," - f" {table.filename}: {','.join(dupes)}" + f"Merging duplicate columns in {table.range}, {table.sheetname}," + f" {table.filename}: {dupes.to_list()}" ) + table.dataframe = _merge_duplicate_named_columns(df) + result.append(table) return result +def _merge_duplicate_named_columns(df_in: DataFrame) -> DataFrame: + """Merges values in duplicate columns into a single column. + This is implemented as a foward-fill of missing values in the left-to-right direction, to match VEDA's behaviour. + So any missing values in the right-most of each set of duplicate-named columns are filled with the first non-missing value to the left. + + Parameters + df_in : DataFrame to be processed (not modified) + Returns + DataFrame with duplicate columns merged + """ + if not df_in.columns.duplicated().any(): + return df_in + + df = df_in.copy() + dupes = pd.unique(df.columns[df.columns.duplicated(keep="first")]) + for dup_col in dupes: + df[dup_col] = df[dup_col].ffill(axis=1) + + # only keep the right-most duplicate column from each duplicate set + df = df.iloc[:, ~df.columns.duplicated(keep="last")] + return df + + def normalize_tags_columns( config: datatypes.Config, tables: List[datatypes.EmbeddedXlTable], @@ -588,7 +617,8 @@ def process_user_constraint_table( # TODO: apply table.uc_sets # Fill in UC_N blank cells with value from above - df["uc_n"] = df["uc_n"].ffill() + if "uc_n" in df.columns: + df["uc_n"] = df["uc_n"].ffill() data_columns = [ x for x in df.columns if x not in config.known_columns[datatypes.Tag.uc_t] diff --git a/xl2times/utils.py b/xl2times/utils.py index fcc10eb..48c516b 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -60,6 +60,12 @@ def explode(df, data_columns): :return: Tuple with the exploded dataframe and a Series of the original column name for each value in each new row. """ + if df.columns.duplicated().any(): + raise ValueError( + f"Dataframe has duplicated columns: {df.columns[df.columns.duplicated()]}" + ) + + dfo = df.copy() data = df[data_columns].values.tolist() other_columns = [ colname for colname in df.columns.values if colname not in data_columns @@ -75,6 +81,7 @@ def explode(df, data_columns): index = df[value_column].notna() df = df[index] names = names[index] + return df, names