diff --git a/openhexa/toolbox/era5/aggregate.py b/openhexa/toolbox/era5/aggregate.py index 888f9c2..0f683ca 100644 --- a/openhexa/toolbox/era5/aggregate.py +++ b/openhexa/toolbox/era5/aggregate.py @@ -238,7 +238,9 @@ def aggregate(ds: xr.Dataset, var: str, masks: np.ndarray, boundaries_id: list[s return df -def aggregate_per_week(daily: pl.DataFrame, column_uid: str, use_epidemiological_weeks: bool = False) -> pl.DataFrame: +def aggregate_per_week( + daily: pl.DataFrame, column_uid: str, use_epidemiological_weeks: bool = False, sum_aggregation: bool = False +) -> pl.DataFrame: """Aggregate daily data per week. Parameters @@ -248,6 +250,10 @@ def aggregate_per_week(daily: pl.DataFrame, column_uid: str, use_epidemiological Length of the dataframe should be (n_boundaries * n_days). column_uid : str Column containing the boundary ID. + use_epidemiological_weeks : bool, optional + Use epidemiological weeks instead of iso weeks. + sum_aggregation : bool, optional + If True, sum values instead of computing the mean, for example for total precipitation data. Returns ------- @@ -261,15 +267,22 @@ def aggregate_per_week(daily: pl.DataFrame, column_uid: str, use_epidemiological df = daily.select([column_uid, pl.col(week_column).alias("week"), "mean", "min", "max"]) - # by default, we aggregate min & max values by respectively using the min() and max() - # functions, however it might makes sense to use mean() for some use cases - df = df.group_by([column_uid, "week"]).agg( - [ - pl.col("mean").mean().alias("mean"), - pl.col("min").min().alias("min"), - pl.col("max").max().alias("max"), - ] - ) + if sum_aggregation: + df = df.group_by([column_uid, "week"]).agg( + [ + pl.col("mean").sum().alias("mean"), + pl.col("min").sum().alias("min"), + pl.col("max").sum().alias("max"), + ] + ) + else: + df = df.group_by([column_uid, "week"]).agg( + [ + pl.col("mean").mean().alias("mean"), + pl.col("min").min().alias("min"), + pl.col("max").max().alias("max"), + ] + ) # sort per date since dhis2 period format is "2012W9", we need to extract year and week number # from the period string and cast them to int before sorting, else "2012W9" will be superior to @@ -285,7 +298,7 @@ def aggregate_per_week(daily: pl.DataFrame, column_uid: str, use_epidemiological return df -def aggregate_per_month(daily: pl.DataFrame, column_uid: str) -> pl.DataFrame: +def aggregate_per_month(daily: pl.DataFrame, column_uid: str, sum_aggregation: bool = False) -> pl.DataFrame: """Aggregate daily data per month. Parameters @@ -295,6 +308,8 @@ def aggregate_per_month(daily: pl.DataFrame, column_uid: str) -> pl.DataFrame: Length of the dataframe should be (n_boundaries * n_days). column_uid : str Column containing the boundary ID. + sum_aggregation : bool, optional + If True, sum values instead of computing the mean, for example for total precipitation data. Returns ------- @@ -303,13 +318,22 @@ def aggregate_per_month(daily: pl.DataFrame, column_uid: str) -> pl.DataFrame: """ df = daily.select([column_uid, "month", "mean", "min", "max"]) - df = df.group_by([column_uid, "month"]).agg( - [ - pl.col("mean").mean().alias("mean"), - pl.col("min").min().alias("min"), - pl.col("max").max().alias("max"), - ] - ) + if sum_aggregation: + df = df.group_by([column_uid, "month"]).agg( + [ + pl.col("mean").sum().alias("mean"), + pl.col("min").sum().alias("min"), + pl.col("max").sum().alias("max"), + ] + ) + else: + df = df.group_by([column_uid, "month"]).agg( + [ + pl.col("mean").mean().alias("mean"), + pl.col("min").min().alias("min"), + pl.col("max").max().alias("max"), + ] + ) df = df.sort( by=[