diff --git a/src/amplitude_experiment/cohort/cohort_loader.py b/src/amplitude_experiment/cohort/cohort_loader.py index e5d6639..14f639e 100644 --- a/src/amplitude_experiment/cohort/cohort_loader.py +++ b/src/amplitude_experiment/cohort/cohort_loader.py @@ -6,7 +6,7 @@ from .cohort import Cohort from .cohort_download_api import CohortDownloadApi from .cohort_storage import CohortStorage -from ..exception import CohortUpdateException +from ..exception import CohortsDownloadException class CohortLoader: @@ -30,39 +30,37 @@ def load_cohort(self, cohort_id: str) -> Future: def _remove_job(self, cohort_id: str): if cohort_id in self.jobs: - del self.jobs[cohort_id] + with self.lock_jobs: + self.jobs.pop(cohort_id, None) def download_cohort(self, cohort_id: str) -> Cohort: cohort = self.cohort_storage.get_cohort(cohort_id) return self.cohort_download_api.get_cohort(cohort_id, cohort) - def update_stored_cohorts(self) -> Future: - def update_task(): + def download_cohorts(self, cohort_ids: Set[str]) -> Future: + def update_task(task_cohort_ids): errors = [] - cohort_ids = self.cohort_storage.get_cohort_ids() - futures = [] - with self.lock_jobs: - for cohort_id in cohort_ids: - future = self.load_cohort(cohort_id) - futures.append(future) + for cohort_id in task_cohort_ids: + future = self.load_cohort(cohort_id) + futures.append(future) for future in as_completed(futures): - cohort_id = next(c_id for c_id, f in self.jobs.items() if f == future) try: future.result() except Exception as e: - errors.append((cohort_id, e)) + cohort_id = next((c_id for c_id, f in self.jobs.items() if f == future), None) + if cohort_id: + errors.append((cohort_id, e)) if errors: - raise CohortUpdateException(errors) + raise CohortsDownloadException(errors) - return self.executor.submit(update_task) + return self.executor.submit(update_task, cohort_ids) def __load_cohort_internal(self, cohort_id): try: cohort = self.download_cohort(cohort_id) - # None is returned when cohort is not modified if cohort is not None: self.cohort_storage.put_cohort(cohort) except Exception as e: diff --git a/src/amplitude_experiment/deployment/deployment_runner.py b/src/amplitude_experiment/deployment/deployment_runner.py index e826aed..f61b657 100644 --- a/src/amplitude_experiment/deployment/deployment_runner.py +++ b/src/amplitude_experiment/deployment/deployment_runner.py @@ -8,7 +8,9 @@ from ..flag.flag_config_api import FlagConfigApi from ..flag.flag_config_storage import FlagConfigStorage from ..local.poller import Poller -from ..util.flag_config import get_all_cohort_ids_from_flag +from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags + +COHORT_POLLING_INTERVAL_MILLIS = 60000 class DeploymentRunner: @@ -29,7 +31,7 @@ def __init__( self.lock = threading.Lock() self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update) if self.cohort_loader: - self.cohort_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, + self.cohort_poller = Poller(COHORT_POLLING_INTERVAL_MILLIS / 1000, self.__update_cohorts) self.logger = logger @@ -71,15 +73,12 @@ def __update_flag_configs(self): existing_cohort_ids = self.cohort_storage.get_cohort_ids() cohort_ids_to_download = new_cohort_ids - existing_cohort_ids - cohort_download_errors = [] # download all new cohorts - for cohort_id in cohort_ids_to_download: - try: - self.cohort_loader.load_cohort(cohort_id).result() - except Exception as e: - cohort_download_errors.append((cohort_id, str(e))) - self.logger.warning(f"Download cohort {cohort_id} failed: {e}") + try: + self.cohort_loader.download_cohorts(cohort_ids_to_download).result() + except Exception as e: + self.logger.warning(f"Error while downloading cohorts: {e}") # get updated set of cohort ids updated_cohort_ids = self.cohort_storage.get_cohort_ids() @@ -97,8 +96,9 @@ def __update_flag_configs(self): self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.") def __update_cohorts(self): + cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values())) try: - self.cohort_loader.update_stored_cohorts().result() + self.cohort_loader.download_cohorts(cohort_ids).result() except Exception as e: self.logger.warning(f"Error while updating cohorts: {e}") diff --git a/src/amplitude_experiment/exception.py b/src/amplitude_experiment/exception.py index 7281a0e..92ed0e9 100644 --- a/src/amplitude_experiment/exception.py +++ b/src/amplitude_experiment/exception.py @@ -15,7 +15,7 @@ def __init__(self, status_code, message): self.status_code = status_code -class CohortUpdateException(Exception): +class CohortsDownloadException(Exception): def __init__(self, errors): self.errors = errors super().__init__(self.__str__())