Skip to content

Commit

Permalink
Added RequestCache.wait_for
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink committed Nov 20, 2024
1 parent a257bf6 commit 984e362
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 33 deletions.
8 changes: 3 additions & 5 deletions doc/basics/requestcache_tutorial_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import create_task, run, sleep
from asyncio import run

from ipv8.requestcache import NumberCacheWithName, RequestCache

Expand Down Expand Up @@ -29,12 +29,10 @@ async def bar() -> None:
"""
# Normally, you would add this to a network overlay instance.
request_cache = RequestCache()
request_cache.register_anonymous_task("Add later", foo, request_cache, delay=1.23)

_ = create_task(foo(request_cache))
cache = await request_cache.wait_for(MyState, 0)

while not request_cache.has(MyState, 0):
await sleep(0.1)
cache = request_cache.pop(MyState, 0)
print("I found a cache with the state:", cache.state)


Expand Down
43 changes: 42 additions & 1 deletion ipv8/requestcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from contextlib import contextmanager, suppress
from random import random
from threading import Lock
from typing import TYPE_CHECKING, TypeVar, overload
from typing import TYPE_CHECKING, TypeVar, overload, cast, Union

from typing_extensions import Protocol

from .taskmanager import TaskManager
from .util import succeed

if TYPE_CHECKING:
from collections.abc import Generator, Iterable
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(self) -> None:
self._logger = logging.getLogger(self.__class__.__name__)

self._identifiers: dict[str, NumberCache] = {}
self._waiters: dict[tuple[str, int], Future] = {}
self.lock = Lock()
self._shutdown = False

Expand Down Expand Up @@ -210,6 +212,9 @@ def add(self, cache: ACT) -> ACT | None:
timeout_delay = self._timeout_override

self.register_task(cache, self._on_timeout, cache, delay=timeout_delay)
waiter = self._waiters.pop((cache.prefix, cache.number), None)
if waiter is not None and not waiter.done():
waiter.set_result(cache)
return cache

@overload
Expand All @@ -228,6 +233,42 @@ def has(self, prefix: str | type[CacheTypeVar], number: int) -> bool:
return self._create_identifier(number, prefix) in self._identifiers
return self.has(prefix.name, number)

def _watch_future(self, key: tuple[str, int]) -> None:
"""
Ensure that a given future is killed after some timeout.
"""
future = self._waiters.pop(key, None)
if future is not None and not future.done():
future.cancel()

@overload
def wait_for(self, prefix: str, number: int, timeout: float | None = None) -> Future[NumberCache]:
pass

@overload
def wait_for(self, prefix: type[CacheTypeVar], number: int, timeout: float | None = None) -> Future[CacheTypeVar]:
pass

def wait_for(self, prefix: str | type[CacheTypeVar], number: int,
timeout: float | None = None) -> Future[CacheTypeVar] | Future[NumberCache]:
"""
Returns a future that fires if or when the given cache is registered.
"""
result = self.get(prefix, number)
if result is not None:
# This is just to please ``Mypy``: ``return succeed(result)`` is functionally equivalent in this block.
if isinstance(prefix, str):
return succeed(cast(NumberCache, result))
return succeed(cast(CacheTypeVar, result))
if isinstance(prefix, str):
fut: Future[NumberCache] = Future()
self._waiters[(prefix, number)] = fut
if timeout is not None:
self.register_anonymous_task(f"Watch RequestCache Future {fut}", self._watch_future, (prefix, number),
delay=timeout)
return self.register_anonymous_task(f"RequestCache wait for {prefix}", fut)
return cast(Future[CacheTypeVar], self.wait_for(prefix.name, number))

@overload
def get(self, prefix: str, number: int) -> NumberCache | None:
pass
Expand Down
88 changes: 61 additions & 27 deletions ipv8/test/test_requestcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def on_timeout(self) -> None:
"""


class MockNamedNumberCache(NumberCache):
"""
A "normal" NumberCache that has a name.
"""

name = "test"


class TestRequestCache(TestBase):
"""
Tests related to the request cache.
Expand All @@ -115,6 +123,13 @@ def setUp(self) -> None:
super().setUp()
self.request_cache = RequestCache()

async def tearDown(self) -> None:
"""
Destroy the request cache.
"""
await self.request_cache.shutdown()
await super().tearDown()

async def test_shutdown(self) -> None:
"""
Test if RequestCache does not allow new Caches after shutdown().
Expand All @@ -134,7 +149,6 @@ async def test_timeout(self) -> None:
cache = MockCache(self.request_cache)
self.request_cache.add(cache)
await cache.timed_out
await self.request_cache.shutdown()

async def test_add_duplicate(self) -> None:
"""
Expand All @@ -145,16 +159,13 @@ async def test_add_duplicate(self) -> None:

self.assertIsNone(self.request_cache.add(cache))

await self.request_cache.shutdown()

async def test_timeout_future_default_value(self) -> None:
"""
Test if a registered future gets set to None on timeout.
"""
cache = MockRegisteredCache(self.request_cache)
self.request_cache.add(cache)
self.assertEqual(None, (await cache.timed_out))
await self.request_cache.shutdown()

async def test_timeout_future_custom_value(self) -> None:
"""
Expand All @@ -166,8 +177,6 @@ async def test_timeout_future_custom_value(self) -> None:
cache.managed_futures[0] = (cache.managed_futures[0][0], 123)
self.assertEqual(123, (await cache.timed_out))

await self.request_cache.shutdown()

async def test_timeout_future_exception(self) -> None:
"""
Test if a registered future raises an exception on timeout.
Expand All @@ -179,8 +188,6 @@ async def test_timeout_future_exception(self) -> None:
with self.assertRaises(RuntimeError):
await cache.timed_out

await self.request_cache.shutdown()

async def test_cancel_future_after_shutdown(self) -> None:
"""
Test if a registered future is cancelled when the RequestCache has shutdown.
Expand Down Expand Up @@ -211,8 +218,6 @@ async def test_passthrough_noargs(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_timeout(self) -> None:
"""
Test if passthrough respects the timeout value.
Expand All @@ -225,8 +230,6 @@ async def test_passthrough_timeout(self) -> None:

self.assertFalse(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_one_match(self) -> None:
"""
Test if passthrough filters correctly with one filter, that matches.
Expand All @@ -239,8 +242,6 @@ async def test_passthrough_filter_one_match(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_one_mismatch(self) -> None:
"""
Test if passthrough filters correctly with one filter, that doesn't match.
Expand All @@ -253,8 +254,6 @@ async def test_passthrough_filter_one_mismatch(self) -> None:

self.assertFalse(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_many_match(self) -> None:
"""
Test if passthrough filters correctly with many filters, that all match.
Expand All @@ -267,8 +266,6 @@ async def test_passthrough_filter_many_match(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_some_match(self) -> None:
"""
Test if passthrough filters correctly with many filters, for which some match.
Expand All @@ -281,8 +278,6 @@ async def test_passthrough_filter_some_match(self) -> None:

self.assertTrue(cache.timed_out)

await self.request_cache.shutdown()

async def test_passthrough_filter_no_match(self) -> None:
"""
Test if passthrough filters correctly with many filters, for which none match.
Expand All @@ -295,8 +290,6 @@ async def test_passthrough_filter_no_match(self) -> None:

self.assertFalse(cache.timed_out)

await self.request_cache.shutdown()

async def test_has_by_class(self) -> None:
"""
Check if we can call ``.has()`` by cache class.
Expand All @@ -307,7 +300,52 @@ async def test_has_by_class(self) -> None:

self.assertTrue(self.request_cache.has(MockNamedCache, added.number))

await self.request_cache.shutdown()
async def test_wait_for_by_name(self) -> None:
"""
Check if we can call ``.wait_for()`` by cache name.
"""
cache = MockNamedNumberCache(self.request_cache, MockNamedNumberCache.name, 1337)
fut = self.request_cache.wait_for(MockNamedNumberCache.name, 1337)

added = self.request_cache.add(cache)
result = await fut

self.assertEqual(added, result)

async def test_wait_for_by_class(self) -> None:
"""
Check if we can call ``.wait_for()`` by cache class.
"""
cache = MockNamedNumberCache(self.request_cache, MockNamedNumberCache.name, 1337)
fut = self.request_cache.wait_for(MockNamedNumberCache, 1337)

added = self.request_cache.add(cache)
result = await fut

self.assertEqual(added, result)

async def test_wait_for_already_added(self) -> None:
"""
Check if we can ``.wait_for()`` returns when a cache is already available.
"""
cache = MockNamedNumberCache(self.request_cache, MockNamedNumberCache.name, 1337)
added = self.request_cache.add(cache)
fut = self.request_cache.wait_for(MockNamedNumberCache.name, 1337)

result = await fut

self.assertEqual(added, result)

async def test_wait_for_timeout(self) -> None:
"""
Check if ``.wait_for()`` cancels its future when a timeout occurs.
"""
fut = self.request_cache.wait_for(MockNamedNumberCache.name, 1337, timeout=0.0)

await sleep(0)

self.assertTrue(fut.done())
self.assertTrue(fut.cancelled())

async def test_get_by_class(self) -> None:
"""
Expand All @@ -318,8 +356,6 @@ async def test_get_by_class(self) -> None:

self.assertEqual(added, self.request_cache.get(MockNamedCache, added.number))

await self.request_cache.shutdown()

async def test_pop_by_class(self) -> None:
"""
Check if we can call ``.pop()`` by cache class.
Expand All @@ -328,5 +364,3 @@ async def test_pop_by_class(self) -> None:
added = self.request_cache.add(cache)

self.assertEqual(added, self.request_cache.pop(MockNamedCache, added.number))

await self.request_cache.shutdown()

0 comments on commit 984e362

Please sign in to comment.