Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temporal subset task #757

Merged
merged 2 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions eolearn/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
RemoveFeatureTask,
RenameFeatureTask,
SaveTask,
TemporalSubsetTask,
ZipFeatureTask,
)
from .eodata import EOPatch
Expand Down
24 changes: 24 additions & 0 deletions eolearn/core/core_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,30 @@ def execute(self, src_eopatch: EOPatch, dst_eopatch: EOPatch) -> EOPatch:
return dst_eopatch


class TemporalSubsetTask(EOTask):
"""Extracts a temporal subset of the EOPatch."""

def __init__(
self, timestamps: None | list[dt.datetime] | list[int] | Callable[[list[dt.datetime]], Iterable[bool]] = None
):
"""
:param timestamps: Input for the `temporal_subset` method of EOPatch. Can also be provided in execution
arguments. Value in execution arguments takes precedence.
"""
self.timestamps = timestamps

def execute(
self,
eopatch: EOPatch,
*,
timestamps: None | list[dt.datetime] | list[int] | Callable[[list[dt.datetime]], Iterable[bool]] = None,
) -> EOPatch:
timestamps = timestamps if timestamps is not None else self.timestamps
if timestamps is None:
raise ValueError("Value for `timestamps` must be provided on initialization or as an execution argument.")
return eopatch.temporal_subset(timestamps)


class MapFeatureTask(EOTask):
"""Applies a function to each feature in input_features of a patch and stores the results in a set of
output_features.
Expand Down
10 changes: 6 additions & 4 deletions eolearn/core/eodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,14 +752,15 @@ def consolidate_timestamps(self, timestamps: list[dt.datetime]) -> set[dt.dateti
return remove_from_patch

def temporal_subset(
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool]
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[list[dt.datetime]], Iterable[bool]]
) -> EOPatch:
"""Returns an EOPatch that only contains data for the temporal subset corresponding to `timestamps`.

For array-based data appropriate temporal slices are extracted. For vector data a filtration is performed.

:param timestamps: Parameter that defines the temporal subset. Can be a collection of timestamps, a
collection of timestamp indices, or a callable that returns whether a timestamp should be kept.
collection of timestamp indices. It is possible to also provide a callable that maps a list of timestamps
to a sequence of booleans, which determine if a given timestamp is included in the subset or not.
"""
timestamp_indices = self._parse_temporal_subset_input(timestamps)
new_timestamps = [ts for i, ts in enumerate(self.get_timestamps()) if i in timestamp_indices]
Expand All @@ -777,11 +778,12 @@ def temporal_subset(
return new_patch

def _parse_temporal_subset_input(
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[dt.datetime], bool]
self, timestamps: Iterable[dt.datetime] | Iterable[int] | Callable[[list[dt.datetime]], Iterable[bool]]
) -> list[int]:
"""Parses input into a list of timestamp indices. Also adds implicit support for strings via `parse_time`."""
if callable(timestamps):
return [i for i, ts in enumerate(self.get_timestamps()) if timestamps(ts)]
accepted_timestamps = timestamps(self.get_timestamps())
return [i for i, accepted in enumerate(accepted_timestamps) if accepted]
ts_or_idx = list(timestamps)
if all(isinstance(ts, int) for ts in ts_or_idx):
return ts_or_idx # type: ignore[return-value]
Expand Down
22 changes: 22 additions & 0 deletions tests/core/test_core_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
RemoveFeatureTask,
RenameFeatureTask,
SaveTask,
TemporalSubsetTask,
ZipFeatureTask,
)
from eolearn.core.core_tasks import ExplodeBandsTask
Expand Down Expand Up @@ -277,6 +278,27 @@ def test_merge_features(axis: int, features_to_merge: list[Feature], feature: Fe
assert_array_equal(patch[feature], expected)


@pytest.mark.parametrize(
"timestamps",
[
[1, 2, 4],
[datetime(2019, 4, 2), datetime(2019, 7, 2), datetime(2019, 12, 31)],
lambda _: [False, True, True, False, True],
],
)
def test_temporal_subset_task(patch: EOPatch, timestamps):
"""The correctness is tested in the method test, so we focus on testing that parameters are passed correctly."""
task_init = TemporalSubsetTask(timestamps)
result_init = task_init.execute(patch)

task_exec = TemporalSubsetTask()
result_exec = task_exec.execute(patch, timestamps=timestamps)

assert result_init == result_exec
assert len(result_exec.get_timestamps()) == 3
assert_array_equal(result_exec.data["bands"], patch.data["bands"][[1, 2, 4]])


@pytest.mark.parametrize(
("input_features", "output_feature", "zip_function", "kwargs"),
[
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_eodata.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def test_timestamp_consolidation() -> None:
[
["2017-04-08", "2017-09-17"],
[1, 2],
lambda x: (x > dt.datetime(2017, 4, 1) and x < dt.datetime(2017, 10, 10)),
lambda dates: (dt.datetime(2017, 4, 1) < x < dt.datetime(2017, 10, 10) for x in dates),
],
)
def test_temporal_subset(method_input):
Expand Down