diff --git a/aiogoogle/auth/managers.py b/aiogoogle/auth/managers.py index 078c0fe..fc6291a 100644 --- a/aiogoogle/auth/managers.py +++ b/aiogoogle/auth/managers.py @@ -617,6 +617,7 @@ async def refresh(self, user_creds, client_creds=None): Returns: + bool: If the token is refreshed or not aiogoogle.creds.UserCreds: Refreshed user credentials Raises: @@ -624,12 +625,16 @@ async def refresh(self, user_creds, client_creds=None): aiogoogle.excs.AuthError: Auth Error """ client_creds = client_creds or self.client_creds + + if not self.is_expired(user_creds): + return False, user_creds + request = self._build_refresh_request(user_creds, client_creds) json_res = await self._send_request(request) final_user_creds = self._build_user_creds_from_res(json_res) if not final_user_creds.get('refresh_token'): final_user_creds['refresh_token'] = user_creds.get('refresh_token') - return final_user_creds + return True, final_user_creds def _build_refresh_request(self, user_creds, client_creds): data = dict( @@ -1325,11 +1330,11 @@ async def refresh(self): Ensures that there's an unexpired access token. Returns: + bool: If the token is refreshed or not. - None ''' if self._access_token and not _is_expired(self._expires_at): - return + return False if self._creds_source == 'key_file': await self._get_oauth2_authorization_grant() @@ -1340,3 +1345,4 @@ async def refresh(self): 'No service account creds found.' 'Please provide service account credentials or call self.detect_default_creds first' ) + return True diff --git a/aiogoogle/client.py b/aiogoogle/client.py index 057f458..9b551f7 100644 --- a/aiogoogle/client.py +++ b/aiogoogle/client.py @@ -241,16 +241,12 @@ async def as_user(self, *requests, timeout=None, full_res=False, user_creds=None raise TypeError("No user credentials were found") # Refresh credentials - if ( - user_creds.get("expires_at") is None - or (user_creds.get("expires_at") and self.oauth2.is_expired(user_creds) is True) - ): - user_creds = await self.oauth2.refresh( + if user_creds.get("expires_at") is None: + is_refreshed, user_creds = await self.oauth2.refresh( user_creds, client_creds=self.client_creds ) - # Set refreshed user_creds if ones were already existing - if self.user_creds is not None: + if is_refreshed and self.user_creds is not None: self.user_creds = user_creds authorized_requests = [ @@ -263,7 +259,8 @@ async def as_user(self, *requests, timeout=None, full_res=False, user_creds=None full_res=full_res, raise_for_status=raise_for_status, session_factory=self.session_factory, - auth_manager=self.oauth2 + auth_manager=self.oauth2, + user_creds=user_creds ) async def as_service_account( diff --git a/aiogoogle/models.py b/aiogoogle/models.py index 620c229..3b2f6ef 100644 --- a/aiogoogle/models.py +++ b/aiogoogle/models.py @@ -280,6 +280,8 @@ class Response: session_factory (aiogoogle.sessions.abc.AbstractSession): A callable implementation of aiogoogle's session interface auth_manager (aiogoogle.auth.managers.ServiceAccountManager): Service account authorization manager. + + user_creds (aiogoogle.auth.creds.UserCreds): user_creds to make an api call with. """ def __init__( @@ -296,7 +298,8 @@ def __init__( upload_file=None, pipe_from=None, session_factory=None, - auth_manager=None + auth_manager=None, + user_creds=None ): if json and data: raise TypeError("Pass either json or data, not both.") @@ -314,6 +317,8 @@ def __init__( self.pipe_from = pipe_from self.session_factory = session_factory self.auth_manager = auth_manager + # Used for refreshing tokens for the Oauth2 authentication workflow. + self.user_creds = user_creds @staticmethod async def _next_page_generator( @@ -323,6 +328,7 @@ async def _next_page_generator( res_token_name=None, json_req=False, ): + from .auth.managers import ServiceAccountManager, Oauth2Manager prev_url = None while prev_res is not None: @@ -342,8 +348,16 @@ async def _next_page_generator( ) if next_req is not None: async with session_factory() as sess: - await prev_res.auth_manager.refresh() - prev_res.auth_manager.authorize(next_req) + if isinstance(prev_res.auth_manager, (ServiceAccountManager, Oauth2Manager)): + authorize_params = [next_req] + if isinstance(prev_res.auth_manager, ServiceAccountManager): + is_refreshed = await prev_res.auth_manager.refresh() + else: + is_refreshed, user_creds = await prev_res.auth_manager.refresh(prev_res.user_creds) + authorize_params.append(user_creds) + + if is_refreshed is True: + prev_res.auth_manager.authorize(*authorize_params) prev_res = await sess.send(next_req, full_res=True, auth_manager=prev_res.auth_manager) else: prev_res = None diff --git a/aiogoogle/sessions/aiohttp_session.py b/aiogoogle/sessions/aiohttp_session.py index 5a7a66c..bf01f8e 100644 --- a/aiogoogle/sessions/aiohttp_session.py +++ b/aiogoogle/sessions/aiohttp_session.py @@ -40,7 +40,8 @@ async def send( full_res=False, raise_for_status=True, session_factory=None, - auth_manager=None + auth_manager=None, + **kwargs ): async def resolve_response(request, response): data = None @@ -93,7 +94,8 @@ async def resolve_response(request, response): upload_file=upload_file, pipe_from=pipe_from, session_factory=session_factory, - auth_manager=auth_manager + auth_manager=auth_manager, + user_creds=kwargs.get("user_creds") ) async def fire_request(request):