diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 95d4013e3a8f27..30ea1f1de4dbce 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -537,6 +537,10 @@ def get(self): .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) .all() ) + + # Only return keys created by current user or created_by none (for old keys do not have created_by) + keys = [key for key in keys if key.created_by == current_user.id or key.created_by is None] + return {"items": keys} @setup_required @@ -548,12 +552,15 @@ def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = ( + keys = ( db.session.query(ApiToken) .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) - .count() + .all() ) + # Only count keys created by current user or created_by none (for old keys do not have created_by) + current_key_count = len([key for key in keys if key.created_by == current_user.id or key.created_by is None]) + if current_key_count >= self.max_keys: flask_restful.abort( 400, @@ -566,6 +573,7 @@ def post(self): api_token.tenant_id = current_user.current_tenant_id api_token.token = key api_token.type = self.resource_type + api_token.created_by = current_user.id db.session.add(api_token) db.session.commit() return api_token, 200 diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 2128c4c53f9909..62c8788ca603b3 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -140,14 +140,29 @@ def decorator(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token("dataset") - tenant_account_join = ( + + # Build base query + query = ( db.session.query(Tenant, TenantAccountJoin) .filter(Tenant.id == api_token.tenant_id) .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role.in_(["owner"])) .filter(Tenant.status == TenantStatus.NORMAL) - .one_or_none() - ) # TODO: only owner information is required, so only one is returned. + ) + + if api_token.created_by: + # Only apply account_id filter if created_by exists + query = query.filter( + db.and_( + TenantAccountJoin.role.in_(["owner", "admin"]), + TenantAccountJoin.account_id == api_token.created_by, + ) + ) + else: + query = query.filter(TenantAccountJoin.role.in_(["owner"])) + + tenant_account_join = query.one_or_none() + # TODO: only owner information is required, so only one is returned. + if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() diff --git a/api/migrations/versions/2024_12_07_1936-8db4e7683504_add_created_by_to_api_tokens.py b/api/migrations/versions/2024_12_07_1936-8db4e7683504_add_created_by_to_api_tokens.py new file mode 100644 index 00000000000000..a7c833b8ff2ded --- /dev/null +++ b/api/migrations/versions/2024_12_07_1936-8db4e7683504_add_created_by_to_api_tokens.py @@ -0,0 +1,35 @@ +"""add created_by to api_tokens + +Revision ID: 8db4e7683504 +Revises: 01d6889832f7 +Create Date: 2024-12-07 19:36:49.632151 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '8db4e7683504' +down_revision = '01d6889832f7' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.add_column(sa.Column('created_by', models.types.StringUUID(), nullable=True)) + batch_op.create_index('api_token_created_by_idx', ['created_by'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('api_tokens', schema=None) as batch_op: + batch_op.drop_index('api_token_created_by_idx') + batch_op.drop_column('created_by') + + # ### end Alembic commands ### diff --git a/api/models/model.py b/api/models/model.py index 03b8e0bea553aa..8dfa0dc162a759 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1332,6 +1332,7 @@ class ApiToken(db.Model): db.Index("api_token_app_id_type_idx", "app_id", "type"), db.Index("api_token_token_idx", "token", "type"), db.Index("api_token_tenant_idx", "tenant_id", "type"), + db.Index("api_token_created_by_idx", "created_by"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) @@ -1340,6 +1341,7 @@ class ApiToken(db.Model): type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) + created_by = db.Column(StringUUID, nullable=True) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) @staticmethod