diff --git a/eolearn/core/eodata.py b/eolearn/core/eodata.py index 5f0dda5a..942db7f4 100644 --- a/eolearn/core/eodata.py +++ b/eolearn/core/eodata.py @@ -726,6 +726,7 @@ def merge( self, *eopatches, features=features, time_dependent_op=time_dependent_op, timeless_op=timeless_op ) + @deprecated_function(EODeprecationWarning, "Please use the method `temporal_subset` instead.") def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.datetime]: """Removes all frames from the EOPatch with a date not found in the provided timestamps list. @@ -750,6 +751,43 @@ def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.dateti return remove_from_patch + def temporal_subset( + self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool] + ) -> EOPatch: + """Returns an EOPatch that only contains data for the temporal subset corresponding to `timestamps`. + + For array-based data appropriate temporal slices are extracted. For vector data a filtration is performed. + + :param timestamps: Parameter that defines the temporal subset. Can be a collection of timestamps, a + collection of timestamp indices, or a callable that returns whether a timestamp should be kept. + """ + timestamp_indices = self._parse_temporal_subset_input(timestamps) + new_timestamps = [ts for i, ts in enumerate(self.get_timestamps()) if i in timestamp_indices] + new_patch = EOPatch(bbox=self.bbox, timestamps=new_timestamps) + + for ftype, fname in self.get_features(): + if ftype.is_timeless() or ftype.is_meta(): + new_patch[ftype, fname] = self[ftype, fname] + elif ftype.is_vector(): + gdf: gpd.GeoDataFrame = self[ftype, fname] + new_patch[ftype, fname] = gdf[gdf[TIMESTAMP_COLUMN].isin(new_timestamps)] + else: + new_patch[ftype, fname] = self[ftype, fname][timestamp_indices] + + return new_patch + + def _parse_temporal_subset_input( + self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool] + ) -> list[int]: + """Parses input into a list of timestamp indices. Also adds implicit support for strings via `parse_time`.""" + if callable(timestamps): + return [i for i, ts in enumerate(self.get_timestamps()) if timestamps(ts)] + ts_or_idx = list(timestamps) + if all(isinstance(ts, int) for ts in ts_or_idx): + return ts_or_idx # type: ignore[return-value] + parsed_timestamps = {parse_time(ts, force_datetime=True) for ts in ts_or_idx} # type: ignore[call-overload] + return [i for i, ts in enumerate(self.get_timestamps()) if ts in parsed_timestamps] + def plot( self, feature: Feature, diff --git a/tests/core/test_eodata.py b/tests/core/test_eodata.py index 68ccae5f..cfe1366b 100644 --- a/tests/core/test_eodata.py +++ b/tests/core/test_eodata.py @@ -6,7 +6,7 @@ """ from __future__ import annotations -import datetime +import datetime as dt import warnings from typing import Any @@ -17,6 +17,7 @@ from sentinelhub import CRS, BBox from eolearn.core import EOPatch, FeatureType +from eolearn.core.constants import TIMESTAMP_COLUMN from eolearn.core.eodata_io import FeatureIO from eolearn.core.exceptions import EODeprecationWarning, TemporalDimensionWarning from eolearn.core.types import Feature, FeaturesSpecification @@ -96,7 +97,7 @@ def test_bbox_feature_type(invalid_bbox: Any) -> None: @pytest.mark.parametrize( - "valid_entry", [["2018-01-01", "15.2.1992"], (datetime.datetime(2017, 1, 1, 10, 4, 7), datetime.date(2017, 1, 11))] + "valid_entry", [["2018-01-01", "15.2.1992"], (dt.datetime(2017, 1, 1, 10, 4, 7), dt.date(2017, 1, 11))] ) def test_timestamp_valid_feature_type(valid_entry: Any) -> None: eop = EOPatch(bbox=DUMMY_BBOX, timestamps=valid_entry) @@ -106,9 +107,9 @@ def test_timestamp_valid_feature_type(valid_entry: Any) -> None: @pytest.mark.parametrize( "invalid_timestamps", [ - [datetime.datetime(2017, 1, 1, 10, 4, 7), None, datetime.datetime(2017, 1, 11, 10, 3, 51)], + [dt.datetime(2017, 1, 1, 10, 4, 7), None, dt.datetime(2017, 1, 11, 10, 3, 51)], "something", - datetime.datetime(2017, 1, 1, 10, 4, 7), + dt.datetime(2017, 1, 1, 10, 4, 7), ], ) def test_timestamps_invalid_feature_type(invalid_timestamps: Any) -> None: @@ -398,19 +399,20 @@ def test_get_features(patch: EOPatch, expected_features: list[Feature]) -> None: assert patch.get_features() == expected_features +@pytest.mark.filterwarnings("ignore::eolearn.core.exceptions.EODeprecationWarning") def test_timestamp_consolidation() -> None: # 10 frames timestamps = [ - datetime.datetime(2017, 1, 1, 10, 4, 7), - datetime.datetime(2017, 1, 4, 10, 14, 5), - datetime.datetime(2017, 1, 11, 10, 3, 51), - datetime.datetime(2017, 1, 14, 10, 13, 46), - datetime.datetime(2017, 1, 24, 10, 14, 7), - datetime.datetime(2017, 2, 10, 10, 1, 32), - datetime.datetime(2017, 2, 20, 10, 6, 35), - datetime.datetime(2017, 3, 2, 10, 0, 20), - datetime.datetime(2017, 3, 12, 10, 7, 6), - datetime.datetime(2017, 3, 15, 10, 12, 14), + dt.datetime(2017, 1, 1, 10, 4, 7), + dt.datetime(2017, 1, 4, 10, 14, 5), + dt.datetime(2017, 1, 11, 10, 3, 51), + dt.datetime(2017, 1, 14, 10, 13, 46), + dt.datetime(2017, 1, 24, 10, 14, 7), + dt.datetime(2017, 2, 10, 10, 1, 32), + dt.datetime(2017, 2, 20, 10, 6, 35), + dt.datetime(2017, 3, 2, 10, 0, 20), + dt.datetime(2017, 3, 12, 10, 7, 6), + dt.datetime(2017, 3, 15, 10, 12, 14), ] data = np.random.rand(10, 100, 100, 3) @@ -430,7 +432,7 @@ def test_timestamp_consolidation() -> None: good_timestamps = timestamps.copy() del good_timestamps[0] del good_timestamps[-1] - good_timestamps.append(datetime.datetime(2017, 12, 1)) + good_timestamps.append(dt.datetime(2017, 12, 1)) removed_frames = eop.consolidate_timestamps(good_timestamps) @@ -444,20 +446,48 @@ def test_timestamp_consolidation() -> None: assert np.array_equal(mask_timeless, eop.mask_timeless["MASK_TIMELESS"]) -def test_timestamps_deprecation(): - eop = EOPatch(bbox=DUMMY_BBOX, timestamps=[datetime.datetime(1234, 5, 6)]) - - with pytest.warns(EODeprecationWarning): - assert eop.timestamp == [datetime.datetime(1234, 5, 6)] - - with pytest.warns(EODeprecationWarning): - eop.timestamp = [datetime.datetime(4321, 5, 6)] - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=EODeprecationWarning) - # so the warnings get ignored in pytest summary - assert eop.timestamp == [datetime.datetime(4321, 5, 6)] - assert eop.timestamp == eop.timestamps +@pytest.mark.parametrize( + "method_input", + [ + ["2017-04-08", "2017-09-17"], + [1, 2], + lambda x: (x > dt.datetime(2017, 4, 1) and x < dt.datetime(2017, 10, 10)), + ], +) +def test_temporal_subset(method_input): + eop = generate_eopatch( + { + FeatureType.DATA: ["data1", "data2"], + FeatureType.MASK_TIMELESS: ["mask_timeless"], + FeatureType.SCALAR_TIMELESS: ["scalar_timeless"], + FeatureType.MASK: ["mask"], + }, + timestamps=[ + dt.datetime(2017, 1, 5), + dt.datetime(2017, 4, 8), + dt.datetime(2017, 9, 17), + dt.datetime(2018, 1, 5), + dt.datetime(2018, 12, 1), + ], + ) + vector_data = GeoDataFrame( + {TIMESTAMP_COLUMN: eop.get_timestamps()}, geometry=[eop.bbox.geometry.buffer(i) for i in range(5)], crs=32633 + ) + eop.vector["vector"] = vector_data + subset_timestamps = eop.timestamps[1:3] + + subset_eop = eop.temporal_subset(method_input) + assert subset_eop.timestamps == subset_timestamps + for feature in eop.get_features(): + if feature[0].is_timeless(): + assert_feature_data_equal(eop[feature], subset_eop[feature]) + elif feature[0].is_array(): + assert_feature_data_equal(eop[feature][1:3, ...], subset_eop[feature]) + + assert_feature_data_equal( + subset_eop.vector["vector"], + vector_data[1:3], + ) def test_bbox_none_deprecation(): diff --git a/tests/core/test_eodata_io.py b/tests/core/test_eodata_io.py index e9692cd9..bfb56bfe 100644 --- a/tests/core/test_eodata_io.py +++ b/tests/core/test_eodata_io.py @@ -623,8 +623,7 @@ def test_partial_temporal_saving_into_existing(eopatch: EOPatch, temporal_select io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES") eopatch.save(**io_kwargs, use_zarr=True) - partial_patch = eopatch.copy(deep=True) - partial_patch.consolidate_timestamps(np.array(partial_patch.timestamps)[temporal_selection or ...]) + partial_patch = eopatch.copy(deep=True).temporal_subset(np.array(eopatch.timestamps)[temporal_selection or ...]) partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2) partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection=temporal_selection) @@ -668,8 +667,7 @@ def test_partial_temporal_saving_infer(eopatch: EOPatch): io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES") eopatch.save(**io_kwargs, use_zarr=True) - partial_patch = eopatch.copy(deep=True) - partial_patch.consolidate_timestamps(eopatch.timestamps[1:2] + eopatch.timestamps[3:5]) + partial_patch = eopatch.copy(deep=True).temporal_subset([1, 3, 4]) partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2) partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection="infer") diff --git a/tests/core/test_eodata_merge.py b/tests/core/test_eodata_merge.py index 131e8706..4ce7bd3d 100644 --- a/tests/core/test_eodata_merge.py +++ b/tests/core/test_eodata_merge.py @@ -206,8 +206,7 @@ def test_lazy_loading(test_eopatch_path): def test_temporally_independent_merge(test_eopatch_path): full_patch = EOPatch.load(test_eopatch_path) - part1, part2 = full_patch.copy(deep=True), full_patch.copy(deep=True) - part1.consolidate_timestamps(full_patch.get_timestamps()[:10]) - part2.consolidate_timestamps(full_patch.get_timestamps()[10:]) + part1 = full_patch.copy(deep=True).temporal_subset(range(10)) + part2 = full_patch.copy(deep=True).temporal_subset(range(10, len(full_patch.timestamps))) assert full_patch == merge_eopatches(part1, part2, time_dependent_op="concatenate") diff --git a/tests/features/conftest.py b/tests/features/conftest.py index 77c533c4..8959c6b1 100644 --- a/tests/features/conftest.py +++ b/tests/features/conftest.py @@ -34,5 +34,4 @@ def small_ndvi_eopatch_fixture(example_eopatch: EOPatch): ndvi = example_eopatch.data["NDVI"][:, :20, :20] ndvi[np.isnan(ndvi)] = 0 example_eopatch.data["NDVI"] = ndvi - example_eopatch.consolidate_timestamps(example_eopatch.get_timestamps()[:10]) - return example_eopatch + return example_eopatch.temporal_subset(range(10))