diff --git a/toolkit/compare_case_groups/api.py b/toolkit/compare_case_groups/api.py index e0908b3d..4b578eb3 100644 --- a/toolkit/compare_case_groups/api.py +++ b/toolkit/compare_case_groups/api.py @@ -44,30 +44,19 @@ def get_dataset_proportion(self) -> int: ) def _select_columns_ranked_df(self, ranked_df: pl.DataFrame) -> None: - lower_groups = [g.lower() for g in self.groups] - - if self.temporal != "": - columns = [ - *lower_groups, - "group_count", - "group_rank", - "attribute_value", - "attribute_count", - "attribute_rank", - f"{self.temporal}_window", - f"{self.temporal}_window_count", - f"{self.temporal}_window_rank", - f"{self.temporal}_window_delta", - ] - else: - columns = [ - *lower_groups, - "group_count", - "group_rank", - "attribute_value", - "attribute_count", - "attribute_rank", - ] + columns = [g.lower() for g in self.groups] + default_columns = [ + "group_count", + "group_rank", + "attribute_value", + "attrubite_count", + "attribute_rank", + ] + + columns.extend(default_columns) + + if self.temporal: + columns.extend([f"{self.temporal}_window", ...]) self.model_df = ranked_df.select(columns)