Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest observations from MapData #3112

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 100 additions & 25 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from __future__ import annotations

from bisect import bisect_right
from collections.abc import Iterable, Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Generic, TypeVar

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.data import Data
from ax.core.types import TMapTrialEvaluation
Expand Down Expand Up @@ -276,15 +278,15 @@ def from_multiple_data(
def df(self) -> pd.DataFrame:
"""Returns a Data shaped DataFrame"""

# If map_keys is empty just return the df
if self._memo_df is not None:
return self._memo_df

# If map_keys is empty just return the df
if len(self.map_keys) == 0:
return self.map_df

self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
self._memo_df = _tail(
map_df=self.map_df, map_keys=self.map_keys, n=1, sort=True
)

return self._memo_df
Expand Down Expand Up @@ -338,6 +340,32 @@ def clone(self) -> MapData:
description=self.description,
)

def latest(
self,
map_keys: list[str] | None = None,
rows_per_group: int = 1,
) -> MapData:
"""Return a new MapData with the most recently observed `rows_per_group`
rows for each (arm, metric) group, determined by the `map_key` values,
where higher implies more recent.

This function considers only the relative ordering of the `map_key` values,
making it most suitable when these values are equally spaced.

If `rows_per_group` is greater than the number of rows in a given
(arm, metric) group, then all rows are returned.
"""
if map_keys is None:
map_keys = self.map_keys

return MapData(
df=_tail(
map_df=self.map_df, map_keys=map_keys, n=rows_per_group, sort=True
),
map_key_infos=self.map_key_infos,
description=self.description,
)

def subsample(
self,
map_key: str | None = None,
Expand All @@ -346,11 +374,13 @@ def subsample(
limit_rows_per_metric: int | None = None,
include_first_last: bool = True,
) -> MapData:
"""Subsample the `map_key` column in an equally-spaced manner (if there is
a `self.map_keys` is length one, then `map_key` can be set to None). The
values of the `map_key` column are not taken into account, so this function
is most reasonable when those values are equally-spaced. There are three
ways that this can be done:
"""Return a new MapData that subsamples the `map_key` column in an
equally-spaced manner. If `self.map_keys` has a length of one, `map_key`
can be set to None. This function considers only the relative ordering
of the `map_key` values, making it most suitable when these values are
equally spaced.

There are three ways that this can be done:
1. If `keep_every = k` is set, then every kth row of the DataFrame in the
`map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`.
In other words, every kth step of each (arm, metric) will be kept.
Expand Down Expand Up @@ -412,6 +442,60 @@ def subsample(
)


def _ceil_divide(
a: int | np.int_ | npt.NDArray[np.int_], b: int | np.int_ | npt.NDArray[np.int_]
) -> np.int_ | npt.NDArray[np.int_]:
return -np.floor_divide(-a, b)


def _subsample_rate(
map_df: pd.DataFrame,
keep_every: int | None = None,
limit_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
) -> int:
if keep_every is not None:
return keep_every

grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)
group_sizes = grouped_map_df.size()
max_rows = group_sizes.max()

if limit_rows_per_group is not None:
return _ceil_divide(max_rows, limit_rows_per_group).item()

if limit_rows_per_metric is not None:
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
ks = np.arange(max_rows, 0, -1)
# total sizes in ascending order
total_sizes = np.sum(
_ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1
)
# binary search
i = bisect_right(total_sizes, limit_rows_per_metric)
# if no such `k` is found, then `derived_keep_every` stays as 1.
if i > 0:
return ks[i - 1].item()

raise ValueError(
"at least one of `keep_every`, `limit_rows_per_group`, "
"or `limit_rows_per_metric` must be specified."
)


def _tail(
map_df: pd.DataFrame,
map_keys: list[str],
n: int = 1,
sort: bool = True,
) -> pd.DataFrame:
df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n)
if sort:
df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True)
return df


def _subsample_one_metric(
map_df: pd.DataFrame,
map_key: str | None = None,
Expand All @@ -421,30 +505,21 @@ def _subsample_one_metric(
include_first_last: bool = True,
) -> pd.DataFrame:
"""Helper function to subsample a dataframe that holds a single metric."""
derived_keep_every = 1
if keep_every is not None:
derived_keep_every = keep_every
elif limit_rows_per_group is not None:
max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max()
derived_keep_every = np.ceil(max_rows / limit_rows_per_group)
elif limit_rows_per_metric is not None:
group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy()
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
for k in range(1, group_sizes.max() + 1):
if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric:
derived_keep_every = k
break
# if no such `k` is found, then `derived_keep_every` stays as 1.

grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)

derived_keep_every = _subsample_rate(
map_df, keep_every, limit_rows_per_group, limit_rows_per_metric
)

if derived_keep_every <= 1:
filtered_map_df = map_df
else:
filtered_dfs = []
for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS):
for _, df_g in grouped_map_df:
df_g = df_g.sort_values(map_key)
if include_first_last:
rows_per_group = int(np.ceil(len(df_g) / derived_keep_every))
rows_per_group = _ceil_divide(len(df_g), derived_keep_every)
linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group)
idcs = np.round(linspace_idcs).astype(int)
filtered_df = df_g.iloc[idcs]
Expand Down
Loading
Loading