Skip to content

Commit

Permalink
update cohort_sync_config fields: include polling and remove request …
Browse files Browse the repository at this point in the history
…delay, use enum for serverzone, update tests accordingly
  • Loading branch information
tyiuhc committed Aug 6, 2024
1 parent 1d974f1 commit 06e693e
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 20 deletions.
13 changes: 7 additions & 6 deletions src/amplitude_experiment/cohort/cohort_download_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from ..connection_pool import HTTPConnectionPool
from ..exception import HTTPErrorResponseException, CohortTooLargeException

COHORT_REQUEST_RETRY_DELAY_MILLIS = 100


class CohortDownloadApi:

Expand All @@ -17,13 +19,11 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort:


class DirectCohortDownloadApi(CohortDownloadApi):
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, cohort_request_delay_millis: int,
server_url: str, logger: logging.Logger):
def __init__(self, api_key: str, secret_key: str, max_cohort_size: int, server_url: str, logger: logging.Logger):
super().__init__()
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_request_delay_millis = cohort_request_delay_millis
self.server_url = server_url
self.logger = logger
self.__setup_connection_pool()
Expand All @@ -48,10 +48,11 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None
group_type=cohort_info['groupType'],
)
elif response.status == 204:
self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified" )
self.logger.debug(f"getCohortMembers({cohort_id}): Cohort not modified")
return
elif response.status == 413:
raise CohortTooLargeException(f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}")
raise CohortTooLargeException(
f"Cohort exceeds max cohort size of {self.max_cohort_size}: {response.status}")
elif response.status != 202:
raise HTTPErrorResponseException(response.status,
f"Unexpected response code: {response.status}")
Expand All @@ -61,7 +62,7 @@ def get_cohort(self, cohort_id: str, cohort: Optional[Cohort]) -> Cohort or None
self.logger.debug(f"getCohortMembers({cohort_id}): request-status error {errors} - {e}")
if errors >= 3 or isinstance(e, CohortTooLargeException):
raise e
time.sleep(self.cohort_request_delay_millis/1000)
time.sleep(COHORT_REQUEST_RETRY_DELAY_MILLIS / 1000)

def _get_cohort_members_request(self, cohort_id: str, last_modified: int) -> HTTPResponse:
headers = {
Expand Down
7 changes: 4 additions & 3 deletions src/amplitude_experiment/cohort/cohort_sync_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ class CohortSyncConfig:
api_key (str): The project API Key
secret_key (str): The project Secret Key
max_cohort_size (int): The maximum cohort size that can be downloaded
cohort_request_delay_millis (int): The delay in milliseconds between cohort download requests
cohort_polling_interval_millis (int): The interval, in milliseconds, at which to poll for
cohort updates, minimum 60000
cohort_server_url (str): The server endpoint from which to request cohorts
"""

def __init__(self, api_key: str, secret_key: str, max_cohort_size: int = 2147483647,
cohort_request_delay_millis: int = 5000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL):
cohort_polling_interval_millis: int = 60000, cohort_server_url: str = DEFAULT_COHORT_SYNC_URL):
self.api_key = api_key
self.secret_key = secret_key
self.max_cohort_size = max_cohort_size
self.cohort_request_delay_millis = cohort_request_delay_millis
self.cohort_polling_interval_millis = max(cohort_polling_interval_millis, 60000)
self.cohort_server_url = cohort_server_url
4 changes: 1 addition & 3 deletions src/amplitude_experiment/deployment/deployment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from ..local.poller import Poller
from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags

COHORT_POLLING_INTERVAL_MILLIS = 60000


class DeploymentRunner:
def __init__(
Expand All @@ -31,7 +29,7 @@ def __init__(
self.lock = threading.Lock()
self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update)
if self.cohort_loader:
self.cohort_poller = Poller(COHORT_POLLING_INTERVAL_MILLIS / 1000,
self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000,
self.__update_cohorts)
self.logger = logger

Expand Down
1 change: 0 additions & 1 deletion src/amplitude_experiment/local/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, api_key: str, config: LocalEvaluationConfig = None):
cohort_download_api = DirectCohortDownloadApi(self.config.cohort_sync_config.api_key,
self.config.cohort_sync_config.secret_key,
self.config.cohort_sync_config.max_cohort_size,
self.config.cohort_sync_config.cohort_request_delay_millis,
self.config.cohort_sync_config.cohort_server_url,
self.logger)

Expand Down
2 changes: 1 addition & 1 deletion tests/cohort/cohort_download_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def response(code: int, body: dict = None):
class CohortDownloadApiTest(unittest.TestCase):

def setUp(self):
self.api = DirectCohortDownloadApi('api', 'secret', 15000, 100, "https://example.amplitude.com", mock.create_autospec(logging.Logger))
self.api = DirectCohortDownloadApi('api', 'secret', 15000, "https://example.amplitude.com", mock.create_autospec(logging.Logger))

def test_cohort_download_success(self):
cohort = Cohort(id="1234", last_modified=0, size=1, member_ids={'user'})
Expand Down
5 changes: 3 additions & 2 deletions tests/deployment/deployment_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from src.amplitude_experiment import LocalEvaluationConfig
from src.amplitude_experiment.cohort.cohort_loader import CohortLoader
from src.amplitude_experiment.cohort.cohort_sync_config import CohortSyncConfig
from src.amplitude_experiment.flag.flag_config_api import FlagConfigApi
from src.amplitude_experiment.deployment.deployment_runner import DeploymentRunner

Expand Down Expand Up @@ -41,7 +42,7 @@ def test_start_throws_if_first_flag_config_load_fails(self):
logger = mock.create_autospec(logging.Logger)
cohort_loader = CohortLoader(cohort_download_api, cohort_storage)
runner = DeploymentRunner(
LocalEvaluationConfig(),
LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')),
flag_api,
flag_config_storage,
cohort_storage,
Expand All @@ -61,7 +62,7 @@ def test_start_does_not_throw_if_cohort_load_fails(self):
logger = mock.create_autospec(logging.Logger)
cohort_loader = CohortLoader(cohort_download_api, cohort_storage)
runner = DeploymentRunner(
LocalEvaluationConfig(),
LocalEvaluationConfig(cohort_sync_config=CohortSyncConfig('api_key', 'secret_key')),
flag_api, flag_config_storage,
cohort_storage,
logger,
Expand Down
3 changes: 1 addition & 2 deletions tests/local/client_eu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def setUpClass(cls) -> None:
api_key = os.environ['EU_API_KEY']
secret_key = os.environ['EU_SECRET_KEY']
cohort_sync_config = CohortSyncConfig(api_key=api_key,
secret_key=secret_key,
cohort_request_delay_millis=100)
secret_key=secret_key)
cls._local_evaluation_client = (
LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False, server_zone=ServerZone.EU,
cohort_sync_config=cohort_sync_config)))
Expand Down
3 changes: 1 addition & 2 deletions tests/local/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def setUpClass(cls) -> None:
api_key = os.environ['API_KEY']
secret_key = os.environ['SECRET_KEY']
cohort_sync_config = CohortSyncConfig(api_key=api_key,
secret_key=secret_key,
cohort_request_delay_millis=100)
secret_key=secret_key)
cls._local_evaluation_client = (
LocalEvaluationClient(SERVER_API_KEY, LocalEvaluationConfig(debug=False,
cohort_sync_config=cohort_sync_config)))
Expand Down

0 comments on commit 06e693e

Please sign in to comment.