Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix hindcasts #49

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/anemoi/utils/hindcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def __init__(self, reference_dates, years=20):

self.reference_dates = reference_dates

if isinstance(years, list):
self.years = years
else:
self.years = range(1, years + 1)
assert isinstance(years, int), f"years must be an integer, got {years}"
assert years > 0, f"years must be greater than 0, got {years}"
self.years = years

def __iter__(self):
for reference_date in self.reference_dates:
for year in self.years:
if reference_date.month == 2 and reference_date.day == 29:
date = datetime.datetime(reference_date.year - year, 2, 28)
else:
date = datetime.datetime(reference_date.year - year, reference_date.month, reference_date.day)
year, month, day = reference_date.year, reference_date.month, reference_date.day
if (month, day) == (2, 29):
day = 28

for i in range(1, self.years + 1):
date = datetime.datetime(year - i, month, day)
yield (date, reference_date)
25 changes: 21 additions & 4 deletions src/anemoi/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __call__(self, factory):
return factory


_BY_KIND = {}


class Registry:
"""A registry of factories"""

Expand All @@ -39,6 +42,11 @@ def __init__(self, package, key="_type"):
self.registered = {}
self.kind = package.split(".")[-1]
self.key = key
_BY_KIND[self.kind] = self

@classmethod
def lookup_kind(cls, kind: str):
return _BY_KIND.get(kind)

def register(self, name: str, factory: callable = None):

Expand All @@ -47,14 +55,19 @@ def register(self, name: str, factory: callable = None):

self.registered[name] = factory

# def registered(self, name: str):
# return name in self.registered

def _load(self, file):
name, _ = os.path.splitext(file)
try:
importlib.import_module(f".{name}", package=self.package)
except Exception:
LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)

def lookup(self, name: str) -> callable:
def lookup(self, name: str, *, return_none=False) -> callable:

# print('✅✅✅✅✅✅✅✅✅✅✅✅✅', name, self.registered)
if name in self.registered:
return self.registered[name]

Expand Down Expand Up @@ -87,8 +100,12 @@ def lookup(self, name: str) -> callable:
self.registered[name] = entry_point.load()

if name not in self.registered:
if return_none:
return None

for e in self.registered:
LOG.info(f"Registered: {e}")

raise ValueError(f"Cannot load '{name}' from {self.package}")

return self.registered[name]
Expand All @@ -97,8 +114,8 @@ def create(self, name: str, *args, **kwargs):
factory = self.lookup(name)
return factory(*args, **kwargs)

def __call__(self, name: str, *args, **kwargs):
return self.create(name, *args, **kwargs)
# def __call__(self, name: str, *args, **kwargs):
# return self.create(name, *args, **kwargs)

def from_config(self, config, *args, **kwargs):
if isinstance(config, str):
Expand All @@ -125,5 +142,5 @@ def from_config(self, config, *args, **kwargs):
return self.create(key, *args, value, **kwargs)

raise ValueError(
f"Entry '{config}' must either be a string, a dictionray with a single entry, or a dictionary with a '{self.key}' key"
f"Entry '{config}' must either be a string, a dictionary with a single entry, or a dictionary with a '{self.key}' key"
)
30 changes: 0 additions & 30 deletions tests/test_dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,36 +83,6 @@ def test_date_hindcast_1():
assert len(list(d)) == 60


def test_date_hindcast_2():
d = _(
"""
- name: hindcast
reference_dates:
start: 2023-01-01
end: 2023-01-03
frequency: 24
years: [2018, 2019, 2020, 2021]
"""
)
assert len(list(d)) == 12


def test_date_hindcast_3():
d = _(
"""
- name: hindcast
reference_dates:
start: 2022-12-25 00:00:00
end: 2022-12-31 12:00:00
frequency: 12h
day_of_week: tuesday
years: [2018, 2019, 2020, 2021]
"""
)
print(list(d))
assert len(list(d)) == 8


if __name__ == "__main__":
for name, obj in list(globals().items()):
if name.startswith("test_") and callable(obj):
Expand Down
Loading