Skip to content

Commit

Permalink
[resotocore][fix] Close the arangoconnection on stop (#1884)
Browse files Browse the repository at this point in the history
  • Loading branch information
aquamatthias authored Jan 18, 2024
1 parent de596ad commit 01dc405
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 28 deletions.
58 changes: 31 additions & 27 deletions resotocore/resotocore/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from resotocore.db.db_access import DbAccess
from resotocore.db.system_data_db import JwtSigningKeyHolder
from resotocore.graph_manager.graph_manager import GraphManager
from resotocore.infra_apps.local_runtime import LocalResotocoreAppRuntime
from resotocore.infra_apps.package_manager import PackageManager
from resotocore.infra_apps.runtime import Runtime
from resotocore.message_bus import MessageBus
Expand Down Expand Up @@ -97,6 +96,7 @@ class Dependencies(Service):
def __init__(self, **deps: Any) -> None:
super().__init__()
self.lookup: Dict[str, Any] = deps
self.on_stop_callbacks: List[Callable[[], None]] = []

def add(self, name: str, service: T) -> "T":
self.lookup[name] = service
Expand All @@ -112,6 +112,9 @@ def get(self, name: str) -> Optional[Any]:
def all(self) -> Dict[str, Any]:
return self.lookup

def register_on_stop_callback(self, callback: Callable[[], None]) -> None:
self.on_stop_callbacks.append(callback)

@property
def services(self) -> List[Service]:
return [v for _, v in self.all().items() if isinstance(v, Service)]
Expand Down Expand Up @@ -190,6 +193,8 @@ async def stop(self) -> None:
await session.close()
for service in reversed(self.services):
await service.stop()
for callback in self.on_stop_callbacks:
callback()


class TenantDependencies(Dependencies):
Expand Down Expand Up @@ -311,30 +316,32 @@ async def _lock_for(self, key: str) -> asyncio.Lock:

async def _expire(self) -> None:
now = self._time()
to_delete = []
for key, (timestamp, value) in list(self._cache.items()):
lock = await self._lock_for(key)
async with lock:
if now - timestamp > self._ttl:
to_delete.append((key, value))
if now - timestamp > self._ttl:
lock = await self._lock_for(key)
async with lock:
log.info(f"Stop tenant dependencies for {key}")
self._cache.pop(key, None)
# content lock is not removed on purpose.
for key, value in to_delete:
log.info(f"Stop tenant dependencies for {key}")
await value.stop()
await value.stop()
log.info(f"Tenant dependencies for {key} stopped.")

async def get(self, key: str, if_empty: Callable[[], Awaitable[TenantDependencies]]) -> TenantDependencies:
now = self._time()
lck = await self._lock_for(key)
async with lck:
if result := self._cache.get(key):
_, value = result
else:
log.info(f"Create and start new tenant dependencies for {key}")
value = await if_empty()
await value.start()
self._cache[key] = (now, value)
return value
try:
async with lck:
if result := self._cache.get(key):
_, value = result
else:
log.info(f"Create and start new tenant dependencies for {key}")
value = await if_empty()
await value.start()
log.info(f"Tenant dependencies for {key} created.")
self._cache[key] = (now, value)
return value
except Exception as e:
log.exception(f"Failed to create tenant dependencies for {key}: {e}", exc_info=True)
raise


@define
Expand Down Expand Up @@ -386,13 +393,14 @@ async def create_tenant_dependencies(self, tenant_hash: str, access: GraphDbAcce
args = dp.config.args
message_bus = dp.message_bus
event_sender = dp.event_sender
deps = dp.tenant_dependencies(tenant_hash=tenant_hash, access=access)

def standard_database() -> StandardDatabase:
http_client = ArangoHTTPClient(args.graphdb_request_timeout, verify=dp.config.run.verify)
client = ArangoClient(hosts=access.server, http_client=http_client)
deps.register_on_stop_callback(client.close)
return client.db(name=access.database, username=access.username, password=access.password)

deps = self._dependencies.tenant_dependencies(tenant_hash=tenant_hash, access=access)
# direct db access
sdb = deps.add(ServiceNames.system_database, await run_async(standard_database))
db = deps.add(ServiceNames.db_access, DbAccess(sdb, dp.event_sender, NoAdjust(), config))
Expand Down Expand Up @@ -425,13 +433,9 @@ def standard_database() -> StandardDatabase:
ServiceNames.core_config_handler,
CoreConfigHandler(config, message_bus, worker_task_queue, config_handler, event_sender, inspector),
)
deps.add(ServiceNames.infra_apps_runtime, LocalResotocoreAppRuntime(cli))
deps.add(
ServiceNames.infra_apps_package_manager,
PackageManager(
db.package_entity_db, config_handler, cli.register_infra_app_alias, cli.unregister_infra_app_alias
),
)
# Enable package manager and runtime for infra apps when required
# deps.add(ServiceNames.infra_apps_runtime, LocalResotocoreAppRuntime(cli))
# deps.add(ServiceNames.infra_apps_package_manager, PackageManager(db.package_entity_db, config_handler, cli.register_infra_app_alias, cli.unregister_infra_app_alias)) # noqa
graph_merger = deps.add(ServiceNames.graph_merger, GraphMerger(model, event_sender, config, message_bus))
task_handler = deps.add(
ServiceNames.task_handler,
Expand Down
77 changes: 76 additions & 1 deletion resotocore/tests/resotocore/dependencies_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Tuple, List, Any

import pytest
from aiohttp.test_utils import make_mocked_request

from resotocore.dependencies import Dependencies, TenantDependencyCache, TenantDependencies
from resotocore.async_extensions import run_async
from resotocore.dependencies import (
Dependencies,
TenantDependencyCache,
TenantDependencies,
FromRequestTenantDependencyProvider,
ServiceNames,
)
from resotocore.service import Service
from resotocore.system_start import parse_args
from resotocore.types import JsonElement
Expand Down Expand Up @@ -34,8 +45,15 @@ async def stop(self) -> None:

@pytest.mark.asyncio
async def test_nested_dependencies() -> None:
stopped = False

def on_stop_callback() -> None:
nonlocal stopped
stopped = True

deps = Dependencies(a=ExampleService("a"), b=ExampleService("b"))
async with deps.tenant_dependencies(a=ExampleService("na"), c=ExampleService("c")) as td:
td.register_on_stop_callback(on_stop_callback)
assert len(deps.services) == 2 # manages a and b
assert len(td.services) == 2 # manages a and c
# Deps has a and b, but not c
Expand All @@ -59,19 +77,53 @@ async def test_nested_dependencies() -> None:
assert a.started is True
assert b.started is False
assert c.started is True
assert stopped is True


@pytest.mark.asyncio
async def test_dependency_cache_access_safety() -> None:
created = 0
deps = Dependencies(a=ExampleService("a"))

async def tenant_deps() -> TenantDependencies:
nonlocal created
created += 1
await asyncio.sleep(0.1)
await run_async(time.sleep, 0.1)
return deps.tenant_dependencies(b=ExampleService("b"))

async with TenantDependencyCache(timedelta(days=1), timedelta(days=1)) as cache:

async def get_concurrently() -> None:
tasks = [asyncio.create_task(cache.get("a", tenant_deps)) for _ in range(100)]
await asyncio.gather(*tasks)

# 10 worker, trying 100 times using 100 tasks to access the cache concurrently
with ThreadPoolExecutor(max_workers=10) as executor:
for _ in range(100):
executor.submit(asyncio.run, get_concurrently())

# Check that only one tenant dependency was created
assert created == 1


@pytest.mark.asyncio
async def test_dependency_cache() -> None:
time = 1
created = 0
failing = 0
deps = Dependencies(a=ExampleService("a"))

async def tenant_deps() -> TenantDependencies:
nonlocal created
created += 1
return deps.tenant_dependencies(b=ExampleService("b"))

async def failing_tenant_deps() -> Any:
nonlocal failing
failing += 1
raise Exception("failing_tenant_deps")

async with TenantDependencyCache(timedelta(seconds=1), timedelta(days=1), lambda: time) as cache:
# getting the value from the cache will create a new value
assert created == 0
Expand All @@ -95,5 +147,28 @@ async def tenant_deps() -> TenantDependencies:
# all services from td_1 are stopped, all services from td_3 are started
assert td_1.service("b", ExampleService).started is False
assert td_3.service("b", ExampleService).started is True
# a dependency which failed to create is not cached
for _ in range(10):
with pytest.raises(Exception):
await cache.get("never", failing_tenant_deps)
assert failing == 10
# when the cache is stopped, all started services are stopped
assert td_3.service("b", ExampleService).started is False


@pytest.mark.asyncio
async def test_tenant_dependency_test_creation(dependencies: Dependencies) -> None:
async with FromRequestTenantDependencyProvider(dependencies) as provider:
deps = await provider.dependencies(
make_mocked_request(
"GET",
"test",
headers=dict(
FixGraphDbServer="http://localhost:8529",
FixGraphDbDatabase="test",
FixGraphDbUsername="test",
FixGraphDbPassword="test",
),
)
)
assert deps.get(ServiceNames.cli) is not None

0 comments on commit 01dc405

Please sign in to comment.