Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temporal filtration method #756

Merged
merged 5 commits into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions eolearn/core/eodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ def merge(
self, *eopatches, features=features, time_dependent_op=time_dependent_op, timeless_op=timeless_op
)

@deprecated_function(EODeprecationWarning, "Please use the method `temporal_subset` instead.")
def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.datetime]:
"""Removes all frames from the EOPatch with a date not found in the provided timestamps list.

Expand All @@ -750,6 +751,43 @@ def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.dateti

return remove_from_patch

def temporal_subset(
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool]
) -> EOPatch:
"""Returns an EOPatch that only contains data for the temporal subset corresponding to `timestamps`.

For array-based data appropriate temporal slices are extracted. For vector data a filtration is performed.

:param timestamps: Parameter that defines the temporal subset. Can be a collection of timestamps, a
collection of timestamp indices, or a callable that returns whether a timestamp should be kept.
"""
timestamp_indices = self._parse_temporal_subset_input(timestamps)
new_timestamps = [ts for i, ts in enumerate(self.get_timestamps()) if i in timestamp_indices]
new_patch = EOPatch(bbox=self.bbox, timestamps=new_timestamps)

for ftype, fname in self.get_features():
if ftype.is_timeless() or ftype.is_meta():
new_patch[ftype, fname] = self[ftype, fname]
elif ftype.is_vector():
gdf: gpd.GeoDataFrame = self[ftype, fname]
new_patch[ftype, fname] = gdf[gdf[TIMESTAMP_COLUMN].isin(new_timestamps)]
else:
new_patch[ftype, fname] = self[ftype, fname][timestamp_indices]

return new_patch

def _parse_temporal_subset_input(
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool]
) -> list[int]:
"""Parses input into a list of timestamp indices. Also adds implicit support for strings via `parse_time`."""
mlubej marked this conversation as resolved.
Show resolved Hide resolved
if callable(timestamps):
return [i for i, ts in enumerate(self.get_timestamps()) if timestamps(ts)]
ts_or_idx = list(timestamps)
if all(isinstance(ts, int) for ts in ts_or_idx):
return ts_or_idx # type: ignore[return-value]
parsed_timestamps = {parse_time(ts, force_datetime=True) for ts in ts_or_idx} # type: ignore[call-overload]
mlubej marked this conversation as resolved.
Show resolved Hide resolved
return [i for i, ts in enumerate(self.get_timestamps()) if ts in parsed_timestamps]

def plot(
self,
feature: Feature,
Expand Down
88 changes: 59 additions & 29 deletions tests/core/test_eodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
from __future__ import annotations

import datetime
import datetime as dt
import warnings
from typing import Any

Expand All @@ -17,6 +17,7 @@
from sentinelhub import CRS, BBox

from eolearn.core import EOPatch, FeatureType
from eolearn.core.constants import TIMESTAMP_COLUMN
from eolearn.core.eodata_io import FeatureIO
from eolearn.core.exceptions import EODeprecationWarning, TemporalDimensionWarning
from eolearn.core.types import Feature, FeaturesSpecification
Expand Down Expand Up @@ -96,7 +97,7 @@ def test_bbox_feature_type(invalid_bbox: Any) -> None:


@pytest.mark.parametrize(
"valid_entry", [["2018-01-01", "15.2.1992"], (datetime.datetime(2017, 1, 1, 10, 4, 7), datetime.date(2017, 1, 11))]
"valid_entry", [["2018-01-01", "15.2.1992"], (dt.datetime(2017, 1, 1, 10, 4, 7), dt.date(2017, 1, 11))]
)
def test_timestamp_valid_feature_type(valid_entry: Any) -> None:
eop = EOPatch(bbox=DUMMY_BBOX, timestamps=valid_entry)
Expand All @@ -106,9 +107,9 @@ def test_timestamp_valid_feature_type(valid_entry: Any) -> None:
@pytest.mark.parametrize(
"invalid_timestamps",
[
[datetime.datetime(2017, 1, 1, 10, 4, 7), None, datetime.datetime(2017, 1, 11, 10, 3, 51)],
[dt.datetime(2017, 1, 1, 10, 4, 7), None, dt.datetime(2017, 1, 11, 10, 3, 51)],
"something",
datetime.datetime(2017, 1, 1, 10, 4, 7),
dt.datetime(2017, 1, 1, 10, 4, 7),
],
)
def test_timestamps_invalid_feature_type(invalid_timestamps: Any) -> None:
Expand Down Expand Up @@ -398,19 +399,20 @@ def test_get_features(patch: EOPatch, expected_features: list[Feature]) -> None:
assert patch.get_features() == expected_features


@pytest.mark.filterwarnings("ignore::eolearn.core.exceptions.EODeprecationWarning")
def test_timestamp_consolidation() -> None:
# 10 frames
timestamps = [
datetime.datetime(2017, 1, 1, 10, 4, 7),
datetime.datetime(2017, 1, 4, 10, 14, 5),
datetime.datetime(2017, 1, 11, 10, 3, 51),
datetime.datetime(2017, 1, 14, 10, 13, 46),
datetime.datetime(2017, 1, 24, 10, 14, 7),
datetime.datetime(2017, 2, 10, 10, 1, 32),
datetime.datetime(2017, 2, 20, 10, 6, 35),
datetime.datetime(2017, 3, 2, 10, 0, 20),
datetime.datetime(2017, 3, 12, 10, 7, 6),
datetime.datetime(2017, 3, 15, 10, 12, 14),
dt.datetime(2017, 1, 1, 10, 4, 7),
dt.datetime(2017, 1, 4, 10, 14, 5),
dt.datetime(2017, 1, 11, 10, 3, 51),
dt.datetime(2017, 1, 14, 10, 13, 46),
dt.datetime(2017, 1, 24, 10, 14, 7),
dt.datetime(2017, 2, 10, 10, 1, 32),
dt.datetime(2017, 2, 20, 10, 6, 35),
dt.datetime(2017, 3, 2, 10, 0, 20),
dt.datetime(2017, 3, 12, 10, 7, 6),
dt.datetime(2017, 3, 15, 10, 12, 14),
]

data = np.random.rand(10, 100, 100, 3)
Expand All @@ -430,7 +432,7 @@ def test_timestamp_consolidation() -> None:
good_timestamps = timestamps.copy()
del good_timestamps[0]
del good_timestamps[-1]
good_timestamps.append(datetime.datetime(2017, 12, 1))
good_timestamps.append(dt.datetime(2017, 12, 1))

removed_frames = eop.consolidate_timestamps(good_timestamps)

Expand All @@ -444,20 +446,48 @@ def test_timestamp_consolidation() -> None:
assert np.array_equal(mask_timeless, eop.mask_timeless["MASK_TIMELESS"])


def test_timestamps_deprecation():
eop = EOPatch(bbox=DUMMY_BBOX, timestamps=[datetime.datetime(1234, 5, 6)])

with pytest.warns(EODeprecationWarning):
assert eop.timestamp == [datetime.datetime(1234, 5, 6)]

with pytest.warns(EODeprecationWarning):
eop.timestamp = [datetime.datetime(4321, 5, 6)]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=EODeprecationWarning)
# so the warnings get ignored in pytest summary
assert eop.timestamp == [datetime.datetime(4321, 5, 6)]
assert eop.timestamp == eop.timestamps
@pytest.mark.parametrize(
"method_input",
[
["2017-04-08", "2017-09-17"],
[1, 2],
lambda x: (x > dt.datetime(2017, 4, 1) and x < dt.datetime(2017, 10, 10)),
],
)
def test_temporal_subset(method_input):
eop = generate_eopatch(
{
FeatureType.DATA: ["data1", "data2"],
FeatureType.MASK_TIMELESS: ["mask_timeless"],
FeatureType.SCALAR_TIMELESS: ["scalar_timeless"],
FeatureType.MASK: ["mask"],
},
timestamps=[
dt.datetime(2017, 1, 5),
dt.datetime(2017, 4, 8),
dt.datetime(2017, 9, 17),
dt.datetime(2018, 1, 5),
dt.datetime(2018, 12, 1),
],
)
vector_data = GeoDataFrame(
{TIMESTAMP_COLUMN: eop.get_timestamps()}, geometry=[eop.bbox.geometry.buffer(i) for i in range(5)], crs=32633
)
eop.vector["vector"] = vector_data
subset_timestamps = eop.timestamps[1:3]

subset_eop = eop.temporal_subset(method_input)
assert subset_eop.timestamps == subset_timestamps
for feature in eop.get_features():
if feature[0].is_timeless():
assert_feature_data_equal(eop[feature], subset_eop[feature])
elif feature[0].is_array():
assert_feature_data_equal(eop[feature][1:3, ...], subset_eop[feature])

assert_feature_data_equal(
subset_eop.vector["vector"],
vector_data[1:3],
)


def test_bbox_none_deprecation():
Expand Down
6 changes: 2 additions & 4 deletions tests/core/test_eodata_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,7 @@ def test_partial_temporal_saving_into_existing(eopatch: EOPatch, temporal_select
io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES")
eopatch.save(**io_kwargs, use_zarr=True)

partial_patch = eopatch.copy(deep=True)
partial_patch.consolidate_timestamps(np.array(partial_patch.timestamps)[temporal_selection or ...])
partial_patch = eopatch.copy(deep=True).temporal_subset(np.array(eopatch.timestamps)[temporal_selection or ...])

partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2)
partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection=temporal_selection)
Expand Down Expand Up @@ -668,8 +667,7 @@ def test_partial_temporal_saving_infer(eopatch: EOPatch):
io_kwargs = dict(path="patch-folder", filesystem=temp_fs, overwrite_permission="OVERWRITE_FEATURES")
eopatch.save(**io_kwargs, use_zarr=True)

partial_patch = eopatch.copy(deep=True)
partial_patch.consolidate_timestamps(eopatch.timestamps[1:2] + eopatch.timestamps[3:5])
partial_patch = eopatch.copy(deep=True).temporal_subset([1, 3, 4])

partial_patch.data["data"] = np.full_like(partial_patch.data["data"], 2)
partial_patch.save(**io_kwargs, use_zarr=True, temporal_selection="infer")
Expand Down
5 changes: 2 additions & 3 deletions tests/core/test_eodata_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ def test_lazy_loading(test_eopatch_path):

def test_temporally_independent_merge(test_eopatch_path):
full_patch = EOPatch.load(test_eopatch_path)
part1, part2 = full_patch.copy(deep=True), full_patch.copy(deep=True)
part1.consolidate_timestamps(full_patch.get_timestamps()[:10])
part2.consolidate_timestamps(full_patch.get_timestamps()[10:])
part1 = full_patch.copy(deep=True).temporal_subset(range(10))
part2 = full_patch.copy(deep=True).temporal_subset(range(10, len(full_patch.timestamps)))
mlubej marked this conversation as resolved.
Show resolved Hide resolved

assert full_patch == merge_eopatches(part1, part2, time_dependent_op="concatenate")
3 changes: 1 addition & 2 deletions tests/features/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ def small_ndvi_eopatch_fixture(example_eopatch: EOPatch):
ndvi = example_eopatch.data["NDVI"][:, :20, :20]
ndvi[np.isnan(ndvi)] = 0
example_eopatch.data["NDVI"] = ndvi
example_eopatch.consolidate_timestamps(example_eopatch.get_timestamps()[:10])
return example_eopatch
return example_eopatch.temporal_subset(range(10))