Skip to content

Commit

Permalink
Update datastore
Browse files Browse the repository at this point in the history
  • Loading branch information
Marco Mancini committed Jan 26, 2024
1 parent 81654a4 commit 95bcb82
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 53 deletions.
58 changes: 38 additions & 20 deletions datastore/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self) -> None:

@log_execution_time(_LOG)
def get_cached_product_or_read(
self, dataset_id: str, product_id: str, query: GeoQuery | None = None
self, dataset_id: str, product_id: str
) -> DataCube | Dataset:
"""Get product from the cache instead of loading files indicated in
the catalog if `metadata_caching` set to `True`.
Expand Down Expand Up @@ -81,21 +81,19 @@ def get_cached_product_or_read(
)
return self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][
product_id
].get(geoquery=query, compute=False).read_chunked()
].read_chunked()
return self.cache[dataset_id][product_id]

@log_execution_time(_LOG)
def _load_cache(self, datasets: list[str] | None = None):
if self.cache is None or datasets is None:
def _load_cache(self):
if self.cache is None:
self.cache = {}
datasets = self.dataset_list()

for i, dataset_id in enumerate(datasets):
for i, dataset_id in enumerate(self.dataset_list()):
self._LOG.info(
"loading cache for `%s` (%d/%d)",
dataset_id,
i + 1,
len(datasets),
len(self.dataset_list()),
)
self.cache[dataset_id] = {}
for product_id in self.product_list(dataset_id):
Expand All @@ -119,7 +117,7 @@ def _load_cache(self, datasets: list[str] | None = None):
dataset_id,
product_id,
exc_info=True,
)
)

@log_execution_time(_LOG)
def dataset_list(self) -> list:
Expand Down Expand Up @@ -358,9 +356,9 @@ def query(
self._LOG.debug("loading product...")
kube = self.catalog(CACHE_DIR=self.cache_dir)[dataset_id][
product_id
].get(geoquery=geoquery, compute=compute).process_with_query()
].read_chunked()
self._LOG.debug("original kube len: %s", len(kube))
return kube
return Datastore._process_query(kube, geoquery, compute)

@log_execution_time(_LOG)
def estimate(
Expand Down Expand Up @@ -391,8 +389,7 @@ def estimate(
# NOTE: we always use catalog directly and single product cache
self._LOG.debug("loading product...")
# NOTE: for estimation we use cached products
kube = self.get_cached_product_or_read(dataset_id, product_id,
query=query)
kube = self.get_cached_product_or_read(dataset_id, product_id)
self._LOG.debug("original kube len: %s", len(kube))
return Datastore._process_query(kube, geoquery, False).nbytes

Expand Down Expand Up @@ -422,10 +419,7 @@ def is_product_valid_for_role(
def _process_query(kube, query: GeoQuery, compute: None | bool = False):
if isinstance(kube, Dataset):
Datastore._LOG.debug("filtering with: %s", query.filters)
try:
kube = kube.filter(**query.filters)
except ValueError as err:
Datastore._LOG.warning("could not filter by one of the key: %s", err)
kube = kube.filter(**query.filters)
Datastore._LOG.debug("resulting kube len: %s", len(kube))
if isinstance(kube, Delayed) and compute:
kube = kube.compute()
Expand All @@ -440,9 +434,33 @@ def _process_query(kube, query: GeoQuery, compute: None | bool = False):
kube = kube.locations(**query.location)
if query.time:
Datastore._LOG.debug("subsetting by time...")
kube = kube.sel(time=query.time)
kube = kube.sel(
**{
"time": Datastore._maybe_convert_dict_slice_to_slice(
query.time
)
}
)
if query.vertical:
Datastore._LOG.debug("subsetting by vertical...")
method = None if isinstance(query.vertical, slice) else "nearest"
kube = kube.sel(vertical=query.vertical, method=method)
if isinstance(
vertical := Datastore._maybe_convert_dict_slice_to_slice(
query.vertical
),
slice,
):
method = None
else:
method = "nearest"
kube = kube.sel(vertical=vertical, method=method)
return kube.compute() if compute else kube

@staticmethod
def _maybe_convert_dict_slice_to_slice(dict_vals):
if "start" in dict_vals or "stop" in dict_vals:
return slice(
dict_vals.get("start"),
dict_vals.get("stop"),
dict_vals.get("step"),
)
return dict_vals
33 changes: 10 additions & 23 deletions datastore/dbmanager/dbmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class RequestStatus(Enum_):
"""Status of the Request"""

PENDING = auto()
QUEUED = auto()
RUNNING = auto()
DONE = auto()
FAILED = auto()
Expand Down Expand Up @@ -86,7 +85,7 @@ class User(Base):
String(255), nullable=False, unique=True, default=generate_key
)
contact_name = Column(String(255))
requests = relationship("Request", lazy="dynamic")
requests = relationship("Request")
roles = relationship("Role", secondary=association_table, lazy="selectin")


Expand All @@ -97,7 +96,7 @@ class Worker(Base):
host = Column(String(255))
dask_scheduler_port = Column(Integer)
dask_dashboard_address = Column(String(10))
created_on = Column(DateTime, default=datetime.now)
created_on = Column(DateTime, nullable=False)


class Request(Base):
Expand All @@ -113,8 +112,8 @@ class Request(Base):
product = Column(String(255))
query = Column(JSON())
estimate_size_bytes = Column(Integer)
created_on = Column(DateTime, default=datetime.now)
last_update = Column(DateTime, default=datetime.now, onupdate=datetime.now)
created_on = Column(DateTime, nullable=False)
last_update = Column(DateTime)
fail_reason = Column(String(1000))
download = relationship("Download", uselist=False, lazy="selectin")

Expand All @@ -129,7 +128,7 @@ class Download(Base):
storage_id = Column(Integer, ForeignKey("storages.storage_id"))
location_path = Column(String(255))
size_bytes = Column(Integer)
created_on = Column(DateTime, default=datetime.now)
created_on = Column(DateTime, nullable=False)


class Storage(Base):
Expand Down Expand Up @@ -268,18 +267,16 @@ def create_request(
def update_request(
self,
request_id: int,
worker_id: int | None = None,
status: RequestStatus | None = None,
worker_id: int,
status: RequestStatus,
location_path: str = None,
size_bytes: int = None,
fail_reason: str = None,
) -> int:
with self.__session_maker() as session:
request = session.query(Request).get(request_id)
if status:
request.status = status
if worker_id:
request.worker_id = worker_id
request.status = status
request.worker_id = worker_id
request.last_update = datetime.utcnow()
request.fail_reason = fail_reason
session.commit()
Expand Down Expand Up @@ -308,17 +305,7 @@ def get_request_status_and_reason(

def get_requests_for_user_id(self, user_id) -> list[Request]:
with self.__session_maker() as session:
return session.query(User).get(user_id).requests.all()

def get_requests_for_user_id_and_status(
self, user_id, status: RequestStatus | tuple[RequestStatus]
) -> list[Request]:
if isinstance(status, RequestStatus):
status = (status,)
with self.__session_maker() as session:
return session.get(User, user_id).requests.filter(
Request.status.in_(status)
)
return session.query(User).get(user_id).requests

def get_download_details_for_request_id(self, request_id) -> Download:
with self.__session_maker() as session:
Expand Down
1 change: 0 additions & 1 deletion datastore/geoquery/geoquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class GeoQuery(BaseModel, extra="allow"):
vertical: Optional[Union[float, List[float], Dict[str, float]]]
filters: Optional[Dict]
format: Optional[str]
format_args: Optional[Dict]

# TODO: Check if we are going to allow the vertical coordinates inside both
# `area`/`location` nad `vertical`
Expand Down
15 changes: 6 additions & 9 deletions datastore/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def _subset(kube: DataCube | None = None) -> DataCube:
return Datastore().query(
dataset_id=dataset_id,
product_id=product_id,
query=(
query if isinstance(query, GeoQuery) else GeoQuery(**query)
),
query=query
if isinstance(query, GeoQuery)
else GeoQuery(**query),
compute=False,
)

Expand Down Expand Up @@ -153,21 +153,18 @@ def _average(kube: DataCube | None = None) -> DataCube:
)
self._add_computational_node(task)
return self

def to_regular(
self, id: Hashable, *, dependencies: list[Hashable]
) -> "Workflow":
def _to_regular(kube: DataCube | None = None) -> DataCube:
assert (
kube is not None
), "`kube` cannot be `None` for `to_regular``"
assert kube is not None, "`kube` cannot be `None` for `to_regular``"
return kube.to_regular()

task = _WorkflowTask(
id=id, operator=_to_regular, dependencies=dependencies
)
self._add_computational_node(task)
return self
return self

def add_task(
self,
Expand Down

0 comments on commit 95bcb82

Please sign in to comment.