Skip to content

Commit

Permalink
Add temporal filtration method (#756)
Browse files Browse the repository at this point in the history
* add test

* add implementation

* add docstrings

* remove consolidate use from codebase, deprecate, and fix a bug

* add MR suggestions
  • Loading branch information
zigaLuksic authored Oct 3, 2023
1 parent 901458c commit 9c556e1
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 38 deletions.
38 changes: 38 additions & 0 deletions eolearn/core/eodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ def merge(
self, *eopatches, features=features, time_dependent_op=time_dependent_op, timeless_op=timeless_op
)

@deprecated_function(EODeprecationWarning, "Please use the method `temporal_subset` instead.")
def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.datetime]:
"""Removes all frames from the EOPatch with a date not found in the provided timestamps list.
Expand All @@ -750,6 +751,43 @@ def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.dateti

return remove_from_patch

def temporal_subset(
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool]
) -> EOPatch:
"""Returns an EOPatch that only contains data for the temporal subset corresponding to `timestamps`.
For array-based data appropriate temporal slices are extracted. For vector data a filtration is performed.
:param timestamps: Parameter that defines the temporal subset. Can be a collection of timestamps, a
collection of timestamp indices, or a callable that returns whether a timestamp should be kept.
"""
timestamp_indices = self._parse_temporal_subset_input(timestamps)
new_timestamps = [ts for i, ts in enumerate(self.get_timestamps()) if i in timestamp_indices]
new_patch = EOPatch(bbox=self.bbox, timestamps=new_timestamps)

for ftype, fname in self.get_features():
if ftype.is_timeless() or ftype.is_meta():
new_patch[ftype, fname] = self[ftype, fname]
elif ftype.is_vector():
gdf: gpd.GeoDataFrame = self[ftype, fname]
new_patch[ftype, fname] = gdf[gdf[TIMESTAMP_COLUMN].isin(new_timestamps)]
else:
new_patch[ftype, fname] = self[ftype, fname][timestamp_indices]

return new_patch

def _parse_temporal_subset_input(
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool]
) -> list[int]:
"""Parses input into a list of timestamp indices. Also adds implicit support for strings via `parse_time`."""
if callable(timestamps):
return [i for i, ts in enumerate(self.get_timestamps()) if timestamps(ts)]
ts_or_idx = list(timestamps)
if all(isinstance(ts, int) for ts in ts_or_idx):
return ts_or_idx # type: ignore[return-value]
parsed_timestamps = {parse_time(ts, force_datetime=True) for ts in ts_or_idx} # type: ignore[call-overload]
return [i for i, ts in enumerate(self.get_timestamps()) if ts in parsed_timestamps]

def plot(
self,
feature: Feature,
Expand Down
88 changes: 59 additions & 29 deletions tests/core/test_eodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
from __future__ import annotations

import datetime
import datetime as dt
import warnings
from typing import Any

Expand All @@ -17,6 +17,7 @@
from sentinelhub import CRS, BBox

from eolearn.core import EOPatch, FeatureType
from eolearn.core.constants import TIMESTAMP_COLUMN
from eolearn.core.eodata_io import FeatureIO
from eolearn.core.exceptions import EODeprecationWarning, TemporalDimensionWarning
from eolearn.core.types import Feature, FeaturesSpecification
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_bbox_feature_type(invalid_bbox: Any) -> None:


@pytest.mark.parametrize(
"valid_entry", [["2018-01-01", "15.2.1992"], (datetime.datetime(2017, 1, 1, 10, 4, 7), datetime.date(2017, 1, 11))]
"valid_entry", [["2018-01-01", "15.2.1992"], (dt.datetime(2017, 1, 1, 10, 4, 7), dt.date(2017, 1, 11))]
)
def test_timestamp_valid_feature_type(valid_entry: Any) -> None:
eop = EOPatch(bbox=DUMMY_BBOX, timestamps=valid_entry)
Expand All @@ -106,9 +107,9 @@ def test_timestamp_valid_feature_type(valid_entry: Any) -> None:
@pytest.mark.parametrize(
"invalid_timestamps",
[
[datetime.datetime(2017, 1, 1, 10, 4, 7), None, datetime.datetime(2017, 1, 11, 10, 3, 51)],
[dt.datetime(2017, 1, 1, 10, 4, 7), None, dt.datetime(2017, 1, 11, 10, 3, 51)],
"something",
datetime.datetime(2017, 1, 1, 10, 4, 7),
dt.datetime(2017, 1, 1, 10, 4, 7),
],
)
def test_timestamps_invalid_feature_type(invalid_timestamps: Any) -> None:
Expand Down Expand Up @@ -398,19 +399,20 @@ def test_get_features(patch: EOPatch, expected_features: list[Feature]) -> None:
assert patch.get_features() == expected_features


@pytest.mark.filterwarnings("ignore::eolearn.core.exceptions.EODeprecationWarning")
def test_timestamp_consolidation() -> None:
# 10 frames
timestamps = [
datetime.datetime(2017, 1, 1, 10, 4, 7),
datetime.datetime(2017, 1, 4, 10, 14, 5),
datetime.datetime(2017, 1, 11, 10, 3, 51),
datetime.datetime(2017, 1, 14, 10, 13, 46),
datetime.datetime(2017, 1, 24, 10, 14, 7),
datetime.datetime(2017, 2, 10, 10, 1, 32),
datetime.datetime(2017, 2, 20, 10, 6, 35),
datetime.datetime(2017, 3, 2, 10, 0, 20),
datetime.datetime(2017, 3, 12, 10, 7, 6),
datetime.datetime(2017, 3, 15, 10, 12, 14),
dt.datetime(2017, 1, 1, 10, 4, 7),
dt.datetime(2017, 1, 4, 10, 14, 5),
dt.datetime(2017, 1, 11, 10, 3, 51),
dt.datetime(2017, 1, 14, 10, 13, 46),
dt.datetime(2017, 1, 24, 10, 14, 7),
dt.datetime(2017, 2, 10, 10, 1, 32),
dt.datetime(2017, 2, 20, 10, 6, 35),
dt.datetime(2017, 3, 2, 10, 0, 20),
dt.datetime(2017, 3, 12, 10, 7, 6),
dt.datetime(2017, 3, 15, 10, 12, 14),
]

data = np.random.rand(10, 100, 100, 3)
Expand All @@ -430,7 +432,7 @@ def test_timestamp_consolidation() -> None:
good_timestamps = timestamps.copy()
del good_timestamps[0]
del good_timestamps[-1]
good_timestamps.append(datetime.datetime(2017, 12, 1))
good_timestamps.append(dt.datetime(2017, 12, 1))

removed_frames = eop.consolidate_timestamps(good_timestamps)

Expand All @@ -444,20 +446,48 @@ def test_timestamp_consolidation() -> None:
assert np.array_equal(mask_timeless, eop.mask_timeless["MASK_TIMELESS"])


def test_timestamps_deprecation():
eop = EOPatch(bbox=DUMMY_BBOX, timestamps=[datetime.datetime(1234, 5, 6)])

with pytest.warns(EODeprecationWarning):
assert eop.timestamp == [datetime.datetime(1234, 5, 6)]

with pytest.warns(EODeprecationWarning):
eop.timestamp = [datetime.datetime(4321, 5, 6)]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=EODeprecationWarning)
# so the warnings get ignored in pytest summary
assert eop.timestamp == [datetime.datetime(4321, 5, 6)]
assert eop.timestamp == eop.timestamps
@pytest.mark.parametrize(
"method_input",
[
["2017-04-08", "2017-09-17"],
[1, 2],
lambda x: (x > dt.datetime(2017, 4, 1) and x < dt.datetime(2017, 10, 10)),
],
)
def test_temporal_subset(method_input):
eop = generate_eopatch(
{
FeatureType.DATA: ["data1", "data2"],
FeatureType.MASK_TIMELESS: ["mask_timeless"],
FeatureType.SCALAR_TIMELESS: ["scalar_timeless"],
FeatureType.MASK: ["mask"],
},
timestamps=[
dt.datetime(2017, 1, 5),
dt.datetime(2017, 4, 8),
dt.datetime(2017, 9, 17),
dt.datetime(2018, 1, 5),
dt.datetime(2018, 12, 1),
],
)
vector_data = GeoDataFrame(
{TIMESTAMP_COLUMN: eop.get_timestamps()}, geometry=[eop.bbox.geometry.buffer(i) for i in range(5)], crs=32633
)
eop.vector["vector"] = vector_data
subset_timestamps = eop.timestamps[1:3]

subset_eop = eop.temporal_subset(method_input)
assert subset_eop.timestamps == subset_timestamps
for feature in eop.get_features():
if feature[0].is_timeless():
assert_feature_data_equal(eop[feature], subset_eop[feature])
elif feature[0].is_array():
assert_feature_data_equal(eop[feature][1:3, ...], subset_eop[feature])

assert_feature_data_equal(
subset_eop.vector["vector"],
vector_data[1:3],
)


def test_bbox_none_deprecation():
Expand Down
6 changes: 2 additions & 4 deletions tests/core/test_eodata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,7 @@ def test_partial_temporal_saving_into_existing(eopatch: EOPatch, temporal_select
io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES")
eopatch.save(**io_kwargs, use_zarr=True)

partial_patch = eopatch.copy(deep=True)
partial_patch.consolidate_timestamps(np.array(partial_patch.timestamps)[temporal_selection or ...])
partial_patch = eopatch.copy(deep=True).temporal_subset(np.array(eopatch.timestamps)[temporal_selection or ...])

partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2)
partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection=temporal_selection)
Expand Down Expand Up @@ -668,8 +667,7 @@ def test_partial_temporal_saving_infer(eopatch: EOPatch):
io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES")
eopatch.save(**io_kwargs, use_zarr=True)

partial_patch = eopatch.copy(deep=True)
partial_patch.consolidate_timestamps(eopatch.timestamps[1:2] + eopatch.timestamps[3:5])
partial_patch = eopatch.copy(deep=True).temporal_subset([1, 3, 4])

partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2)
partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection="infer")
Expand Down
5 changes: 2 additions & 3 deletions tests/core/test_eodata_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def test_lazy_loading(test_eopatch_path):

def test_temporally_independent_merge(test_eopatch_path):
full_patch = EOPatch.load(test_eopatch_path)
part1, part2 = full_patch.copy(deep=True), full_patch.copy(deep=True)
part1.consolidate_timestamps(full_patch.get_timestamps()[:10])
part2.consolidate_timestamps(full_patch.get_timestamps()[10:])
part1 = full_patch.copy(deep=True).temporal_subset(range(10))
part2 = full_patch.copy(deep=True).temporal_subset(range(10, len(full_patch.timestamps)))

assert full_patch == merge_eopatches(part1, part2, time_dependent_op="concatenate")
3 changes: 1 addition & 2 deletions tests/features/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ def small_ndvi_eopatch_fixture(example_eopatch: EOPatch):
ndvi = example_eopatch.data["NDVI"][:, :20, :20]
ndvi[np.isnan(ndvi)] = 0
example_eopatch.data["NDVI"] = ndvi
example_eopatch.consolidate_timestamps(example_eopatch.get_timestamps()[:10])
return example_eopatch
return example_eopatch.temporal_subset(range(10))

0 comments on commit 9c556e1

Please sign in to comment.