Skip to content

Commit

Permalink
Simplified and optimized logic for calculating per-metric subsampling…
Browse files Browse the repository at this point in the history
… rate for MapData (#3106)

Summary:

This refines the logic for calculating per-metric subsampling rates in `MapData.subsample` and incorporates a (probably premature) performance optimization, achieved by utilizing binary search on a sorted list instead of linear search.

Reviewed By: Balandat

Differential Revision: D66366076
  • Loading branch information
ltiao authored and facebook-github-bot committed Nov 27, 2024
1 parent 9cfead2 commit 7d32e88
Showing 1 changed file with 50 additions and 17 deletions.
67 changes: 50 additions & 17 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from __future__ import annotations

from bisect import bisect_right
from collections.abc import Iterable, Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Generic, TypeVar

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.data import Data
from ax.core.types import TMapTrialEvaluation
Expand Down Expand Up @@ -412,6 +414,46 @@ def subsample(
)


def _ceil_divide(a: int | npt.NDArray, b: int | npt.NDArray) -> int | npt.NDArray:
return -np.floor_divide(-a, b)


def _subsample_rate(
map_df: pd.DataFrame,
keep_every: int | None = None,
limit_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
) -> int:
if keep_every is not None:
return keep_every

grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)
group_sizes = grouped_map_df.size()
max_rows = group_sizes.max()

if limit_rows_per_group is not None:
return checked_cast(int, _ceil_divide(max_rows, limit_rows_per_group))

if limit_rows_per_metric is not None:
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
ks = np.arange(max_rows, 0, -1)
# total sizes in ascending order
total_sizes = np.sum(
_ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1
)
# binary search
i = bisect_right(total_sizes, limit_rows_per_metric)
# if no such `k` is found, then `derived_keep_every` stays as 1.
if i > 0:
return ks[i - 1]

raise ValueError(
"at least one of `keep_every`, `limit_rows_per_group`, "
"or `limit_rows_per_metric` must be specified."
)


def _subsample_one_metric(
map_df: pd.DataFrame,
map_key: str | None = None,
Expand All @@ -421,30 +463,21 @@ def _subsample_one_metric(
include_first_last: bool = True,
) -> pd.DataFrame:
"""Helper function to subsample a dataframe that holds a single metric."""
derived_keep_every = 1
if keep_every is not None:
derived_keep_every = keep_every
elif limit_rows_per_group is not None:
max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max()
derived_keep_every = np.ceil(max_rows / limit_rows_per_group)
elif limit_rows_per_metric is not None:
group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy()
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
for k in range(1, group_sizes.max() + 1):
if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric:
derived_keep_every = k
break
# if no such `k` is found, then `derived_keep_every` stays as 1.

grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)

derived_keep_every = _subsample_rate(
map_df, keep_every, limit_rows_per_group, limit_rows_per_metric
)

if derived_keep_every <= 1:
filtered_map_df = map_df
else:
filtered_dfs = []
for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS):
for _, df_g in grouped_map_df:
df_g = df_g.sort_values(map_key)
if include_first_last:
rows_per_group = int(np.ceil(len(df_g) / derived_keep_every))
rows_per_group = _ceil_divide(len(df_g), derived_keep_every)
linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group)
idcs = np.round(linspace_idcs).astype(int)
filtered_df = df_g.iloc[idcs]
Expand Down

0 comments on commit 7d32e88

Please sign in to comment.