diff --git a/pyproject.toml b/pyproject.toml index eeb262dd..ff4d68ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "ops-scenario" -version = "2.1.3.5" +version = "2.2" authors = [ { name = "Pietro Pasotti", email = "pietro.pasotti@canonical.com" } ] diff --git a/scenario/mocking.py b/scenario/mocking.py index 8bfda1c4..8b294c58 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -220,7 +220,7 @@ def secret_get( if peek or refresh: revision = max(secret.contents.keys()) if refresh: - secret.revision = revision + secret._set_revision(revision) return secret.contents[revision] @@ -298,6 +298,10 @@ def secret_remove(self, id: str, *, revision: Optional[int] = None): else: secret.contents.clear() + def relation_remote_app_name(self, relation_id: int): + relation = self._get_relation_by_id(relation_id) + return relation.remote_app_name + # TODO: def action_set(self, *args, **kwargs): raise NotImplementedError("action_set") @@ -314,9 +318,6 @@ def storage_add(self, *args, **kwargs): def action_get(self): raise NotImplementedError("action_get") - def relation_remote_app_name(self, *args, **kwargs): - raise NotImplementedError("relation_remote_app_name") - def resource_get(self, *args, **kwargs): raise NotImplementedError("resource_get") diff --git a/scenario/state.py b/scenario/state.py index 1086df7e..cef188e0 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -70,7 +70,7 @@ class StateValidationError(RuntimeError): # **combination** of several parts of the State are. -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class _DCBase: def replace(self, *args, **kwargs): return dataclasses.replace(self, *args, **kwargs) @@ -79,7 +79,7 @@ def copy(self) -> "Self": return copy.deepcopy(self) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Secret(_DCBase): id: str @@ -142,6 +142,11 @@ def remove_event(self): ) return Event(name="secret_removed", secret=self) + def _set_revision(self, revision: int): + """Set a new tracked revision.""" + # bypass frozen dataclass + object.__setattr__(self, "revision", revision) + _RELATION_IDS_CTR = 0 @@ -172,7 +177,17 @@ def deferred(self, handler: Callable, event_id: int = 1) -> "DeferredEvent": return self().deferred(handler=handler, event_id=event_id) -@dataclasses.dataclass +def _generate_new_relation_id(): + global _RELATION_IDS_CTR + _RELATION_IDS_CTR += 1 + logger.info( + f"relation ID unset; automatically assigning {_RELATION_IDS_CTR}. " + f"If there are problems, pass one manually." + ) + return _RELATION_IDS_CTR + + +@dataclasses.dataclass(frozen=True) class RelationBase(_DCBase): endpoint: str @@ -180,7 +195,7 @@ class RelationBase(_DCBase): interface: str = None # Every new Relation instance gets a new one, if there's trouble, override. - relation_id: int = -1 + relation_id: int = dataclasses.field(default_factory=_generate_new_relation_id) local_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) local_unit_data: Dict[str, str] = dataclasses.field(default_factory=dict) @@ -212,15 +227,6 @@ def __post_init__(self): "please use Relation, PeerRelation, or SubordinateRelation" ) - global _RELATION_IDS_CTR - if self.relation_id == -1: - _RELATION_IDS_CTR += 1 - logger.info( - f"relation ID unset; automatically assigning {_RELATION_IDS_CTR}. " - f"If there are problems, pass one manually." - ) - self.relation_id = _RELATION_IDS_CTR - for databag in self._databags: self._validate_databag(databag) @@ -314,9 +320,11 @@ def unify_ids_and_remote_units_data(ids: List[int], data: Dict[int, Any]): return ids, data -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Relation(RelationBase): remote_app_name: str = "remote" + + # fixme: simplify API by deriving remote_unit_ids from remote_units_data. remote_unit_ids: List[int] = dataclasses.field(default_factory=list) # local limit @@ -327,6 +335,16 @@ class Relation(RelationBase): default_factory=dict ) + def __post_init__(self): + super().__post_init__() + + remote_unit_ids, remote_units_data = unify_ids_and_remote_units_data( + self.remote_unit_ids, self.remote_units_data + ) + # bypass frozen dataclass + object.__setattr__(self, "remote_unit_ids", remote_unit_ids) + object.__setattr__(self, "remote_units_data", remote_units_data) + @property def _remote_app_name(self) -> str: """Who is on the other end of this relation?""" @@ -349,14 +367,8 @@ def _databags(self): yield self.remote_app_data yield from self.remote_units_data.values() - def __post_init__(self): - super().__post_init__() - self.remote_unit_ids, self.remote_units_data = unify_ids_and_remote_units_data( - self.remote_unit_ids, self.remote_units_data - ) - -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class SubordinateRelation(RelationBase): # todo: consider renaming them to primary_*_data remote_app_data: Dict[str, str] = dataclasses.field(default_factory=dict) @@ -393,7 +405,7 @@ def primary_name(self) -> str: return f"{self.primary_app_name}/{self.primary_id}" -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class PeerRelation(RelationBase): peers_data: Dict[int, Dict[str, str]] = dataclasses.field(default_factory=dict) @@ -423,9 +435,12 @@ def _get_databag_for_remote(self, unit_id: int) -> Dict[str, str]: return self.peers_data[unit_id] def __post_init__(self): - self.peers_ids, self.peers_data = unify_ids_and_remote_units_data( + peers_ids, peers_data = unify_ids_and_remote_units_data( self.peers_ids, self.peers_data ) + # bypass frozen dataclass guards + object.__setattr__(self, "peers_ids", peers_ids) + object.__setattr__(self, "peers_data", peers_data) def _random_model_name(): @@ -436,7 +451,7 @@ def _random_model_name(): return "".join(random.choice(space) for _ in range(20)) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Model(_DCBase): name: str = _random_model_name() uuid: str = str(uuid4()) @@ -453,31 +468,39 @@ class Model(_DCBase): _CHANGE_IDS = 0 -@dataclasses.dataclass +def _generate_new_change_id(): + global _CHANGE_IDS + _CHANGE_IDS += 1 + logger.info( + f"change ID unset; automatically assigning {_CHANGE_IDS}. " + f"If there are problems, pass one manually." + ) + return _CHANGE_IDS + + +@dataclasses.dataclass(frozen=True) class ExecOutput: return_code: int = 0 stdout: str = "" stderr: str = "" # change ID: used internally to keep track of mocked processes - _change_id: int = -1 + _change_id: int = dataclasses.field(default_factory=_generate_new_change_id) def _run(self) -> int: - global _CHANGE_IDS - _CHANGE_IDS = self._change_id = _CHANGE_IDS + 1 - return _CHANGE_IDS + return self._change_id _ExecMock = Dict[Tuple[str, ...], ExecOutput] -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Mount(_DCBase): location: Union[str, PurePosixPath] src: Union[str, Path] -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Container(_DCBase): name: str can_connect: bool = False @@ -583,7 +606,7 @@ def pebble_ready_event(self): return Event(name=normalize_name(self.name + "-pebble-ready"), container=self) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Address(_DCBase): hostname: str value: str @@ -591,7 +614,7 @@ class Address(_DCBase): address: str = "" # legacy -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class BindAddress(_DCBase): interface_name: str addresses: List[Address] @@ -609,7 +632,7 @@ def hook_tool_output_fmt(self): return dct -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Network(_DCBase): name: str @@ -654,7 +677,7 @@ def default( ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class _EntityStatus(_DCBase): """This class represents StatusBase and should not be interacted with directly.""" @@ -690,17 +713,13 @@ def _status_to_entitystatus(obj: StatusBase) -> _EntityStatus: return _EntityStatus(obj.name, obj.message) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Status(_DCBase): """Represents the 'juju statuses' of the application/unit being tested.""" # the current statuses. Will be cast to _EntitiyStatus in __post_init__ - app: Union[StatusBase, _EntityStatus] = dataclasses.field( - default_factory=lambda: _EntityStatus("unknown") - ) - unit: Union[StatusBase, _EntityStatus] = dataclasses.field( - default_factory=lambda: _EntityStatus("unknown") - ) + app: Union[StatusBase, _EntityStatus] = _EntityStatus("unknown") + unit: Union[StatusBase, _EntityStatus] = _EntityStatus("unknown") app_version: str = "" # most to least recent statuses; do NOT include the current one. @@ -714,14 +733,14 @@ def __post_init__(self): if isinstance(val, _EntityStatus): pass elif isinstance(val, StatusBase): - setattr(self, name, _status_to_entitystatus(val)) + object.__setattr__(self, name, _status_to_entitystatus(val)) elif isinstance(val, tuple): logger.warning( "Initializing Status.[app/unit] with Tuple[str, str] is deprecated " "and will be removed soon. \n" f"Please pass a StatusBase instance: `StatusBase(*{val})`" ) - setattr(self, name, _EntityStatus(*val)) + object.__setattr__(self, name, _EntityStatus(*val)) else: raise TypeError(f"Invalid status.{name}: {val!r}") @@ -729,8 +748,10 @@ def _update_app_version(self, new_app_version: str): """Update the current app version and record the previous one.""" # We don't keep a full history because we don't expect the app version to change more # than once per hook. - self.previous_app_version = self.app_version - self.app_version = new_app_version + + # bypass frozen dataclass + object.__setattr__(self, "previous_app_version", self.app_version) + object.__setattr__(self, "app_version", new_app_version) def _update_status( self, new_status: str, new_message: str = "", is_app: bool = False @@ -738,13 +759,15 @@ def _update_status( """Update the current app/unit status and add the previous one to the history.""" if is_app: self.app_history.append(self.app) - self.app = _EntityStatus(new_status, new_message) + # bypass frozen dataclass + object.__setattr__(self, "app", _EntityStatus(new_status, new_message)) else: self.unit_history.append(self.unit) - self.unit = _EntityStatus(new_status, new_message) + # bypass frozen dataclass + object.__setattr__(self, "unit", _EntityStatus(new_status, new_message)) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class StoredState(_DCBase): # /-separated Object names. E.g. MyCharm/MyCharmLib. # if None, this StoredState instance is owned by the Framework. @@ -760,7 +783,7 @@ def handle_path(self): return f"{self.owner_path or ''}/{self.data_type_name}[{self.name}]" -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class State(_DCBase): """Represents the juju-owned portion of a unit's state. @@ -874,7 +897,7 @@ def trigger( ) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class _CharmSpec(_DCBase): """Charm spec.""" @@ -918,7 +941,7 @@ def sort_patch(patch: List[Dict], key=lambda obj: obj["path"] + obj["op"]): return sorted(patch, key=key) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class DeferredEvent(_DCBase): handle_path: str owner: str @@ -932,7 +955,7 @@ def name(self): return self.handle_path.split("/")[-1].split("[")[0] -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Event(_DCBase): name: str args: Tuple[Any] = () @@ -965,7 +988,9 @@ def __call__(self, remote_unit_id: Optional[int] = None) -> "Event": def __post_init__(self): if "-" in self.name: logger.warning(f"Only use underscores in event names. {self.name!r}") - self.name = normalize_name(self.name) + + # bypass frozen dataclass + object.__setattr__(self, "name", normalize_name(self.name)) @property def _is_relation_event(self) -> bool: @@ -1089,14 +1114,14 @@ def deferred( return event.deferred(handler=handler, event_id=event_id) -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class Inject(_DCBase): """Base class for injectors: special placeholders used to tell harness_ctx to inject instances that can't be retrieved in advance in event args or kwargs. """ -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class InjectRelation(Inject): relation_name: str relation_id: Optional[int] = None diff --git a/tests/test_e2e/test_pebble.py b/tests/test_e2e/test_pebble.py index 3cff2e17..d6fb3889 100644 --- a/tests/test_e2e/test_pebble.py +++ b/tests/test_e2e/test_pebble.py @@ -128,9 +128,8 @@ def callback(self: CharmBase): assert file.read() == text else: # nothing has changed - out.juju_log = [] - out.stored_state = state.stored_state # ignore stored state in delta. - assert not out.jsonpatch_delta(state) + out_purged = out.replace(juju_log=[], stored_state=state.stored_state) + assert not out_purged.jsonpatch_delta(state) LS = """ diff --git a/tests/test_e2e/test_play_assertions.py b/tests/test_e2e/test_play_assertions.py index ddcd05e5..a7ee4175 100644 --- a/tests/test_e2e/test_play_assertions.py +++ b/tests/test_e2e/test_play_assertions.py @@ -58,9 +58,8 @@ def post_event(charm): assert out.status.unit == ActiveStatus("yabadoodle") - out.juju_log = [] # exclude juju log from delta - out.stored_state = initial_state.stored_state # ignore stored state in delta. - assert out.jsonpatch_delta(initial_state) == [ + out_purged = out.replace(juju_log=[], stored_state=initial_state.stored_state) + assert out_purged.jsonpatch_delta(initial_state) == [ { "op": "replace", "path": "/status/unit/message", diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index 695a5518..9aebbab7 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -60,9 +60,8 @@ def state(): def test_bare_event(state, mycharm): out = state.trigger("start", mycharm, meta={"name": "foo"}) - out.juju_log = [] # ignore logging output in the delta - out.stored_state = state.stored_state # ignore stored state in delta. - assert state.jsonpatch_delta(out) == [] + out_purged = out.replace(juju_log=[], stored_state=state.stored_state) + assert state.jsonpatch_delta(out_purged) == [] def test_leader_get(state, mycharm): @@ -88,14 +87,15 @@ def call(charm: CharmBase, _): "start", mycharm, meta={"name": "foo"}, + config={"options": {"foo": {"type": "string"}}}, ) assert out.status.unit == ActiveStatus("foo test") assert out.status.app == WaitingStatus("foo barz") assert out.status.app_version == "" - out.juju_log = [] # ignore logging output in the delta - out.stored_state = state.stored_state # ignore stored state in delta. - assert out.jsonpatch_delta(state) == sort_patch( + # ignore logging output and stored state in the delta + out_purged = out.replace(juju_log=[], stored_state=state.stored_state) + assert out_purged.jsonpatch_delta(state) == sort_patch( [ {"op": "replace", "path": "/status/app/message", "value": "foo barz"}, {"op": "replace", "path": "/status/app/name", "value": "waiting"}, @@ -123,7 +123,7 @@ def pre_event(charm: CharmBase): assert container.name == "foo" assert container.can_connect() is connect - State(containers=(Container(name="foo", can_connect=connect),)).trigger( + State(containers=[Container(name="foo", can_connect=connect)]).trigger( "start", mycharm, meta={ diff --git a/tox.ini b/tox.ini index a25b8f0f..ecd71dd1 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ [tox] envlist = - {py36,py37,py38} + {py36,py37,py38,py311} unit, lint isolated_build = True skip_missing_interpreters = True