Skip to content

Commit

Permalink
fix: Refresh sql registry in background (#42)
Browse files Browse the repository at this point in the history
* fix: Refresh sql registry in background and project creation moved to init

* enabled http registry to use cache

* stopping tests for sqlite as its not production grade database

---------

Co-authored-by: Bhargav Dodla <[email protected]>
  • Loading branch information
EXPEbdodla and Bhargav Dodla authored Sep 30, 2023
1 parent aaab905 commit a8c8dbf
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 50 deletions.
13 changes: 11 additions & 2 deletions sdk/python/feast/infra/registry/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(
headers={"Content-Type": "application/json"},
)
self.project = project
self.apply_project(self.project)
self.cached_registry_proto = self.proto()
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
self.cached_registry_proto_created = datetime.utcnow()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
Expand All @@ -102,6 +102,15 @@ def _send_request(self, method: str, url: str, params=None, data=None):
except Exception as exception:
self._handle_exception(exception)

def apply_project(self, project: str, commit: bool = True) -> ProjectMetadataModel:
try:
url = f"{self.base_url}/projects"
params = {"project": project, "commit": commit}
response_data = self._send_request("PUT", url, params=params)
return ProjectMetadataModel.parse_obj(response_data)
except Exception as exception:
self._handle_exception(exception)

def apply_entity(self, entity: Entity, project: str, commit: bool = True):
try:
url = f"{self.base_url}/projects/{project}/entities"
Expand Down Expand Up @@ -608,7 +617,7 @@ def proto(self) -> RegistryProto:
(self.list_validation_references, r.validation_references),
(self.list_project_metadata, r.project_metadata),
]:
objs: List[Any] = lister(project, False) # type: ignore
objs: List[Any] = lister(project, True) # type: ignore
if objs:
obj_protos = [obj.to_proto() for obj in objs]
for obj_proto in obj_protos:
Expand Down
118 changes: 87 additions & 31 deletions sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import concurrent.futures
import logging
import threading
import time
import uuid
from datetime import datetime, timedelta
from enum import Enum
Expand Down Expand Up @@ -184,6 +187,9 @@ class FeastMetadataKeys(Enum):

logger = logging.getLogger(__name__)

CACHE_REFRESH_THRESHOLD_SECONDS = 300
MAX_WORKERS = 5


class SqlRegistryConfig(RegistryConfig):
registry_type: StrictStr = "sql"
Expand All @@ -198,8 +204,8 @@ class SqlRegistry(BaseRegistry):
def __init__(
self,
registry_config: Optional[Union[RegistryConfig, SqlRegistryConfig]],
project: str,
repo_path: Optional[Path],
project: str = None,
repo_path: Optional[Path] = None,
):
assert registry_config is not None, "SqlRegistry needs a valid registry_config"
# pool_recycle will recycle connections after the given number of seconds has passed
Expand All @@ -208,16 +214,33 @@ def __init__(
registry_config.path, echo=False, pool_recycle=3600
)
metadata.create_all(self.engine)
self.project = project
if project is not None:
self.create_project_if_not_exists(self.project)
self.cached_registry_proto = self.proto()
proto_registry_utils.init_project_metadata(self.cached_registry_proto, project)
self.cached_registry_proto_created = datetime.utcnow()
self._refresh_lock = Lock()
self.cached_registry_proto_ttl = timedelta(
seconds=registry_config.cache_ttl_seconds
if registry_config.cache_ttl_seconds is not None
else 0
)
self.project = project
self.stop_thread = False
self.refresh_cache_thread = threading.Thread(target=self._refresh_cache)
self.refresh_cache_thread.start()

def _refresh_cache(self):
while not self.stop_thread:
self.refresh()
# Sleep for cached_registry_proto_ttl - 10 seconds
time.sleep(self.cached_registry_proto_ttl.total_seconds() - 10)

def close(self):
self.stop_thread = True
self.refresh_cache_thread.join()

def __del__(self):
self.close()

def teardown(self):
for t in {
Expand All @@ -234,6 +257,8 @@ def teardown(self):
stmt = delete(t)
conn.execute(stmt)

self.close()

def refresh(self, project: Optional[str] = None):
if project:
project_metadata = proto_registry_utils.get_project_metadata(
Expand Down Expand Up @@ -269,11 +294,30 @@ def _refresh_cached_registry_if_necessary(self):
logger.info("Registry cache expired, so refreshing")
self.refresh()

def _check_if_registry_refreshed(self):
if (
self.cached_registry_proto is None
or self.cached_registry_proto_created is None
) or (
self.cached_registry_proto_ttl.total_seconds() > 0 # 0 ttl means infinity
and (
datetime.utcnow()
> (self.cached_registry_proto_created + self.cached_registry_proto_ttl)
)
):
seconds_since_last_refresh = (
datetime.utcnow() - self.cached_registry_proto_created
).total_seconds()
if seconds_since_last_refresh > CACHE_REFRESH_THRESHOLD_SECONDS:
logger.warning(
f"Cache is stale: {seconds_since_last_refresh} seconds since last refresh"
)

def get_stream_feature_view(
self, name: str, project: str, allow_cache: bool = False
):
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_stream_feature_view(
self.cached_registry_proto, name, project
)
Expand All @@ -292,7 +336,7 @@ def list_stream_feature_views(
self, project: str, allow_cache: bool = False
) -> List[StreamFeatureView]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_stream_feature_views(
self.cached_registry_proto, project
)
Expand All @@ -304,6 +348,10 @@ def list_stream_feature_views(
"feature_view_proto",
)

def apply_project(self, project: str, commit: bool) -> ProjectMetadataModel:
self.create_project_if_not_exists(project)
return self.get_project_metadata(project)

def apply_entity(self, entity: Entity, project: str, commit: bool = True):
return self._apply_object(
table=entities,
Expand All @@ -315,7 +363,7 @@ def apply_entity(self, entity: Entity, project: str, commit: bool = True):

def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_entity(
self.cached_registry_proto, name, project
)
Expand All @@ -334,7 +382,7 @@ def get_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureView:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_feature_view(
self.cached_registry_proto, name, project
)
Expand All @@ -353,7 +401,7 @@ def get_on_demand_feature_view(
self, name: str, project: str, allow_cache: bool = False
) -> OnDemandFeatureView:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_on_demand_feature_view(
self.cached_registry_proto, name, project
)
Expand All @@ -372,7 +420,7 @@ def get_request_feature_view(
self, name: str, project: str, allow_cache: bool = False
):
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_request_feature_view(
self.cached_registry_proto, name, project
)
Expand All @@ -391,7 +439,7 @@ def get_feature_service(
self, name: str, project: str, allow_cache: bool = False
) -> FeatureService:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_feature_service(
self.cached_registry_proto, name, project
)
Expand All @@ -410,7 +458,7 @@ def get_saved_dataset(
self, name: str, project: str, allow_cache: bool = False
) -> SavedDataset:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_saved_dataset(
self.cached_registry_proto, name, project
)
Expand All @@ -429,7 +477,7 @@ def get_validation_reference(
self, name: str, project: str, allow_cache: bool = False
) -> ValidationReference:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_validation_reference(
self.cached_registry_proto, name, project
)
Expand All @@ -448,7 +496,7 @@ def list_validation_references(
self, project: str, allow_cache: bool = False
) -> List[ValidationReference]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_validation_references(
self.cached_registry_proto, project
)
Expand All @@ -462,7 +510,7 @@ def list_validation_references(

def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_entities(
self.cached_registry_proto, project
)
Expand Down Expand Up @@ -502,7 +550,7 @@ def get_data_source(
self, name: str, project: str, allow_cache: bool = False
) -> DataSource:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.get_data_source(
self.cached_registry_proto, name, project
)
Expand All @@ -521,7 +569,7 @@ def list_data_sources(
self, project: str, allow_cache: bool = False
) -> List[DataSource]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_data_sources(
self.cached_registry_proto, project
)
Expand Down Expand Up @@ -570,7 +618,7 @@ def list_feature_services(
self, project: str, allow_cache: bool = False
) -> List[FeatureService]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_feature_services(
self.cached_registry_proto, project
)
Expand All @@ -586,7 +634,7 @@ def list_feature_views(
self, project: str, allow_cache: bool = False
) -> List[FeatureView]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_feature_views(
self.cached_registry_proto, project
)
Expand All @@ -598,7 +646,7 @@ def list_saved_datasets(
self, project: str, allow_cache: bool = False
) -> List[SavedDataset]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_saved_datasets(
self.cached_registry_proto, project
)
Expand All @@ -614,7 +662,7 @@ def list_request_feature_views(
self, project: str, allow_cache: bool = False
) -> List[RequestFeatureView]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_request_feature_views(
self.cached_registry_proto, project
)
Expand All @@ -630,7 +678,7 @@ def list_on_demand_feature_views(
self, project: str, allow_cache: bool = False
) -> List[OnDemandFeatureView]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_on_demand_feature_views(
self.cached_registry_proto, project
)
Expand All @@ -646,7 +694,7 @@ def list_project_metadata(
self, project: str, allow_cache: bool = False
) -> List[ProjectMetadata]:
if allow_cache:
self._refresh_cached_registry_if_necessary()
self._check_if_registry_refreshed()
return proto_registry_utils.list_project_metadata(
self.cached_registry_proto, project
)
Expand Down Expand Up @@ -836,8 +884,9 @@ def get_user_metadata(
def proto(self) -> RegistryProto:
r = RegistryProto()
last_updated_timestamps = []
projects = self._get_all_projects()
for project in projects:

def process_project(project):
nonlocal r, last_updated_timestamps
for lister, registry_proto_field in [
(self.list_entities, r.entities),
(self.list_feature_views, r.feature_views),
Expand Down Expand Up @@ -865,6 +914,17 @@ def proto(self) -> RegistryProto:
r.infra.CopyFrom(self.get_infra(project).to_proto())
last_updated_timestamps.append(self._get_last_updated_metadata(project))

if self.project is None:
projects = self._get_all_projects()
else:
projects = set([self.project])

# Use a ThreadPoolExecutor to process projects concurrently
with concurrent.futures.ThreadPoolExecutor(
max_workers=MAX_WORKERS
) as executor: # Adjust max_workers as needed
executor.map(process_project, projects)

if last_updated_timestamps:
r.last_updated.FromDatetime(max(last_updated_timestamps))

Expand All @@ -883,8 +943,6 @@ def _apply_object(
proto_field_name: str,
name: Optional[str] = None,
):
self._maybe_init_project_metadata(project)

name = name or (obj.name if hasattr(obj, "name") else None)
assert name, f"name needs to be provided for {obj}"

Expand Down Expand Up @@ -932,7 +990,7 @@ def _apply_object(

self._set_last_updated_metadata(update_datetime, project)

def _maybe_init_project_metadata(self, project):
def create_project_if_not_exists(self, project):
# Initialize project metadata if needed
with self.engine.connect() as conn:
update_datetime = datetime.utcnow()
Expand All @@ -955,6 +1013,7 @@ def _maybe_init_project_metadata(self, project):
insert_stmt = insert(feast_metadata).values(values)
conn.execute(insert_stmt)
usage.set_current_project_uuid(new_project_uuid)
self._set_last_updated_metadata(update_datetime, project)

def _delete_object(
self,
Expand Down Expand Up @@ -986,8 +1045,6 @@ def _get_object(
proto_field_name: str,
not_found_exception: Optional[Callable],
):
self._maybe_init_project_metadata(project)

with self.engine.connect() as conn:
stmt = select(table).where(
getattr(table.c, id_field_name) == name, table.c.project_id == project
Expand All @@ -1009,7 +1066,6 @@ def _list_objects(
python_class: Any,
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with self.engine.connect() as conn:
stmt = select(table).where(table.c.project_id == project)
rows = conn.execute(stmt).all()
Expand Down
Loading

0 comments on commit a8c8dbf

Please sign in to comment.