diff --git a/snuba/state/cache/abstract.py b/snuba/state/cache/abstract.py index eea947b1c8..09181609b0 100644 --- a/snuba/state/cache/abstract.py +++ b/snuba/state/cache/abstract.py @@ -36,7 +36,6 @@ def get_readthrough( key: str, function: Callable[[], TValue], record_cache_hit_type: Callable[[int], None], - timeout: int, timer: Optional[Timer] = None, ) -> TValue: """ diff --git a/snuba/state/cache/redis/backend.py b/snuba/state/cache/redis/backend.py index feefd8a35a..a892f99170 100644 --- a/snuba/state/cache/redis/backend.py +++ b/snuba/state/cache/redis/backend.py @@ -1,27 +1,15 @@ -import concurrent.futures import logging -import random -import uuid -from concurrent.futures import ThreadPoolExecutor from typing import Callable, Optional -from pkg_resources import resource_string - -from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError +from redis.exceptions import ConnectionError, ReadOnlyError from redis.exceptions import TimeoutError as RedisTimeoutError from snuba import environment, settings from snuba.redis import RedisClientType from snuba.state import get_config -from snuba.state.cache.abstract import ( - Cache, - ExecutionError, - ExecutionTimeoutError, - TValue, -) +from snuba.state.cache.abstract import Cache, TValue from snuba.utils.codecs import ExceptionAwareCodec from snuba.utils.metrics.timer import Timer from snuba.utils.metrics.wrapper import MetricsWrapper -from snuba.utils.serializable_exception import SerializableException logger = logging.getLogger(__name__) metrics = MetricsWrapper(environment.metrics, "read_through_cache") @@ -39,21 +27,10 @@ def __init__( client: RedisClientType, prefix: str, codec: ExceptionAwareCodec[bytes, TValue], - executor: ThreadPoolExecutor, ) -> None: self.__client = client self.__prefix = prefix self.__codec = codec - self.__executor = executor - - # TODO: This should probably be lazily instantiated, rather than - # automatically happening at startup. - self.__script_get = client.register_script( - resource_string("snuba", "state/cache/redis/scripts/get.lua") - ) - self.__script_set = client.register_script( - resource_string("snuba", "state/cache/redis/scripts/set.lua") - ) def __build_key( self, key: str, prefix: Optional[str] = None, suffix: Optional[str] = None @@ -76,221 +53,11 @@ def set(self, key: str, value: TValue) -> None: ex=get_config("cache_expiry_sec", 1), ) - def __get_readthrough( - self, - key: str, - function: Callable[[], TValue], - record_cache_hit_type: Callable[[int], None], - timeout: int, - timer: Optional[Timer] = None, - ) -> TValue: - # This method is designed with the following goals in mind: - # 1. The value generation function is only executed when no value - # already exists for the key. - # 2. Only one client can execute the value generation function at a - # time (up to a deadline, at which point the client is assumed to be - # dead and its results are no longer valid.) - # 3. The other clients waiting for the result of the value generation - # function receive a result as soon as it is available. - # 4. This remains compatible with the existing get/set API (at least - # for the time being.) - - # This method shares the same keyspace as the conventional get and set - # methods, which restricts this key to only containing the cache value - # (or lack thereof.) - result_key = self.__build_key(key) - - # if we hit an error, we want to communicate that to waiting clients - # but we do not want it to be considered a true value. hence we store - # the error info in a different key - error_key = self.__build_key(key, "error") - - # The wait queue (a Redis list) is used to identify clients that are - # currently "subscribed" to the evaluation of the function and awaiting - # its result. The first member of this queue is a special case -- it is - # the client responsible for executing the function and notifying the - # subscribed clients of its completion. Only one wait queue should be - # associated with a cache key at any time. - wait_queue_key = self.__build_key(key, "tasks", "wait") - - # The task identity (a Redis bytestring) is used to store a unique - # identifier for a single task evaluation and notification cycle. Only - # one task should be associated with a cache key at any time. - task_ident_key = self.__build_key(key, "tasks") - - # The notify queue (a Redis list) is used to unblock clients that are - # waiting for the task to complete. **This implementation requires that - # the number of clients waiting for responses via the notify queue is - # no greater than the number of clients in the wait queue** (minus one - # client, who is doing the work and not waiting.) If there are more - # clients waiting for notifications than the members of the wait queue - # (for any reason), some set of clients will never be notified. To be - # safe and ensure that each client only waits for notifications from - # tasks where it was also a member of the wait queue, the notify queue - # includes the unique task identity as part of it's key. - def build_notify_queue_key(task_ident: str) -> str: - return self.__build_key(key, "tasks", f"notify/{task_ident}") - - # At this point, we have all of the information we need to figure out - # if the key exists, and if it doesn't, if we should start working or - # wait for a different client to finish the work. We have to pass the - # task creation parameters -- the timeout (execution deadline) and a - # new task identity just in case we are the first in line. - result = self.__script_get( - [result_key, wait_queue_key, task_ident_key], [timeout, uuid.uuid1().hex] - ) - - if timer is not None: - timer.mark("cache_get") - metric_tags = timer.tags if timer is not None else {} - - # This updates the stats object and querylog - record_cache_hit_type(result[0]) - - if result[0] == RESULT_VALUE: - # If we got a cache hit, this is easy -- we just return it. - logger.debug("Immediately returning result from cache hit.") - return self.__codec.decode(result[1]) - elif result[0] == RESULT_EXECUTE: - # If we were the first in line, we need to execute the function. - # We'll also get back the task identity to use for sending - # notifications and approximately how long we have to run the - # function. (In practice, these should be the same values as what - # we provided earlier.) - task_ident = result[1].decode("utf-8") - task_timeout = int(result[2]) - logger.debug( - "Executing task (%r) with %s second timeout...", - task_ident, - task_timeout, - ) - redis_key_to_write_to = result_key - - argv = [task_ident, 60] - try: - # The task is run in a thread pool so that we can return - # control to the caller once the timeout is reached. - value = self.__executor.submit(function).result(task_timeout) - argv.extend( - [self.__codec.encode(value), get_config("cache_expiry_sec", 1)] - ) - except concurrent.futures.TimeoutError as error: - metrics.increment("execute_timeout", tags=metric_tags) - raise TimeoutError("timed out while running query") from error - except Exception as e: - metrics.increment("execute_error", tags=metric_tags) - error_value = SerializableException.from_standard_exception_instance(e) - argv.extend( - [ - self.__codec.encode_exception(error_value), - # the error data only needs to be present for long enough such that - # the waiting clients know that they all have their queries rejected. - # thus we set it to only three seconds - get_config("error_cache_expiry_sec", 3), - ] - ) - # we want the result key to only store real query results in it as the TTL - # of a cached query can be fairly long (minutes). - redis_key_to_write_to = error_key - raise e - finally: - # Regardless of whether the function succeeded or failed, we - # need to mark the task as completed. If there is no result - # value, other clients will know that we raised an exception. - logger.debug("Setting result and waking blocked clients...") - try: - self.__script_set( - [ - redis_key_to_write_to, - wait_queue_key, - task_ident_key, - build_notify_queue_key(task_ident), - ], - argv, - ) - except ResponseError: - # An error response here indicates that we overran our - # deadline, or there was some other issue when trying to - # put the result value in the cache. This doesn't affect - # _our_ evaluation of the task, so log it and move on. - metrics.increment("cache_set_fail", tags=metric_tags) - logger.warning("Error setting cache result!", exc_info=True) - else: - if timer is not None: - timer.mark("cache_set") - return value - elif result[0] == RESULT_WAIT: - # If we were not the first in line, we need to wait for the first - # client to finish and populate the cache with the result value. - # We use the provided task identity to figure out where to listen - # for notifications, and the task timeout remaining informs us the - # maximum amount of time that we should expect to wait. - task_ident = result[1].decode("utf-8") - task_timeout_remaining = int(result[2]) - effective_timeout = min(task_timeout_remaining, timeout) - metrics.increment("task_waiting", tags=metric_tags) - logger.debug( - "Waiting for task result (%r) for up to %s seconds...", - task_ident, - effective_timeout, - ) - notification_received = ( - self.__client.blpop( - build_notify_queue_key(task_ident), effective_timeout - ) - is not None - ) - - if timer is not None: - timer.mark("dedupe_wait") - - if notification_received: - # There should be a value waiting for us at the result key. - raw_value, upsteam_error_payload = self.__client.mget( - [result_key, error_key] - ) - # If there is no value, that means that the client responsible - # for generating the cache value errored while generating it. - if raw_value is None: - if upsteam_error_payload: - metrics.increment("readthrough_error", tags=metric_tags) - return self.__codec.decode(upsteam_error_payload) - else: - metrics.increment("no_value_at_key", tags=metric_tags) - raise ExecutionError( - "no value at key: this means the original process executing the query crashed before the exception could be handled or an error was thrown setting the cache result" - ) - else: - return self.__codec.decode(raw_value) - else: - # We timed out waiting for the notification -- something went - # wrong with the client that was generating the cache value. - if effective_timeout == task_timeout_remaining: - # If the effective timeout was the remaining task timeout, - # this means that the client responsible for generating the - # cache value didn't do so before it promised to. - metrics.increment("notification_wait_timeout", tags=metric_tags) - raise ExecutionTimeoutError( - "result not available before execution deadline" - ) - else: - # If the effective timeout was the timeout provided to this - # method, that means that our timeout was stricter - # (smaller) than the execution timeout. The other client - # may still be working, but we're not waiting. - metrics.increment( - "notification_timeout_too_strict", tags=metric_tags - ) - raise TimeoutError("timed out waiting for result") - else: - raise ValueError("unexpected result from script") - def __get_value_with_simple_readthrough( self, key: str, function: Callable[[], TValue], record_cache_hit_type: Callable[[int], None], - timeout: int, timer: Optional[Timer] = None, ) -> TValue: record_cache_hit_type(SIMPLE_READTHROUGH) @@ -325,7 +92,6 @@ def get_readthrough( key: str, function: Callable[[], TValue], record_cache_hit_type: Callable[[int], None], - timeout: int, timer: Optional[Timer] = None, ) -> TValue: # in case something is wrong with redis, we want to be able to @@ -335,20 +101,9 @@ def get_readthrough( try: # set disable_lua_scripts to use the simple read-through cache without queueing. - sample_rate = get_config( - "read_through_cache.disable_lua_scripts_sample_rate", 0 - ) - disable_lua_scripts = sample_rate is not None and random.random() < float( - sample_rate + return self.__get_value_with_simple_readthrough( + key, function, record_cache_hit_type, timer ) - if disable_lua_scripts: - return self.__get_value_with_simple_readthrough( - key, function, record_cache_hit_type, timeout, timer - ) - else: - return self.__get_readthrough( - key, function, record_cache_hit_type, timeout, timer - ) except (ConnectionError, ReadOnlyError, RedisTimeoutError, ValueError): if settings.RAISE_ON_READTHROUGH_CACHE_REDIS_FAILURES: raise diff --git a/snuba/state/cache/redis/scripts/get.lua b/snuba/state/cache/redis/scripts/get.lua deleted file mode 100644 index 3a2369f78d..0000000000 --- a/snuba/state/cache/redis/scripts/get.lua +++ /dev/null @@ -1,36 +0,0 @@ --- KEYS[1]: The value key. --- KEYS[2]: The waiting queue key. --- KEYS[3]: The task unique ID key. --- ARGV[1]: The task execution timeout. Only used when creating a new task. --- ARGV[2]: The task unique ID. Only used when creating a new task. -local value_key = KEYS[1] -local wait_queue_key = KEYS[2] -local task_id_key = KEYS[3] -local task_timeout = ARGV[1] -local task_id = ARGV[2] - -local CODE_RESULT_VALUE = 0 -local CODE_RESULT_EXECUTE = 1 -local CODE_RESULT_WAIT = 2 - --- Check to see if a value already exists at the result key. If one does, we --- don't have to do anything other than return it and exit. -local value = redis.call('GET', value_key) -if value then - return {CODE_RESULT_VALUE, value} -end - --- Check to see if a waiting queue has already been established. If we are the --- only member of the queue, we can proceed with the task. Otherwise, we need to --- wait to be notified of task completion, or for the timeout to be reached, --- whichever comes first. -local waiting = redis.call('RPUSH', wait_queue_key, '') -if waiting == 1 then - redis.call('EXPIRE', wait_queue_key, task_timeout) - -- We shouldn't be overwriting an existing task here, but it's safe if we - -- do, given that the queue was empty. - redis.call('SETEX', task_id_key, task_timeout, task_id) - return {CODE_RESULT_EXECUTE, task_id, task_timeout} -else - return {CODE_RESULT_WAIT, redis.call('GET', task_id_key), redis.call('TTL', task_id_key)} -end diff --git a/snuba/state/cache/redis/scripts/set.lua b/snuba/state/cache/redis/scripts/set.lua deleted file mode 100644 index 2646aac083..0000000000 --- a/snuba/state/cache/redis/scripts/set.lua +++ /dev/null @@ -1,41 +0,0 @@ --- KEYS[1]: The value key. --- KEYS[2]: The waiting queue key. --- KEYS[3]: The task unique ID key. --- KEYS[4]: The notify queue key. -local value_key = KEYS[1] -local wait_queue_key = KEYS[2] -local task_id_key = KEYS[3] -local notify_queue_key = KEYS[4] - --- ARGV[1]: The task unique ID. -local task_id = ARGV[1] -local notify_queue_ttl = ARGV[2] -local value = ARGV[3] -local value_ttl = ARGV[4] --- ARGV[2]: The notify queue TTL. --- ARGV[3]: The value. (optional) --- ARGV[4]: The value TTL. (optional) - --- Check to make sure that the current task is still the task that we are --- responsible for executing. If it doesn't exist or does not match the current --- task ID, we must have overrun the timeout. --- TODO: This may still be able to safely set the cache value? -local cached_task_id = redis.call('GET', task_id_key) -if not cached_task_id or cached_task_id ~= task_id then - return {err="invalid task id"} -end - --- Update the cache value. -if value ~= nil then - redis.call('SETEX', value_key, value_ttl, value) -end - --- Move the data from the waiting queue to the notify queue. -redis.call('RENAME', wait_queue_key, notify_queue_key) -redis.call('EXPIRE', notify_queue_key, notify_queue_ttl) - --- Remove one item (representing our own entry) from the notify queue. -redis.call('LPOP', notify_queue_key) - --- Delete the task unique key. -redis.call('DEL', task_id_key) diff --git a/snuba/web/db_query.py b/snuba/web/db_query.py index f8e77fa4a4..1a983bbe29 100644 --- a/snuba/web/db_query.py +++ b/snuba/web/db_query.py @@ -3,7 +3,6 @@ import logging import random import uuid -from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial from hashlib import md5 @@ -109,7 +108,6 @@ def encode_exception(self, value: SerializableException) -> bytes: redis_cache_client, "snuba-query-cache:", ResultCacheCodec(), - ThreadPoolExecutor(), ) } # This lock prevents us from initializing the cache twice. The cache is initialized @@ -301,7 +299,6 @@ def _get_cache_partition(reader: Reader) -> Cache[Result]: redis_cache_client, f"snuba-query-cache:{partition_id}:", ResultCacheCodec(), - ThreadPoolExecutor(), ) return cache_partitions[ @@ -401,10 +398,7 @@ def execute_query_with_readthrough_caching( robust, ) - if state.get_config("disable_lua_randomize_query_id", 0): - clickhouse_query_settings["query_id"] = f"randomized-{uuid.uuid4().hex}" - else: - clickhouse_query_settings["query_id"] = query_id + clickhouse_query_settings["query_id"] = f"randomized-{uuid.uuid4().hex}" if span: span.set_data("query_id", query_id) @@ -442,33 +436,10 @@ def record_cache_hit_type(hit_type: int) -> None: robust, ), record_cache_hit_type=record_cache_hit_type, - timeout=_get_cache_wait_timeout(clickhouse_query_settings, reader), timer=timer, ) -def _get_cache_wait_timeout( - query_settings: MutableMapping[str, Any], reader: Reader -) -> int: - """ - Helper function to determine how long a query should wait when doing - a readthrough caching. - - The overrides are primarily used for debugging the ExecutionTimeoutError - raised by the readthrough caching system on the tigers cluster. When we - have root caused the problem we can remove the overrides. - """ - cache_wait_timeout: int = int(query_settings.get("max_execution_time", 30)) - if reader.cache_partition_id and reader.cache_partition_id in { - "tiger_errors", - "tiger_transactions", - }: - tiger_wait_timeout_config = state.get_config("tiger-cache-wait-time") - if tiger_wait_timeout_config: - cache_wait_timeout = tiger_wait_timeout_config - return cache_wait_timeout - - def _get_query_settings_from_config( override_prefix: Optional[str], async_override: bool, diff --git a/tests/state/test_cache.py b/tests/state/test_cache.py index e9179b604d..4ea2d860c8 100644 --- a/tests/state/test_cache.py +++ b/tests/state/test_cache.py @@ -2,10 +2,9 @@ import random import time -from concurrent.futures import Future, ThreadPoolExecutor -from functools import partial +from concurrent.futures import Future from threading import Thread -from typing import Any, Callable, cast +from typing import Any, Callable from unittest import mock import pytest @@ -13,9 +12,9 @@ from sentry_redis_tools.failover_redis import FailoverRedis from redis.exceptions import ReadOnlyError -from snuba.redis import RedisClientKey, RedisClientType, get_redis_client +from snuba.redis import RedisClientKey, get_redis_client from snuba.state import set_config -from snuba.state.cache.abstract import Cache, ExecutionError, ExecutionTimeoutError +from snuba.state.cache.abstract import Cache from snuba.state.cache.redis.backend import RedisCache from snuba.utils.codecs import ExceptionAwareCodec from snuba.utils.serializable_exception import ( @@ -80,9 +79,7 @@ def encode_exception(self, value: SerializableException) -> bytes: @pytest.fixture def backend() -> Cache[bytes]: codec = PassthroughCodec() - backend: Cache[bytes] = RedisCache( - redis_client, "test", codec, ThreadPoolExecutor() - ) + backend: Cache[bytes] = RedisCache(redis_client, "test", codec) return backend @@ -94,15 +91,13 @@ class BadClient(FailoverRedis): def __init__(self, client: Any) -> None: self._client = client - def evalsha(self, *args: str, **kwargs: str) -> None: + def get(self, *args: str, **kwargs: str) -> None: raise ReadOnlyError("Failed") def __getattr__(self, attr: str) -> Any: return getattr(self._client, attr) - backend: Cache[bytes] = RedisCache( - BadClient(redis_client), "test", codec, ThreadPoolExecutor() - ) + backend: Cache[bytes] = RedisCache(BadClient(redis_client), "test", codec) return backend @@ -120,12 +115,12 @@ def test_short_circuit(backend: Cache[bytes]) -> None: assert backend.get(key) is None with assert_changes(lambda: function.call_count, 0, 1): - backend.get_readthrough(key, function, noop, 5) == value + backend.get_readthrough(key, function, noop) == value assert backend.get(key) is None with assert_changes(lambda: function.call_count, 1, 2): - backend.get_readthrough(key, function, noop, 5) == value + backend.get_readthrough(key, function, noop) == value @pytest.mark.redis_db @@ -134,42 +129,7 @@ def test_fail_open(bad_backend: Cache[bytes]) -> None: value = b"value" function = mock.MagicMock(return_value=value) with mock.patch("snuba.settings.RAISE_ON_READTHROUGH_CACHE_REDIS_FAILURES", False): - assert bad_backend.get_readthrough(key, function, noop, 5) == value - - -@pytest.mark.redis_db -def test_get_readthrough_with_disable_lua_scripts(backend: Cache[bytes]) -> None: - set_config("read_through_cache.disable_lua_scripts_sample_rate", 1) - key = "key" - value = b"value" - function = mock.MagicMock(return_value=value) - - assert backend.get(key) is None - - with assert_changes(lambda: function.call_count, 0, 1): - assert backend.get_readthrough(key, function, noop, 5) == value - - assert backend.get(key) == value - - with assert_does_not_change(lambda: function.call_count, 1): - assert backend.get_readthrough(key, function, noop, 5) == value - - -@pytest.mark.redis_db -def test_get_readthrough_exception_with_disable_lua_scripts( - backend: Cache[bytes], -) -> None: - set_config("read_through_cache.disable_lua_scripts_sample_rate", 1) - key = "key" - - class CustomException(SerializableException): - pass - - def function() -> bytes: - raise CustomException("error") - - with pytest.raises(CustomException): - backend.get_readthrough(key, SingleCallFunction(function), noop, 1) + assert bad_backend.get_readthrough(key, function, noop) == value @pytest.mark.redis_db @@ -181,27 +141,12 @@ def test_get_readthrough(backend: Cache[bytes]) -> None: assert backend.get(key) is None with assert_changes(lambda: function.call_count, 0, 1): - backend.get_readthrough(key, function, noop, 5) == value + backend.get_readthrough(key, function, noop) == value assert backend.get(key) == value with assert_does_not_change(lambda: function.call_count, 1): - backend.get_readthrough(key, function, noop, 5) == value - - -@pytest.mark.redis_db -def test_get_readthrough_missed_deadline(backend: Cache[bytes]) -> None: - key = "key" - value = b"value" - - def function() -> bytes: - time.sleep(1.5) - return value - - with pytest.raises(TimeoutError): - backend.get_readthrough(key, function, noop, 1) - - assert backend.get(key) is None + backend.get_readthrough(key, function, noop) == value @pytest.mark.redis_db @@ -215,7 +160,7 @@ def function() -> bytes: raise CustomException("error") with pytest.raises(CustomException): - backend.get_readthrough(key, SingleCallFunction(function), noop, 1) + backend.get_readthrough(key, SingleCallFunction(function), noop) @pytest.mark.redis_db @@ -227,12 +172,12 @@ def function() -> bytes: return f"{random.random()}".encode("utf-8") def worker() -> bytes: - return backend.get_readthrough(key, function, noop, 10) + return backend.get_readthrough(key, function, noop) - setter = execute(worker) - waiter = execute(worker) + setter = worker() + waiter = worker() - assert setter.result() == waiter.result() + assert setter == waiter @pytest.mark.redis_db @@ -246,7 +191,7 @@ def function() -> bytes: raise ReadThroughCustomException("error") def worker() -> bytes: - return backend.get_readthrough(key, SingleCallFunction(function), noop, 10) + return backend.get_readthrough(key, SingleCallFunction(function), noop) setter = execute(worker) time.sleep(0.5) @@ -267,39 +212,11 @@ def worker() -> bytes: "backend", [ pytest.param( - RedisCache(redis_client, "test", PassthroughCodec(), ThreadPoolExecutor()), + RedisCache(redis_client, "test", PassthroughCodec()), id="regular cluster", ), ], ) -@pytest.mark.redis_db -def test_get_readthrough_set_wait_timeout(backend: Cache[bytes]) -> None: - key = "key" - value = b"value" - - def function(id: int) -> bytes: - time.sleep(2.5) - return value + f"{id}".encode() - - def worker(timeout: int) -> bytes: - return backend.get_readthrough(key, partial(function, timeout), noop, timeout) - - setter = execute(partial(worker, 2)) - time.sleep(0.1) - waiter_fast = execute(partial(worker, 1)) - time.sleep(0.1) - waiter_slow = execute(partial(worker, 3)) - - with pytest.raises(TimeoutError): - assert setter.result() - - with pytest.raises(TimeoutError): - waiter_fast.result() - - with pytest.raises((ExecutionError, ExecutionTimeoutError)): - waiter_slow.result() - - @pytest.mark.redis_db def test_transient_error(backend: Cache[bytes]) -> None: key = "key" @@ -314,14 +231,10 @@ def normal_function() -> bytes: return b"hello" def transient_error() -> bytes: - return backend.get_readthrough( - key, SingleCallFunction(error_function), noop, 10 - ) + return backend.get_readthrough(key, SingleCallFunction(error_function), noop) def functioning_query() -> bytes: - return backend.get_readthrough( - key, SingleCallFunction(normal_function), noop, 10 - ) + return backend.get_readthrough(key, SingleCallFunction(normal_function), noop) setter = execute(transient_error) # if this sleep were removed, the waiter would also raise @@ -334,67 +247,3 @@ def functioning_query() -> bytes: setter.result() assert waiter.result() == b"hello" - - -@pytest.mark.redis_db -def test_notify_queue_ttl() -> None: - # Tests that waiting clients can be notified of the cache status - # even with network delays. This test will break if the notify queue - # TTL is set below 200ms - - pop_calls = 0 - num_waiters = 9 - - class DelayedRedisClient: - def __init__(self, redis_client: RedisClientType) -> None: - self._client = redis_client - - def __getattr__(self, attr: str) -> Any: - # simulate the queue pop taking longer than expected. - # the notification queue TTL is 60 seconds so running into a timeout - # shouldn't happen (unless something has drastically changed in the TTL - # time or use) - if attr == "blpop": - nonlocal pop_calls - pop_calls += 1 - time.sleep(0.5) - return getattr(self._client, attr) - - codec = PassthroughCodec() - - delayed_backend: Cache[bytes] = RedisCache( - cast(RedisClientType, DelayedRedisClient(redis_client)), - "test", - codec, - ThreadPoolExecutor(), - ) - key = "key" - - def normal_function() -> bytes: - # this sleep makes sure that all waiting clients - # are put into the waiting queue - time.sleep(0.5) - return b"hello-cached" - - def normal_function_uncached() -> bytes: - return b"hello-not-cached" - - def cached_query() -> bytes: - return delayed_backend.get_readthrough(key, normal_function, noop, 10) - - def uncached_query() -> bytes: - return delayed_backend.get_readthrough(key, normal_function_uncached, noop, 10) - - setter = execute(cached_query) - waiters = [] - time.sleep(0.1) - for _ in range(num_waiters): - waiters.append(execute(uncached_query)) - - # make sure that all clients actually did hit the cache - assert setter.result() == b"hello-cached" - for w in waiters: - assert w.result() == b"hello-cached" - # make sure that all the waiters actually did hit the notification queue - # and didn't just get a direct cache hit - assert pop_calls == num_waiters diff --git a/tests/web/test_cache_partitions.py b/tests/web/test_cache_partitions.py index f11fd64a16..82884b3780 100644 --- a/tests/web/test_cache_partitions.py +++ b/tests/web/test_cache_partitions.py @@ -1,8 +1,7 @@ import pytest -from snuba import state from snuba.clickhouse.native import ClickhousePool, NativeDriverReader -from snuba.web.db_query import _get_cache_partition, _get_cache_wait_timeout +from snuba.web.db_query import _get_cache_partition @pytest.mark.redis_db @@ -23,23 +22,3 @@ def test_cache_partition() -> None: assert id(nondefault_cache) == id(another_nondefault_cache) assert id(default_cache) != id(nondefault_cache) - - -@pytest.mark.redis_db -def test_cache_wait_timeout() -> None: - pool = ClickhousePool("127.0.0.1", 9000, "", "", "") - default_reader = NativeDriverReader(None, pool, None) - tiger_errors_reader = NativeDriverReader("tiger_errors", pool, None) - tiger_transactions_reader = NativeDriverReader("tiger_transactions", pool, None) - - query_settings = {"max_execution_time": 30} - assert _get_cache_wait_timeout(query_settings, default_reader) == 30 - assert _get_cache_wait_timeout(query_settings, tiger_errors_reader) == 30 - assert _get_cache_wait_timeout(query_settings, tiger_transactions_reader) == 30 - - state.set_config("tiger-cache-wait-time", 60) - assert _get_cache_wait_timeout(query_settings, default_reader) == 30 - assert _get_cache_wait_timeout(query_settings, tiger_errors_reader) == 60 - assert _get_cache_wait_timeout(query_settings, tiger_transactions_reader) == 60 - - state.delete_config("tiger-cache-wait-time") diff --git a/tests/web/test_db_query.py b/tests/web/test_db_query.py index 8bb7859b11..997d0bbabc 100644 --- a/tests/web/test_db_query.py +++ b/tests/web/test_db_query.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Dict, Mapping, MutableMapping, Optional +from typing import Any, Mapping, MutableMapping, Optional from unittest import mock import pytest @@ -33,12 +33,7 @@ from snuba.utils.metrics.backends.testing import get_recorded_metric_calls from snuba.utils.metrics.timer import Timer from snuba.web import QueryException -from snuba.web.db_query import ( - _get_cache_partition, - _get_query_settings_from_config, - db_query, - execute_query_with_readthrough_caching, -) +from snuba.web.db_query import _get_query_settings_from_config, db_query test_data = [ pytest.param( @@ -885,52 +880,6 @@ def test_db_query_ignore_consistent() -> None: assert result.extra["stats"]["max_threads"] == 5 -@pytest.mark.redis_db -@pytest.mark.clickhouse_db -@pytest.mark.parametrize( - "disable_lua_randomize_query_id, disable_lua_scripts_sample_rate, expected_startswith, test_cache_hit_simple", - [ - (0, 0, "test_query_id", False), - (1, 1, "randomized-", True), - ], -) -def test_clickhouse_settings_applied_to_query_id( - disable_lua_randomize_query_id: int, - disable_lua_scripts_sample_rate: int, - expected_startswith: str, - test_cache_hit_simple: bool, -) -> None: - query, storage, attribution_info = _build_test_query("count(distinct(project_id))") - state.set_config("disable_lua_randomize_query_id", disable_lua_randomize_query_id) - state.set_config( - "read_through_cache.disable_lua_scripts_sample_rate", - disable_lua_scripts_sample_rate, - ) - - formatted_query = format_query(query) - reader = storage.get_cluster().get_reader() - clickhouse_query_settings: Dict[str, Any] = {} - query_id = "test_query_id" - stats: dict[str, Any] = {} - - execute_query_with_readthrough_caching( - clickhouse_query=query, - query_settings=HTTPQuerySettings(), - formatted_query=formatted_query, - reader=reader, - timer=Timer("foo"), - stats=stats, - clickhouse_query_settings=clickhouse_query_settings, - robust=False, - query_id=query_id, - referrer="test", - ) - - assert ("cache_hit_simple" in stats) == test_cache_hit_simple - assert clickhouse_query_settings["query_id"].startswith(expected_startswith) - assert _get_cache_partition(reader).get("test_query_id") is not None - - @pytest.mark.clickhouse_db @pytest.mark.redis_db def test_cache_metrics_with_simple_readthrough() -> None: diff --git a/tests/web/test_query_cache.py b/tests/web/test_query_cache.py deleted file mode 100644 index 552862a5a2..0000000000 --- a/tests/web/test_query_cache.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Any, Callable - -import pytest -from clickhouse_driver.errors import ErrorCodes - -from snuba.attribution import get_app_id -from snuba.attribution.attribution_info import AttributionInfo -from snuba.clickhouse.errors import ClickhouseError -from snuba.datasets.entities.entity_key import EntityKey -from snuba.datasets.entities.factory import get_entity -from snuba.datasets.factory import get_dataset -from snuba.query import SelectedExpression -from snuba.query.conditions import in_condition -from snuba.query.data_source.simple import Entity -from snuba.query.expressions import Column, Literal -from snuba.query.logical import Query -from snuba.query.query_settings import HTTPQuerySettings -from snuba.request import Request -from snuba.utils.metrics.timer import Timer -from snuba.web import QueryException -from snuba.web.query import run_query as _run_query - - -def run_query() -> None: - events_storage = get_entity(EntityKey.EVENTS).get_writable_storage() - assert events_storage is not None - - query = Query( - Entity(EntityKey.EVENTS, get_entity(EntityKey.EVENTS).get_data_model()), - selected_columns=[ - SelectedExpression("event_id", Column("_snuba_event_id", None, "event_id")), - ], - condition=in_condition(Column(None, None, "project_id"), [Literal(None, 123)]), - ) - - query_settings = HTTPQuerySettings(referrer="asd") - - dataset = get_dataset("events") - timer = Timer("test") - - result = _run_query( - dataset, - Request( - id="asd", - original_body={}, - query=query, - query_settings=query_settings, - attribution_info=AttributionInfo( - get_app_id("blah"), - {"referrer": "r", "organization_id": 1234}, - "blah", - None, - None, - None, - ), - ), - timer, - ) - - assert result.result["data"] == [] - - -@pytest.mark.clickhouse_db -@pytest.mark.redis_db -def test_cache_retries_on_bad_query_id( - monkeypatch: pytest.MonkeyPatch, snuba_set_config: Callable[[str, Any], None] -) -> None: - from snuba.web import db_query - - calls = [] - - old_excecute_query_with_rate_limits = db_query.execute_query_with_rate_limits - - def execute_query_with_rate_limits(*args: Any) -> Any: - calls.append(args[-2]["query_id"]) - - if len(calls) == 1: - raise ClickhouseError( - "duplicate query!", - True, - code=ErrorCodes.QUERY_WITH_SAME_ID_IS_ALREADY_RUNNING, - ) - - return old_excecute_query_with_rate_limits(*args) - - monkeypatch.setattr( - db_query, "execute_query_with_rate_limits", execute_query_with_rate_limits - ) - - with pytest.raises(QueryException): - run_query() - - assert len(calls) == 1 - calls.clear() - - snuba_set_config("retry_duplicate_query_id", True) - - run_query() - - assert len(calls) == 2 - assert "randomized" not in calls[0] - assert "randomized" in calls[1]