Skip to content

Commit

Permalink
Latest observations from MapData (#3112)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #3112

Differential Revision: D66434621
  • Loading branch information
ltiao authored and facebook-github-bot committed Nov 27, 2024
1 parent 7d32e88 commit fd76ff0
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 27 deletions.
54 changes: 45 additions & 9 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,16 +278,14 @@ def from_multiple_data(
def df(self) -> pd.DataFrame:
"""Returns a Data shaped DataFrame"""

# If map_keys is empty just return the df
if self._memo_df is not None:
return self._memo_df

# If map_keys is empty just return the df
if len(self.map_keys) == 0:
return self.map_df

self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
)
self._memo_df = _tail(self.map_df, self.map_keys, n=1, sort=True)

return self._memo_df

Expand Down Expand Up @@ -340,6 +338,30 @@ def clone(self) -> MapData:
description=self.description,
)

def latest(
self,
map_keys: list[str] | None = None,
rows_per_group: int = 1,
) -> MapData:
"""Return a new MapData with the most recently observed `rows_per_group`
rows for each (arm, metric) group, determined by the `map_key` values,
where higher implies more recent.
This function considers only the relative ordering of the `map_key` values,
making it most suitable when these values are equally spaced.
If `rows_per_group` is greater than the number of rows in a given
(arm, metric) group, then all rows are returned.
"""
if map_keys is None:
map_keys = self.map_keys

return MapData(
df=_tail(self.map_df, map_keys, n=rows_per_group, sort=True),
map_key_infos=self.map_key_infos,
description=self.description,
)

def subsample(
self,
map_key: str | None = None,
Expand All @@ -348,11 +370,13 @@ def subsample(
limit_rows_per_metric: int | None = None,
include_first_last: bool = True,
) -> MapData:
"""Subsample the `map_key` column in an equally-spaced manner (if there is
a `self.map_keys` is length one, then `map_key` can be set to None). The
values of the `map_key` column are not taken into account, so this function
is most reasonable when those values are equally-spaced. There are three
ways that this can be done:
"""Return a new MapData that subsamples the `map_key` column in an
equally-spaced manner. If `self.map_keys` has a length of one, `map_key`
can be set to None. This function considers only the relative ordering
of the `map_key` values, making it most suitable when these values are
equally spaced.
There are three ways that this can be done:
1. If `keep_every = k` is set, then every kth row of the DataFrame in the
`map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`.
In other words, every kth step of each (arm, metric) will be kept.
Expand Down Expand Up @@ -454,6 +478,18 @@ def _subsample_rate(
)


def _tail(
map_df: pd.DataFrame,
map_keys: list[str],
n: int = 1,
sort: bool = True,
) -> pd.DataFrame:
df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n)
if sort:
df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True)
return df


def _subsample_one_metric(
map_df: pd.DataFrame,
map_key: str | None = None,
Expand Down
44 changes: 27 additions & 17 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def observations_from_data(
statuses_to_include: set[TrialStatus] | None = None,
statuses_to_include_map_metric: set[TrialStatus] | None = None,
map_keys_as_parameters: bool = False,
latest_rows_per_group: int | None = None,
limit_rows_per_metric: int | None = None,
limit_rows_per_group: int | None = None,
) -> list[Observation]:
Expand All @@ -472,17 +473,21 @@ def observations_from_data(
trials with statuses in this set. Defaults to all statuses except abandoned.
map_keys_as_parameters: Whether map_keys should be returned as part of
the parameters of the Observation objects.
limit_rows_per_metric: If specified, and if data is an instance of MapData,
uses MapData.subsample() with
`limit_rows_per_metric` equal to the specified value on the first
map_key (map_data.map_keys[0]) to subsample the MapData. This is
useful in, e.g., cases where learning curves are frequently
updated, leading to an intractable number of Observation objects
created.
limit_rows_per_group: If specified, and if data is an instance of MapData,
uses MapData.subsample() with
`limit_rows_per_group` equal to the specified value on the first
map_key (map_data.map_keys[0]) to subsample the MapData.
latest_rows_per_group: If specified and data is an instance of MapData,
uses MapData.latest() with `rows_per_group=latest_rows_per_group` to
retrieve the most recent rows for each group. Useful in cases where
learning curves are frequently updated, preventing an excessive
number of Observation objects. Overrides `limit_rows_per_metric`
and `limit_rows_per_group`.
limit_rows_per_metric: If specified and data is an instance of MapData,
uses MapData.subsample() with `limit_rows_per_metric` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Useful for
managing the number of Observation objects when learning curves are
frequently updated. Ignored if `latest_rows_per_group` is specified.
limit_rows_per_group: If specified and data is an instance of MapData,
uses MapData.subsample() with `limit_rows_per_group` on the first
map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if
`latest_rows_per_group` is specified.
Returns:
List of Observation objects.
Expand All @@ -499,13 +504,18 @@ def observations_from_data(
if is_map_data:
data = checked_cast(MapData, data)

if limit_rows_per_metric is not None or limit_rows_per_group is not None:
data = data.subsample(
map_key=data.map_keys[0],
limit_rows_per_metric=limit_rows_per_metric,
limit_rows_per_group=limit_rows_per_group,
include_first_last=True,
if latest_rows_per_group is not None:
data = data.latest(
map_keys=data.map_keys, rows_per_group=latest_rows_per_group
)
else:
if limit_rows_per_metric is not None or limit_rows_per_group is not None:
data = data.subsample(
map_key=data.map_keys[0],
limit_rows_per_metric=limit_rows_per_metric,
limit_rows_per_group=limit_rows_per_group,
include_first_last=True,
)

map_keys.extend(data.map_keys)
obs_cols = obs_cols.union(data.map_keys)
Expand Down
75 changes: 74 additions & 1 deletion ax/core/tests/test_map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# pyre-strict


import numpy as np
import pandas as pd
from ax.core.data import Data
from ax.core.map_data import MapData, MapKeyInfo
Expand Down Expand Up @@ -236,7 +237,17 @@ def test_upcast(self) -> None:

self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call

def test_subsample(self) -> None:
self.assertTrue(
fresh.df.equals(
fresh.map_df.sort_values(fresh.map_keys).drop_duplicates(
MapData.DEDUPLICATE_BY_COLUMNS, keep="last"
)
)
)

def test_latest(self) -> None:
seed = 8888

arm_names = ["0_0", "1_0", "2_0", "3_0"]
max_epochs = [25, 50, 75, 100]
metric_names = ["a", "b"]
Expand All @@ -259,6 +270,68 @@ def test_subsample(self) -> None:
)
large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos)

shuffled_large_map_df = large_map_data.map_df.groupby(
MapData.DEDUPLICATE_BY_COLUMNS
).sample(frac=1, random_state=seed)
shuffled_large_map_data = MapData(
df=shuffled_large_map_df, map_key_infos=self.map_key_infos
)

for rows_per_group in [1, 40]:
large_map_data_latest = large_map_data.latest(rows_per_group=rows_per_group)

if rows_per_group == 1:
self.assertTrue(
large_map_data_latest.map_df.groupby("metric_name")
.epoch.transform(lambda col: set(col) == set(max_epochs))
.all()
)

# when rows_per_group is larger than the number of rows
# actually observed in a group
actual_rows_per_group = large_map_data_latest.map_df.groupby(
MapData.DEDUPLICATE_BY_COLUMNS
).size()
expected_rows_per_group = np.minimum(
large_map_data_latest.map_df.groupby(
MapData.DEDUPLICATE_BY_COLUMNS
).epoch.max(),
rows_per_group,
)
self.assertTrue(actual_rows_per_group.equals(expected_rows_per_group))

# behavior should be consistent even if map_keys are not in ascending order
shuffled_large_map_data_latest = shuffled_large_map_data.latest(
rows_per_group=rows_per_group
)
self.assertTrue(
shuffled_large_map_data_latest.map_df.equals(
large_map_data_latest.map_df
)
)

def test_subsample(self) -> None:
arm_names = ["0_0", "1_0", "2_0", "3_0"]
max_epochs = [25, 50, 75, 100]
metric_names = ["a", "b"]
large_map_df = pd.DataFrame(
[
{
"arm_name": arm_name,
"epoch": epoch + 1,
"mean": epoch * 0.1,
"sem": 0.1,
"trial_index": trial_index,
"metric_name": metric_name,
}
for metric_name in metric_names
for trial_index, (arm_name, max_epoch) in enumerate(
zip(arm_names, max_epochs)
)
for epoch in range(max_epoch)
]
)
large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos)
large_map_df_sparse_metric = pd.DataFrame(
[
{
Expand Down

0 comments on commit fd76ff0

Please sign in to comment.