diff --git a/lib/rucio/core/account.py b/lib/rucio/core/account.py index 5bcb7df311..b65d702e1b 100644 --- a/lib/rucio/core/account.py +++ b/lib/rucio/core/account.py @@ -19,7 +19,7 @@ from traceback import format_exc from typing import TYPE_CHECKING -from sqlalchemy import and_ +from sqlalchemy import select, and_ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import exc @@ -64,17 +64,20 @@ def add_account(account, type_, email, *, session: "Session"): @read_session def account_exists(account, *, session: "Session"): - """ Checks to see if account exists. This procedure does not check it's status. + """ Checks to see if account exists and is active. :param account: Name of the account. :param session: the database session in use. :returns: True if found, otherwise false. """ - - query = session.query(models.Account).filter_by(account=account, status=AccountStatus.ACTIVE) - - return True if query.first() else False + query = select( + models.Account + ).where( + models.Account.account == account, + models.Account.status == AccountStatus.ACTIVE + ) + return session.execute(query).scalar() is not None @read_session @@ -86,10 +89,13 @@ def get_account(account, *, session: "Session"): :returns: a dict with all information for the account. """ + query = select( + models.Account + ).where( + models.Account.account == account + ) - query = session.query(models.Account).filter_by(account=account) - - result = query.first() + result = session.execute(query).scalar() if result is None: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) return result @@ -102,9 +108,14 @@ def del_account(account, *, session: "Session"): :param account: the account name. :param session: the database session in use. """ - query = session.query(models.Account).filter_by(account=account).filter_by(status=AccountStatus.ACTIVE) + query = select( + models.Account + ).where( + models.Account.account == account, + models.Account.status == AccountStatus.ACTIVE + ) try: - account = query.one() + account = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) @@ -120,20 +131,24 @@ def update_account(account, key, value, *, session: "Session"): :param value: Property value. :param session: the database session in use. """ - query = session.query(models.Account).filter_by(account=account) + query = select( + models.Account + ).where( + models.Account.account == account + ) try: - account = query.one() + account = session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) if key == 'status': if isinstance(value, str): value = AccountStatus[value] if value == AccountStatus.SUSPENDED: - query.update({'status': value, 'suspended_at': datetime.utcnow()}) + account.update({'status': value, 'suspended_at': datetime.utcnow()}) elif value == AccountStatus.ACTIVE: - query.update({'status': value, 'suspended_at': None}) + account.update({'status': value, 'suspended_at': None}) else: - query.update({key: value}) + account.update({key: value}) @stream_session @@ -147,31 +162,53 @@ def list_accounts(filter_=None, *, session: "Session"): """ if filter_ is None: filter_ = {} - query = session.query(models.Account.account, models.Account.account_type, - models.Account.email).filter_by(status=AccountStatus.ACTIVE) + query = select( + models.Account.account, + models.Account.account_type, + models.Account.email + ).where( + models.Account.status == AccountStatus.ACTIVE + ) for filter_type in filter_: if filter_type == 'account_type': if isinstance(filter_['account_type'], str): - query = query.filter_by(account_type=AccountType[filter_['account_type']]) + query = query.where( + models.Account.account_type == AccountType[filter_['account_type']] + ) elif isinstance(filter_['account_type'], Enum): - query = query.filter_by(account_type=filter_['account_type']) + query = query.where( + models.Account.account_type == filter_['account_type'] + ) elif filter_type == 'identity': - query = query.join(models.IdentityAccountAssociation, models.Account.account == models.IdentityAccountAssociation.account).\ - filter(models.IdentityAccountAssociation.identity == filter_['identity']) + query = query.join( + models.IdentityAccountAssociation, + models.Account.account == models.IdentityAccountAssociation.account + ).where( + models.IdentityAccountAssociation.identity == filter_['identity'] + ) elif filter_type == 'account': if '*' in filter_['account'].internal: account_str = filter_['account'].internal.replace('*', '%') - query = query.filter(models.Account.account.like(account_str)) + query = query.where( + models.Account.account.like(account_str) + ) else: - query = query.filter_by(account=filter_['account']) + query = query.where( + models.Account.account == filter_['account'] + ) else: - query = query.join(models.AccountAttrAssociation, models.Account.account == models.AccountAttrAssociation.account).\ - filter(models.AccountAttrAssociation.key == filter_type).\ - filter(models.AccountAttrAssociation.value == filter_[filter_type]) - - for account, account_type, email in query.order_by(models.Account.account).yield_per(25): + query = query.join( + models.AccountAttrAssociation, + models.Account.account == models.AccountAttrAssociation.account + ).where( + models.AccountAttrAssociation.key == filter_type, + models.AccountAttrAssociation.value == filter_[filter_type] + ) + query = query.order_by(models.Account.account) + + for account, account_type, email in session.execute(query).yield_per(25): yield {'account': account, 'type': account_type, 'email': email} @@ -183,22 +220,31 @@ def list_identities(account, *, session: "Session"): :param account: The account name. :param session: the database session in use. """ - identity_list = list() - - query = session.query(models.Account).filter_by(account=account).filter_by(status=AccountStatus.ACTIVE) + query = select( + models.Account + ).where( + models.Account.account == account, + models.Account.status == AccountStatus.ACTIVE + ) try: - query.one() + session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) - query = session.query(models.IdentityAccountAssociation, models.Identity)\ - .join(models.Identity, and_(models.Identity.identity == models.IdentityAccountAssociation.identity, - models.Identity.identity_type == models.IdentityAccountAssociation.identity_type))\ - .filter(models.IdentityAccountAssociation.account == account) - for identity in query: - identity_list.append({'type': identity[0].identity_type, 'identity': identity[0].identity, 'email': identity[1].email}) - - return identity_list + query = select( + models.IdentityAccountAssociation.identity_type.label('type'), + models.IdentityAccountAssociation.identity, + models.Identity.email + ).join( + models.Identity, + and_( + models.Identity.identity == models.IdentityAccountAssociation.identity, + models.Identity.identity_type == models.IdentityAccountAssociation.identity_type + ) + ).where( + models.IdentityAccountAssociation.account == account + ) + return [row._asdict() for row in session.execute(query)] @read_session @@ -211,18 +257,24 @@ def list_account_attributes(account, *, session: "Session"): :returns: a list of all key, value pairs for this account. """ - attr_list = [] - query = session.query(models.Account).filter_by(account=account).filter_by(status=AccountStatus.ACTIVE) + query = select( + models.Account + ).where( + models.Account.account == account, + models.Account.status == AccountStatus.ACTIVE + ) try: - query.one() + session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound("Account ID '{0}' does not exist".format(account)) - query = session.query(models.AccountAttrAssociation).filter_by(account=account) - for attr in query: - attr_list.append({'key': attr.key, 'value': attr.value}) - - return attr_list + query = select( + models.AccountAttrAssociation.key, + models.AccountAttrAssociation.value + ).where( + models.AccountAttrAssociation.account == account + ) + return [row._asdict() for row in session.execute(query)] @read_session @@ -236,9 +288,13 @@ def has_account_attribute(account, key, *, session: "Session"): :returns: True or False """ - if session.query(models.AccountAttrAssociation.value).filter_by(account=account, key=key).first(): - return True - return False + query = select( + models.AccountAttrAssociation.value + ).where( + models.AccountAttrAssociation.account == account, + models.AccountAttrAssociation.key == key + ) + return session.execute(query).scalar() is not None @transactional_session @@ -251,11 +307,14 @@ def add_account_attribute(account, key, value, *, session: "Session"): :param account: the account to add the attribute to. :param session: The database session in use. """ - - query = session.query(models.Account).filter_by(account=account, status=AccountStatus.ACTIVE) - + query = select( + models.Account + ).where( + models.Account.account == account, + models.Account.status == AccountStatus.ACTIVE + ) try: - query.one() + session.execute(query).scalar_one() except exc.NoResultFound: raise exception.AccountNotFound("Account ID '{0}' does not exist".format(account)) @@ -283,7 +342,13 @@ def del_account_attribute(account, key, *, session: "Session"): :param key: the key for the new attribute. :param session: The database session in use. """ - aid = session.query(models.AccountAttrAssociation).filter_by(key=key, account=account).first() + query = select( + models.AccountAttrAssociation + ).where( + models.AccountAttrAssociation.account == account, + models.AccountAttrAssociation.key == key + ) + aid = session.execute(query).scalar() if aid is None: raise exception.AccountNotFound('Attribute ({0}) does not exist for the account {1}!'.format(key, account)) aid.delete(session=session) @@ -299,10 +364,16 @@ def get_usage(rse_id, account, *, session: "Session"): :param session: The database session in use. :returns: A dictionary {'bytes', 'files', 'updated_at'} """ - + query = select( + models.AccountUsage.bytes, + models.AccountUsage.files, + models.AccountUsage.updated_at + ).where( + models.AccountUsage.rse_id == rse_id, + models.AccountUsage.account == account + ) try: - counter = session.query(models.AccountUsage).filter_by(rse_id=rse_id, account=account).one() - return {'bytes': counter.bytes, 'files': counter.files, 'updated_at': counter.updated_at} + return session.execute(query).one()._asdict() except exc.NoResultFound: return {'bytes': 0, 'files': 0, 'updated_at': None} @@ -317,9 +388,13 @@ def get_all_rse_usages_per_account(account, *, session: "Session"): :param session: The database session in use. :returns: List of dicts with :py:class:`models.AccountUsage` items """ - + query = select( + models.AccountUsage + ).where( + models.AccountUsage.account == account + ) try: - return [result.to_dict() for result in session.query(models.AccountUsage).filter_by(account=account).all()] + return [result.to_dict() for result in session.execute(query).scalars()] except exc.NoResultFound: return [] @@ -334,13 +409,18 @@ def get_usage_history(rse_id, account, *, session: "Session"): :param session: The database session in use. :returns: A dictionary {'bytes', 'files', 'updated_at'} """ - - result = [] - AccountUsageHistory = models.AccountUsageHistory + query = select( + models.AccountUsageHistory.bytes, + models.AccountUsageHistory.files, + models.AccountUsageHistory.updated_at + ).where( + models.AccountUsageHistory.rse_id == rse_id, + models.AccountUsageHistory.account == account + ).order_by( + models.AccountUsageHistory.updated_at + ) try: - query = session.query(AccountUsageHistory).filter_by(rse_id=rse_id, account=account).order_by(AccountUsageHistory.updated_at) - for row in query.all(): - result.append({'bytes': row.bytes, 'files': row.files, 'updated_at': row.updated_at}) + return [row._asdict() for row in session.execute(query)] except exc.NoResultFound: raise exception.CounterNotFound('No usage can be found for account %s on RSE %s' % (account, rucio.core.rse.get_rse_name(rse_id=rse_id, session=session))) - return result + return [] diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index 7df67a2756..cd5775ea39 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -25,7 +25,7 @@ import paramiko from dogpile.cache import make_region from dogpile.cache.api import NO_VALUE -from sqlalchemy import and_, or_, select, delete +from sqlalchemy import delete, null, or_, select from rucio.common.cache import make_region_memcached from rucio.common.config import config_get_bool @@ -110,8 +110,13 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses if not account_exists(account, session=session): return None - result = session.query(models.Identity).filter_by(identity=username, - identity_type=IdentityType.USERPASS).first() + query = select( + models.Identity + ).where( + models.Identity.identity == username, + models.Identity.identity_type == IdentityType.USERPASS + ) + result = session.execute(query).scalar() db_salt = result['salt'] db_password = result['password'] @@ -121,9 +126,15 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses return None # get account identifier - result = session.query(models.IdentityAccountAssociation).filter_by(identity=username, - identity_type=IdentityType.USERPASS, - account=account).first() + query = select( + models.IdentityAccountAssociation + ).where( + models.IdentityAccountAssociation.identity == username, + models.IdentityAccountAssociation.identity_type == IdentityType.USERPASS, + models.IdentityAccountAssociation.account == account + ) + result = session.execute(query).scalar() + db_account = result['account'] # remove expired tokens @@ -225,13 +236,23 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session" return None # get all active challenge tokens for the requested account - active_challenge_tokens = session.query(models.Token).filter(models.Token.expired_at >= datetime.datetime.utcnow(), - models.Token.account == account, - models.Token.token.like('challenge-%')).all() + query = select( + models.Token + ).where( + models.Token.expired_at >= datetime.datetime.utcnow(), + models.Token.account == account, + models.Token.token.like('challenge-%') + ) + active_challenge_tokens = session.execute(query).scalars().all() # get all identities for the requested account - identities = session.query(models.IdentityAccountAssociation).filter_by(identity_type=IdentityType.SSH, - account=account).all() + query = select( + models.IdentityAccountAssociation + ).where( + models.IdentityAccountAssociation.identity_type == IdentityType.SSH, + models.IdentityAccountAssociation.account == account + ) + identities = session.execute(query).scalars().all() # no challenge tokens found if not active_challenge_tokens: @@ -351,20 +372,25 @@ def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): """ try: - redirect_result = session.query(models.OAuthRequest.redirect_msg).filter_by(access_msg=auth_code).first() + query = select( + models.OAuthRequest.redirect_msg + ).where( + models.OAuthRequest.access_msg == auth_code + ) + redirect_result = session.execute(query).scalar() if not redirect_result: return None - if 'http' not in redirect_result[0] and fetchtoken: + if 'http' not in redirect_result and fetchtoken: # in this case the function check if the value is a valid token - vdict = validate_auth_token(redirect_result[0], session=session) + vdict = validate_auth_token(redirect_result, session=session) if vdict: - return redirect_result[0] + return redirect_result return None - elif 'http' in redirect_result[0] and not fetchtoken: + elif 'http' in redirect_result and not fetchtoken: # return redirection URL - return redirect_result[0] + return redirect_result except: raise CannotAuthenticate(traceback.format_exc()) @@ -385,23 +411,43 @@ def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: # get expired tokens try: # delete all expired tokens except tokens which have refresh token that is still valid - query = session.query(models.Token.token).filter(and_(models.Token.expired_at <= datetime.datetime.utcnow()))\ - .filter(or_(models.Token.refresh_expired_at.__eq__(None), - models.Token.refresh_expired_at <= datetime.datetime.utcnow()))\ - .order_by(models.Token.expired_at) + query = select( + models.Token.token + ).where( + models.Token.expired_at <= datetime.datetime.utcnow(), + or_( + models.Token.refresh_expired_at == null(), + models.Token.refresh_expired_at <= datetime.datetime.utcnow() + ) + ).order_by( + models.Token.expired_at + ) query = filter_thread_work(session=session, query=query, total_threads=total_workers, thread_id=worker_number, hash_variable='token') # limiting the number of tokens deleted at once query = query.limit(limit) + # Oracle does not support chaining order_by(), limit(), and + # with_for_update(). Use a nested query to overcome this. + if session.bind.dialect.name == 'oracle': + query = select( + models.Token.token + ).where( + models.Token.token.in_(query) + ).with_for_update( + skip_locked=True + ) + else: + query = query.with_for_update(skip_locked=True) # remove expired tokens deleted_tokens = 0 - for items in session.execute(query).partitions(10): - tokens = tuple(map(lambda row: row.token, items)) - deleted_tokens += session.query(models.Token) \ - .filter(models.Token.token.in_(tokens)) \ - .with_for_update(skip_locked=True) \ - .delete(synchronize_session='fetch') + for tokens in session.execute(query).scalars().partitions(10): + query = delete( + models.Token + ).where( + models.Token.token.in_(tokens) + ) + deleted_tokens += session.execute(query).rowcount except Exception as error: raise RucioException(error.args) @@ -426,20 +472,19 @@ def query_token(token, *, session: "Session"): if successful, None otherwise. """ # Query the DB to validate token - ret = session.query(models.Token.account, - models.Token.identity, - models.Token.expired_at, - models.Token.audience, - models.Token.oidc_scope).\ - filter(models.Token.token == token, - models.Token.expired_at > datetime.datetime.utcnow()).\ - all() - if ret: - return {'account': ret[0][0], - 'identity': ret[0][1], - 'lifetime': ret[0][2], - 'audience': ret[0][3], - 'authz_scope': ret[0][4]} + query = select( + models.Token.account, + models.Token.identity, + models.Token.expired_at.label('lifetime'), + models.Token.audience, + models.Token.oidc_scope.label('authz_scope') + ).where( + models.Token.token == token, + models.Token.expired_at > datetime.datetime.utcnow() + ) + result = session.execute(query).first() + if result: + return result._asdict() return None @@ -497,14 +542,22 @@ def __delete_expired_tokens_account(account, *, session: "Session"): :param account: Account to delete expired tokens. :param session: The database session in use. """ - stmt_select = select(models.Token.token) \ - .where(and_(models.Token.expired_at < datetime.datetime.utcnow(), - models.Token.account == account)) \ - .with_for_update(skip_locked=True) - tokens = session.execute(stmt_select).scalars().all() - - for t in chunks(tokens, 100): - stmt_delete = delete(models.Token) \ - .where(models.Token.token.in_(t)) \ - .prefix_with("/*+ INDEX(TOKENS_ACCOUNT_EXPIRED_AT_IDX) */") - session.execute(stmt_delete) + select_query = select( + models.Token.token + ).where( + models.Token.expired_at < datetime.datetime.utcnow(), + models.Token.account == account + ).with_for_update( + skip_locked=True + ) + tokens = session.execute(select_query).scalars().all() + + for chunk in chunks(tokens, 100): + delete_query = delete( + models.Token + ).prefix_with( + "/*+ INDEX(TOKENS_ACCOUNT_EXPIRED_AT_IDX) */" + ).where( + models.Token.token.in_(chunk) + ) + session.execute(delete_query) diff --git a/lib/rucio/core/identity.py b/lib/rucio/core/identity.py index f2288d6002..895dac216f 100644 --- a/lib/rucio/core/identity.py +++ b/lib/rucio/core/identity.py @@ -19,7 +19,7 @@ from typing import Optional from typing import TYPE_CHECKING -from sqlalchemy import asc +from sqlalchemy import select, true from sqlalchemy.exc import IntegrityError from rucio.common import exception @@ -81,7 +81,13 @@ def verify_identity(identity: str, type_: IdentityType, password: Optional[str] if type_ == IdentityType.USERPASS and password is None: raise exception.IdentityError('You must provide a password!') - id_ = session.query(models.Identity).filter_by(identity=identity, identity_type=type_).first() + query = select( + models.Identity + ).where( + models.Identity.identity == identity, + models.Identity.identity_type == type_ + ) + id_ = session.execute(query).scalar() if id_ is None: raise exception.IdentityError('Identity \'%s\' of type \'%s\' does not exist!' % (identity, type_)) if type_ == IdentityType.X509: @@ -106,7 +112,13 @@ def del_identity(identity: str, type_: IdentityType, *, session: "Session"): :param session: The database session in use. """ - id_ = session.query(models.Identity).filter_by(identity=identity, identity_type=type_).first() + query = select( + models.Identity + ).where( + models.Identity.identity == identity, + models.Identity.identity_type == type_ + ) + id_ = session.execute(query).scalar() if id_ is None: raise exception.IdentityError('Identity (\'%s\',\'%s\') does not exist!' % (identity, type_)) id_.delete(session=session) @@ -128,10 +140,16 @@ def add_account_identity(identity: str, type_: IdentityType, account: InternalAc if not account_exists(account, session=session): raise exception.AccountNotFound('Account \'%s\' does not exist.' % account) - id_ = session.query(models.Identity).filter_by(identity=identity, identity_type=type_).first() + query = select( + models.Identity + ).where( + models.Identity.identity == identity, + models.Identity.identity_type == type_ + ) + id_ = session.execute(query).scalar() if id_ is None: add_identity(identity=identity, type_=type_, email=email, password=password, session=session) - id_ = session.query(models.Identity).filter_by(identity=identity, identity_type=type_).first() + id_ = session.execute(query).scalar() iaa = models.IdentityAccountAssociation(identity=id_.identity, identity_type=id_.identity_type, account=account, is_default=default) @@ -160,9 +178,14 @@ def exist_identity_account(identity: str, type_: IdentityType, account: Internal :returns: True if identity is mapped to account, otherwise False """ - return session.query(models.IdentityAccountAssociation).filter_by(identity=identity, - identity_type=type_, - account=account).first() is not None + query = select( + models.IdentityAccountAssociation + ).where( + models.IdentityAccountAssociation.identity == identity, + models.IdentityAccountAssociation.identity_type == type_, + models.IdentityAccountAssociation.account == account + ) + return session.execute(query).scalar() is not None @read_session @@ -179,14 +202,25 @@ def get_default_account(identity: str, type_: IdentityType, oldest_if_none: bool :returns: The default account name, None otherwise. """ - tmp = session.query(models.IdentityAccountAssociation).filter_by(identity=identity, - identity_type=type_, - is_default=True).first() + query = select( + models.IdentityAccountAssociation + ).where( + models.IdentityAccountAssociation.identity == identity, + models.IdentityAccountAssociation.identity_type == type_, + models.IdentityAccountAssociation.is_default == true() + ) + tmp = session.execute(query).scalar() if tmp is None: if oldest_if_none: - tmp = session.query(models.IdentityAccountAssociation)\ - .filter_by(identity=identity, identity_type=type_)\ - .order_by(asc(models.IdentityAccountAssociation.created_at)).first() + query = select( + models.IdentityAccountAssociation + ).where( + models.IdentityAccountAssociation.identity == identity, + models.IdentityAccountAssociation.identity_type == type_ + ).order_by( + models.IdentityAccountAssociation.created_at + ) + tmp = session.execute(query).scalar() if tmp is None: raise exception.IdentityError('There is no account for identity (%s, %s)' % (identity, type_)) else: @@ -205,7 +239,14 @@ def del_account_identity(identity: str, type_: IdentityType, account: InternalAc :param account: The account name. :param session: The database session in use. """ - aid = session.query(models.IdentityAccountAssociation).filter_by(identity=identity, identity_type=type_, account=account).first() + query = select( + models.IdentityAccountAssociation + ).where( + models.IdentityAccountAssociation.identity == identity, + models.IdentityAccountAssociation.identity_type == type_, + models.IdentityAccountAssociation.account == account + ) + aid = session.execute(query).scalar() if aid is None: raise exception.IdentityError('Identity (\'%s\',\'%s\') does not exist!' % (identity, type_)) aid.delete(session=session) @@ -220,13 +261,13 @@ def list_identities(*, session: "Session", **kwargs): returns: A list of all identities. """ - - id_list = [] - - for id_ in session.query(models.Identity).order_by(models.Identity.identity): - id_list.append((id_.identity, id_.identity_type)) - - return id_list + query = select( + models.Identity.identity, + models.Identity.identity_type + ).order_by( + models.Identity.identity + ) + return session.execute(query).all() @read_session @@ -240,10 +281,10 @@ def list_accounts_for_identity(identity: str, type_: IdentityType, *, session: " returns: A list of all accounts for the identity. """ - - account_list = [] - - for account, in session.query(models.IdentityAccountAssociation.account).filter_by(identity=identity, identity_type=type_): - account_list.append(account) - - return account_list + query = select( + models.IdentityAccountAssociation.account + ).where( + models.IdentityAccountAssociation.identity == identity, + models.IdentityAccountAssociation.identity_type == type_ + ) + return session.execute(query).scalars().all() diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index df75ce0794..73c846e0f7 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -32,7 +32,7 @@ Message, RegistrationResponse) from oic.utils import time_util from oic.utils.authn.client import CLIENT_AUTHN_METHOD -from sqlalchemy import and_ +from sqlalchemy import delete, select, update from sqlalchemy.sql.expression import true from rucio.common import types @@ -380,7 +380,12 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" state = parsed_authquery["state"][0] code = parsed_authquery["code"][0] # getting oauth request params from the oauth_requests DB Table - oauth_req_params = session.query(models.OAuthRequest).filter_by(state=state).first() + query = select( + models.OAuthRequest + ).where( + models.OAuthRequest.state == state + ) + oauth_req_params = session.execute(query).scalar() if oauth_req_params is None: raise CannotAuthenticate("User related Rucio OIDC session could not keep " + "track of responses from outstanding requests.") # NOQA: W503 @@ -475,15 +480,26 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" if 'http' not in oauth_req_params.access_msg: if '_polling' not in oauth_req_params.access_msg: fetchcode = rndstr(50) - session.query(models.OAuthRequest).filter(models.OAuthRequest.state == state)\ - .update({models.OAuthRequest.access_msg: fetchcode, - models.OAuthRequest.redirect_msg: new_token['token']}) + query = update( + models.OAuthRequest + ).where( + models.OAuthRequest.state == state + ).values({ + models.OAuthRequest.access_msg: fetchcode, + models.OAuthRequest.redirect_msg: new_token['token'] + }) # If Rucio Client was requested to poll the Rucio Auth server # for a token automatically, we save the token under a access_msg. else: - session.query(models.OAuthRequest).filter(models.OAuthRequest.state == state)\ - .update({models.OAuthRequest.access_msg: oauth_req_params.access_msg, - models.OAuthRequest.redirect_msg: new_token['token']}) + query = update( + models.OAuthRequest + ).where( + models.OAuthRequest.state == state + ).values({ + models.OAuthRequest.access_msg: oauth_req_params.access_msg, + models.OAuthRequest.redirect_msg: new_token['token'] + }) + session.execute(query) session.commit() METRICS.timer('IdP_authorization').observe(stopwatch.elapsed) if '_polling' in oauth_req_params.access_msg: @@ -578,9 +594,14 @@ def __get_admin_account_for_issuer(*, session: "Session"): issuer_account_dict = {} for issuer in OIDC_ADMIN_CLIENTS: admin_identity = oidc_identity_string(OIDC_ADMIN_CLIENTS[issuer].client_id, issuer) - admin_account = session.query(models.IdentityAccountAssociation)\ - .filter_by(identity_type=IdentityType.OIDC, identity=admin_identity).first() - issuer_account_dict[issuer] = (admin_account.account, admin_identity) + query = select( + models.IdentityAccountAssociation.account + ).where( + models.IdentityAccountAssociation.identity_type == IdentityType.OIDC, + models.IdentityAccountAssociation.identity == admin_identity + ) + admin_account = session.execute(query).scalar() + issuer_account_dict[issuer] = (admin_account, admin_identity) return issuer_account_dict @@ -606,16 +627,24 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ req_audience = EXPECTED_OIDC_AUDIENCE # get all identities for the corresponding account - identities_list = session.query(models.IdentityAccountAssociation.identity) \ - .filter(models.IdentityAccountAssociation.identity_type == IdentityType.OIDC, - models.IdentityAccountAssociation.account == account).all() - identities = [] - for identity in identities_list: - identities.append(identity[0]) + query = select( + models.IdentityAccountAssociation.identity + ).where( + models.IdentityAccountAssociation.identity_type == IdentityType.OIDC, + models.IdentityAccountAssociation.account == account + ) + identities = session.execute(query).scalars().all() # get all active/valid OIDC tokens - account_tokens = session.query(models.Token).filter(models.Token.identity.in_(identities), - models.Token.account == account, - models.Token.expired_at > datetime.utcnow()).with_for_update(skip_locked=True).all() + query = select( + models.Token + ).where( + models.Token.identity.in_(identities), + models.Token.account == account, + models.Token.expired_at > datetime.utcnow() + ).with_for_update( + skip_locked=True + ) + account_tokens = session.execute(query).scalars().all() # for Rucio Admin account we ask IdP for a token via client_credential grant # for each user account OIDC identity there is an OIDC issuer that must be, by construction, @@ -653,8 +682,13 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ if 'openid' in req_scope: req_scope = req_scope.replace("openid", "").strip() # checking if there is not already a token to use - admin_account_tokens = session.query(models.Token).filter(models.Token.account == account, - models.Token.expired_at > datetime.utcnow()).all() + query = select( + models.Token + ).where( + models.Token.account == account, + models.Token.expired_at > datetime.utcnow() + ) + admin_account_tokens = session.execute(query).scalars().all() for admin_token in admin_account_tokens: if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\ all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience): @@ -691,9 +725,14 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ admin_acc_idt_tuple = admin_iss_acc_idt_dict[admin_issuer] admin_account = admin_acc_idt_tuple[0] admin_identity = admin_acc_idt_tuple[1] - admin_account_tokens = session.query(models.Token).filter(models.Token.identity == admin_identity, - models.Token.account == admin_account, - models.Token.expired_at > datetime.utcnow()).all() + query = select( + models.Token + ).where( + models.Token.identity == admin_identity, + models.Token.account == admin_account, + models.Token.expired_at > datetime.utcnow() + ) + admin_account_tokens = session.execute(query).scalars().all() for admin_token in admin_account_tokens: if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\ all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience): @@ -831,15 +870,22 @@ def __change_refresh_state(token: str, refresh: bool = False, *, session: "Sessi :param token: the access token for which the refresh value should be changed. """ try: + query = update( + models.Token + ).where( + models.Token.token == token + ) if refresh: # update refresh column for a token to True - session.query(models.Token).filter(models.Token.token == token)\ - .update({models.Token.refresh: True}) + query = query.values({ + models.Token.refresh: True + }) else: - session.query(models.Token).filter(models.Token.token == token)\ - .update({models.Token.refresh: False, - models.Token.refresh_expired_at: datetime.utcnow()}) - session.commit() + query = query.values({ + models.Token.refresh: False, + models.Token.refresh_expired_at: datetime.utcnow() + }) + session.execute(query) except Exception as error: raise RucioException(error.args) from error @@ -856,11 +902,16 @@ def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session :return: tuple of (access token, expiration epoch), None otherswise """ # only validated tokens are in the DB, check presence of token_string - account_token = session.query(models.Token) \ - .filter(models.Token.token == token_string, - models.Token.account == account, - models.Token.expired_at > datetime.utcnow()) \ - .with_for_update(skip_locked=True).first() + query = select( + models.Token + ).where( + models.Token.token == token_string, + models.Token.account == account, + models.Token.expired_at > datetime.utcnow() + ).with_for_update( + skip_locked=True + ) + account_token = session.execute(query).scalar() # if token does not exist in the DB, return None if account_token is None: @@ -892,12 +943,17 @@ def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session else: # find account token with the same scope, # audience and has a valid refresh token - new_token = session.query(models.Token) \ - .filter(models.Token.refresh == true(), - models.Token.refresh_expired_at > datetime.utcnow(), - models.Token.account == account, - models.Token.expired_at > datetime.utcnow()) \ - .with_for_update(skip_locked=True).first() + query = select( + models.Token + ).where( + models.Token.refresh == true(), + models.Token.refresh_expired_at > datetime.utcnow(), + models.Token.account == account, + models.Token.expired_at > datetime.utcnow() + ).with_for_update( + skip_locked=True + ) + new_token = session.execute(query).scalar() if new_token is None: return None @@ -930,22 +986,31 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int try: # get tokens for refresh that expire in the next seconds expiration_future = datetime.utcnow() + timedelta(seconds=refreshrate) - query = session.query(models.Token.token) \ - .filter(and_(models.Token.refresh == true(), - models.Token.refresh_expired_at > datetime.utcnow(), - models.Token.expired_at < expiration_future))\ - .order_by(models.Token.expired_at) + query = select( + models.Token + ).where( + models.Token.refresh == true(), + models.Token.refresh_expired_at > datetime.utcnow(), + models.Token.expired_at < expiration_future + ).order_by( + models.Token.expired_at + ) query = filter_thread_work(session=session, query=query, total_threads=total_workers, thread_id=worker_number, hash_variable='token') - # limiting the number of tokens for refresh query = query.limit(limit) - filtered_tokens = [] - for items in session.execute(query).partitions(10): - tokens = tuple(map(lambda row: row.token, items)) - filtered_tokens += session.query(models.Token) \ - .filter(models.Token.token.in_(tokens)) \ - .with_for_update(skip_locked=True) \ - .all() + # Oracle does not support chaining order_by(), limit(), and + # with_for_update(). Use a nested query to overcome this. + if session.bind.dialect.name == 'oracle': + query = select( + models.Token + ).where( + models.Token.token.in_(query.with_only_columns(models.Token.token)) + ).with_for_update( + skip_locked=True + ) + else: + query = query.with_for_update(skip_locked=True) + filtered_tokens = session.execute(query).scalars().all() # refreshing these tokens for token in filtered_tokens: @@ -1052,20 +1117,37 @@ def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: try: # get expired OAuth request parameters - query = session.query(models.OAuthRequest.state).filter(models.OAuthRequest.expired_at < datetime.utcnow())\ - .order_by(models.OAuthRequest.expired_at) - + query = select( + models.OAuthRequest.state + ).where( + models.OAuthRequest.expired_at < datetime.utcnow() + ).order_by( + models.OAuthRequest.expired_at + ) query = filter_thread_work(session=session, query=query, total_threads=total_workers, thread_id=worker_number, hash_variable='state') - # limiting the number of oauth requests deleted at once query = query.limit(limit) + # Oracle does not support chaining order_by(), limit(), and + # with_for_update(). Use a nested query to overcome this. + if session.bind.dialect.name == 'oracle': + query = select( + models.OAuthRequest.state + ).where( + models.OAuthRequest.state.in_(query) + ).with_for_update( + skip_locked=True + ) + else: + query = query.with_for_update(skip_locked=True) + ndeleted = 0 - for items in session.execute(query).partitions(10): - states = tuple(map(lambda row: row.state, items)) - ndeleted += session.query(models.OAuthRequest) \ - .filter(models.OAuthRequest.state.in_(states)) \ - .with_for_update(skip_locked=True) \ - .delete(synchronize_session='fetch') + for states in session.execute(query).scalars().partitions(10): + query = delete( + models.OAuthRequest + ).where( + models.OAuthRequest.state.in_(states) + ) + ndeleted += session.execute(query).rowcount return ndeleted except Exception as error: raise RucioException(error.args) from error