From 1efcb4286d3ec4fceec3bd88b3648a38034b3541 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Wed, 4 Dec 2024 13:31:17 -0800 Subject: [PATCH] Adds method to retain the N most recently observed values from MapData (#3112) Summary: * Provide a new method `latest` for `MapData` to retrieve the *n* most recently observed values for each (arm, metric) group, determined by the `map_key` values, where higher implies more recent. * Update `observations_from_data` to optionally utilize `latest` and retain only the most recently observed *n* values (the new option, if specified alongside the existing subsampling options, will now take precedence). * Modify the "upcast" `MapData.df` property to leverage `latest`, which is a special case with *n=1*. * Revise the docstring to reflect changes in the pertinent methods, as well as update related methods like `subsample` to ensure uniform and consistent writing. Differential Revision: D66434621 --- ax/core/map_data.py | 56 +++++++++++++++++++++---- ax/core/observation.py | 44 ++++++++++++-------- ax/core/tests/test_map_data.py | 75 +++++++++++++++++++++++++++++++++- 3 files changed, 149 insertions(+), 26 deletions(-) diff --git a/ax/core/map_data.py b/ax/core/map_data.py index 06da05e4a71..664bc011867 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -278,15 +278,15 @@ def from_multiple_data( def df(self) -> pd.DataFrame: """Returns a Data shaped DataFrame""" - # If map_keys is empty just return the df if self._memo_df is not None: return self._memo_df + # If map_keys is empty just return the df if len(self.map_keys) == 0: return self.map_df - self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates( - MapData.DEDUPLICATE_BY_COLUMNS, keep="last" + self._memo_df = _tail( + map_df=self.map_df, map_keys=self.map_keys, n=1, sort=True ) return self._memo_df @@ -340,6 +340,32 @@ def clone(self) -> MapData: description=self.description, ) + def latest( + self, + map_keys: list[str] | None = None, + rows_per_group: int = 1, + ) -> MapData: + """Return a new MapData with the most recently observed `rows_per_group` + rows for each (arm, metric) group, determined by the `map_key` values, + where higher implies more recent. + + This function considers only the relative ordering of the `map_key` values, + making it most suitable when these values are equally spaced. + + If `rows_per_group` is greater than the number of rows in a given + (arm, metric) group, then all rows are returned. + """ + if map_keys is None: + map_keys = self.map_keys + + return MapData( + df=_tail( + map_df=self.map_df, map_keys=map_keys, n=rows_per_group, sort=True + ), + map_key_infos=self.map_key_infos, + description=self.description, + ) + def subsample( self, map_key: str | None = None, @@ -348,11 +374,13 @@ def subsample( limit_rows_per_metric: int | None = None, include_first_last: bool = True, ) -> MapData: - """Subsample the `map_key` column in an equally-spaced manner (if there is - a `self.map_keys` is length one, then `map_key` can be set to None). The - values of the `map_key` column are not taken into account, so this function - is most reasonable when those values are equally-spaced. There are three - ways that this can be done: + """Return a new MapData that subsamples the `map_key` column in an + equally-spaced manner. If `self.map_keys` has a length of one, `map_key` + can be set to None. This function considers only the relative ordering + of the `map_key` values, making it most suitable when these values are + equally spaced. + + There are three ways that this can be done: 1. If `keep_every = k` is set, then every kth row of the DataFrame in the `map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`. In other words, every kth step of each (arm, metric) will be kept. @@ -456,6 +484,18 @@ def _subsample_rate( ) +def _tail( + map_df: pd.DataFrame, + map_keys: list[str], + n: int = 1, + sort: bool = True, +) -> pd.DataFrame: + df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n) + if sort: + df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True) + return df + + def _subsample_one_metric( map_df: pd.DataFrame, map_key: str | None = None, diff --git a/ax/core/observation.py b/ax/core/observation.py index a6284405837..94ee013d7a5 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -452,6 +452,7 @@ def observations_from_data( statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, + latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, ) -> list[Observation]: @@ -472,17 +473,21 @@ def observations_from_data( trials with statuses in this set. Defaults to all statuses except abandoned. map_keys_as_parameters: Whether map_keys should be returned as part of the parameters of the Observation objects. - limit_rows_per_metric: If specified, and if data is an instance of MapData, - uses MapData.subsample() with - `limit_rows_per_metric` equal to the specified value on the first - map_key (map_data.map_keys[0]) to subsample the MapData. This is - useful in, e.g., cases where learning curves are frequently - updated, leading to an intractable number of Observation objects - created. - limit_rows_per_group: If specified, and if data is an instance of MapData, - uses MapData.subsample() with - `limit_rows_per_group` equal to the specified value on the first - map_key (map_data.map_keys[0]) to subsample the MapData. + latest_rows_per_group: If specified and data is an instance of MapData, + uses MapData.latest() with `rows_per_group=latest_rows_per_group` to + retrieve the most recent rows for each group. Useful in cases where + learning curves are frequently updated, preventing an excessive + number of Observation objects. Overrides `limit_rows_per_metric` + and `limit_rows_per_group`. + limit_rows_per_metric: If specified and data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_metric` on the first + map_key (map_data.map_keys[0]) to subsample the MapData. Useful for + managing the number of Observation objects when learning curves are + frequently updated. Ignored if `latest_rows_per_group` is specified. + limit_rows_per_group: If specified and data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_group` on the first + map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if + `latest_rows_per_group` is specified. Returns: List of Observation objects. @@ -499,13 +504,18 @@ def observations_from_data( if is_map_data: data = checked_cast(MapData, data) - if limit_rows_per_metric is not None or limit_rows_per_group is not None: - data = data.subsample( - map_key=data.map_keys[0], - limit_rows_per_metric=limit_rows_per_metric, - limit_rows_per_group=limit_rows_per_group, - include_first_last=True, + if latest_rows_per_group is not None: + data = data.latest( + map_keys=data.map_keys, rows_per_group=latest_rows_per_group ) + else: + if limit_rows_per_metric is not None or limit_rows_per_group is not None: + data = data.subsample( + map_key=data.map_keys[0], + limit_rows_per_metric=limit_rows_per_metric, + limit_rows_per_group=limit_rows_per_group, + include_first_last=True, + ) map_keys.extend(data.map_keys) obs_cols = obs_cols.union(data.map_keys) diff --git a/ax/core/tests/test_map_data.py b/ax/core/tests/test_map_data.py index ce0576e295c..0b4f1f5fd22 100644 --- a/ax/core/tests/test_map_data.py +++ b/ax/core/tests/test_map_data.py @@ -6,6 +6,7 @@ # pyre-strict +import numpy as np import pandas as pd from ax.core.data import Data from ax.core.map_data import MapData, MapKeyInfo @@ -236,7 +237,17 @@ def test_upcast(self) -> None: self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call - def test_subsample(self) -> None: + self.assertTrue( + fresh.df.equals( + fresh.map_df.sort_values(fresh.map_keys).drop_duplicates( + MapData.DEDUPLICATE_BY_COLUMNS, keep="last" + ) + ) + ) + + def test_latest(self) -> None: + seed = 8888 + arm_names = ["0_0", "1_0", "2_0", "3_0"] max_epochs = [25, 50, 75, 100] metric_names = ["a", "b"] @@ -259,6 +270,68 @@ def test_subsample(self) -> None: ) large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos) + shuffled_large_map_df = large_map_data.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).sample(frac=1, random_state=seed) + shuffled_large_map_data = MapData( + df=shuffled_large_map_df, map_key_infos=self.map_key_infos + ) + + for rows_per_group in [1, 40]: + large_map_data_latest = large_map_data.latest(rows_per_group=rows_per_group) + + if rows_per_group == 1: + self.assertTrue( + large_map_data_latest.map_df.groupby("metric_name") + .epoch.transform(lambda col: set(col) == set(max_epochs)) + .all() + ) + + # when rows_per_group is larger than the number of rows + # actually observed in a group + actual_rows_per_group = large_map_data_latest.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).size() + expected_rows_per_group = np.minimum( + large_map_data_latest.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).epoch.max(), + rows_per_group, + ) + self.assertTrue(actual_rows_per_group.equals(expected_rows_per_group)) + + # behavior should be consistent even if map_keys are not in ascending order + shuffled_large_map_data_latest = shuffled_large_map_data.latest( + rows_per_group=rows_per_group + ) + self.assertTrue( + shuffled_large_map_data_latest.map_df.equals( + large_map_data_latest.map_df + ) + ) + + def test_subsample(self) -> None: + arm_names = ["0_0", "1_0", "2_0", "3_0"] + max_epochs = [25, 50, 75, 100] + metric_names = ["a", "b"] + large_map_df = pd.DataFrame( + [ + { + "arm_name": arm_name, + "epoch": epoch + 1, + "mean": epoch * 0.1, + "sem": 0.1, + "trial_index": trial_index, + "metric_name": metric_name, + } + for metric_name in metric_names + for trial_index, (arm_name, max_epoch) in enumerate( + zip(arm_names, max_epochs) + ) + for epoch in range(max_epoch) + ] + ) + large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos) large_map_df_sparse_metric = pd.DataFrame( [ {