diff --git a/packages/atproto_client/client/async_client.py b/packages/atproto_client/client/async_client.py index 228aff3d..93b9366d 100644 --- a/packages/atproto_client/client/async_client.py +++ b/packages/atproto_client/client/async_client.py @@ -50,7 +50,7 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response return await super()._invoke(invoke_type, **kwargs) async def _set_session(self, event: SessionEvent, session: SessionResponse) -> None: - session = self._set_session_common(session) + session = self._set_session_common(session, self._base_url) await self._call_on_session_change_callbacks(event, session.copy()) async def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response': diff --git a/packages/atproto_client/client/client.py b/packages/atproto_client/client/client.py index 269746fb..078ba0a7 100644 --- a/packages/atproto_client/client/client.py +++ b/packages/atproto_client/client/client.py @@ -41,7 +41,7 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs: t.Any) -> 'Response': return super()._invoke(invoke_type, **kwargs) def _set_session(self, event: SessionEvent, session: SessionResponse) -> None: - session = self._set_session_common(session) + session = self._set_session_common(session, self._base_url) self._call_on_session_change_callbacks(event, session.copy()) def _get_and_set_session(self, login: str, password: str) -> 'models.ComAtprotoServerCreateSession.Response': diff --git a/packages/atproto_client/client/methods_mixin/session.py b/packages/atproto_client/client/methods_mixin/session.py index 91175f38..51be92cf 100644 --- a/packages/atproto_client/client/methods_mixin/session.py +++ b/packages/atproto_client/client/methods_mixin/session.py @@ -117,7 +117,7 @@ def _should_refresh_session(self) -> bool: return self.get_current_time() > expired_at - def _set_session_common(self, session: SessionResponse) -> Session: + def _set_session_common(self, session: SessionResponse, current_pds: str) -> Session: self._access_jwt = session.access_jwt self._access_jwt_payload = get_jwt_payload(session.access_jwt) @@ -125,6 +125,10 @@ def _set_session_common(self, session: SessionResponse) -> Session: self._refresh_jwt_payload = get_jwt_payload(session.refresh_jwt) pds_endpoint = get_session_pds_endpoint(session) + if not pds_endpoint: + # current_pds ends with xrpc endpoint, but this is not a problem + # overhead is only 4-5 symbols in the exported session string + pds_endpoint = current_pds self._session = Session( access_jwt=session.access_jwt, diff --git a/packages/atproto_client/client/session.py b/packages/atproto_client/client/session.py index e8c00f21..824842f9 100644 --- a/packages/atproto_client/client/session.py +++ b/packages/atproto_client/client/session.py @@ -72,11 +72,15 @@ def copy(self) -> 'Session': def get_session_pds_endpoint(session: SessionResponse) -> t.Optional[str]: - """Return the PDS endpoint of the given session.""" + """Return the PDS endpoint of the given session. + + Note: + Return :obj:`None` for self-hosted PDSs. + """ if isinstance(session, Session): return session.pds_endpoint - if is_valid_did_doc(session.did_doc): + if session.did_doc and is_valid_did_doc(session.did_doc): doc = DidDocument.from_dict(session.did_doc) return doc.get_pds_endpoint() diff --git a/packages/atproto_identity/did/resolvers/base_resolver.py b/packages/atproto_identity/did/resolvers/base_resolver.py index e542419b..2bd8e5ad 100644 --- a/packages/atproto_identity/did/resolvers/base_resolver.py +++ b/packages/atproto_identity/did/resolvers/base_resolver.py @@ -36,7 +36,7 @@ def resolve_without_validation(self, did: str) -> t.Optional[t.Dict[str, t.Any]] raise NotImplementedError def resolve_no_cache(self, did: str) -> t.Optional['DidDocument']: - """Resolve DID without cache. + """Resolve DID without a cache. Args: did: DID. @@ -103,7 +103,7 @@ def ensure_resolve(self, did: str, force_refresh: bool = False) -> 'DidDocument' :obj:`DidDocument`: DID document. Raises: - :obj:`DidNotFoundError`: DID not found. + :obj:`DidNotFoundError`: DID not find. """ did_doc = self.resolve(did, force_refresh) if did_doc is None: @@ -158,7 +158,7 @@ async def resolve_without_validation(self, did: str) -> t.Optional[t.Dict[str, t raise NotImplementedError async def resolve_no_cache(self, did: str) -> t.Optional['DidDocument']: - """Resolve DID without cache. + """Resolve DID without a cache. Args: did: DID. @@ -225,7 +225,7 @@ async def ensure_resolve(self, did: str, force_refresh: bool = False) -> 'DidDoc :obj:`DidDocument`: DID document. Raises: - :obj:`DidNotFoundError`: DID not found. + :obj:`DidNotFoundError`: DID not find. """ did_doc = await self.resolve(did, force_refresh) if did_doc is None: