From 43d1935b2a717923c81698d3d091d85a3e5a07d9 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Wed, 20 Nov 2024 17:51:46 +0900 Subject: [PATCH] chore: type hints for wait_for_idle specifically --- juju/model.py | 60 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/juju/model.py b/juju/model.py index e95e6248..81648a89 100644 --- a/juju/model.py +++ b/juju/model.py @@ -28,8 +28,7 @@ from .annotationhelper import _get_annotations, _set_annotations from .bundle import BundleHandler, get_charm_series, is_local_charm from .charmhub import CharmHub -from .client import client, connector -from .client.connection import Connection +from .client import client, connection, connector from .client.overrides import Caveat, Macaroon from .constraints import parse as parse_constraints from .controller import ConnectedController, Controller @@ -60,6 +59,7 @@ if TYPE_CHECKING: from .application import Application + from .client._definitions import FullStatus from .machine import Machine from .relation import Relation from .remoteapplication import ApplicationOffer, RemoteApplication @@ -267,7 +267,9 @@ def apply_delta(self, delta): entity = self.get_entity(delta.entity, delta.get_id()) return entity.previous(), entity - def get_entity(self, entity_type, entity_id, history_index=-1, connected=True): + def get_entity( + self, entity_type, entity_id, history_index=-1, connected=True + ) -> ModelEntity | None: """Return an object instance for the given entity_type and id. By default the object state matches the most recent state from @@ -295,6 +297,11 @@ class ModelEntity: """An object in the Model tree""" entity_id: str + model: Model + _history_index: int + connected: bool + connection: connection.Connection + _status: str def __init__( self, @@ -616,6 +623,9 @@ async def resolve( class Model: """The main API for interacting with a Juju model.""" + connector: connector.Connector + state: ModelState + def __init__( self, max_frame_size=None, @@ -660,7 +670,7 @@ def is_connected(self): """Reports whether the Model is currently connected.""" return self._connector.is_connected() - def connection(self) -> Connection: + def connection(self) -> connection.Connection: """Return the current Connection object. It raises an exception if the Model is disconnected """ @@ -914,7 +924,10 @@ def add_local_charm(self, charm_file, series="", size=None): instead. """ - conn, headers, path_prefix = self.connection().https_connection() + connection = self.connection() + assert connection + + conn, headers, path_prefix = connection.https_connection() path = "%s/charms?series=%s" % (path_prefix, series) headers["Content-Type"] = "application/zip" if size: @@ -1212,11 +1225,12 @@ def name(self): return self._info.name @property - def info(self): + def info(self) -> ModelInfo: """Return the cached client.ModelInfo object for this Model. If Model.get_info() has not been called, this will return None. """ + assert self._info is not None return self._info @property @@ -1306,11 +1320,13 @@ async def _all_watcher(): del allwatcher.Id continue except websockets.ConnectionClosed: - monitor = self.connection().monitor + connection = self.connection() + assert connection + monitor = connection.monitor if monitor.status == monitor.ERROR: # closed unexpectedly, try to reopen log.warning("Watcher: connection closed, reopening") - await self.connection().reconnect() + await connection.reconnect() if monitor.status != monitor.CONNECTED: # reconnect failed; abort and shutdown log.error( @@ -2624,7 +2640,7 @@ async def get_action_status(self, uuid_or_prefix=None, name=None): results[tag.untag("action-", a.action.tag)] = a.status return results - async def get_status(self, filters=None, utc=False): + async def get_status(self, filters=None, utc=False) -> FullStatus: """Return the status of the model. :param str filters: Optional list of applications, units, or machines @@ -2959,15 +2975,15 @@ async def _get_source_api(self, url): async def wait_for_idle( self, apps: list[str] | None = None, - raise_on_error=True, - raise_on_blocked=False, - wait_for_active=False, - timeout=10 * 60, - idle_period=15, - check_freq=0.5, - status=None, - wait_for_at_least_units=None, - wait_for_exact_units=None, + raise_on_error: bool = True, + raise_on_blocked: bool = False, + wait_for_active: bool = False, + timeout: float | None = 10 * 60, + idle_period: float = 15, + check_freq: float = 0.5, + status: str | None = None, + wait_for_at_least_units: int | None = None, + wait_for_exact_units: int | None = None, ) -> None: """Wait for applications in the model to settle into an idle state. @@ -3035,12 +3051,12 @@ async def wait_for_idle( raise JujuError(f"Expected a List[str] for apps, given {apps}") apps = apps or self.applications - idle_times = {} - units_ready = set() # The units that are in the desired state - last_log_time = None + idle_times: dict[str, datetime] = {} + units_ready: set[str] = set() # The units that are in the desired state + last_log_time: datetime | None = None log_interval = timedelta(seconds=30) - def _raise_for_status(entities, status): + def _raise_for_status(entities: dict[str, list[str]], status: Any): if not entities: return for entity_name, error_type in (