From 915175db75b3e1f86a7bfaa5907fae263934d3a7 Mon Sep 17 00:00:00 2001 From: Marco Salazar Date: Mon, 29 Aug 2022 22:51:58 +0200 Subject: [PATCH] #56 #121 #124: Initializes boto3 session globally to support configured AWS profile when calling boto3 --- dbt/adapters/athena/connections.py | 4 ++-- dbt/adapters/athena/impl.py | 14 ++++++-------- dbt/adapters/athena/session.py | 19 +++++++++++++++++++ 3 files changed, 27 insertions(+), 10 deletions(-) create mode 100644 dbt/adapters/athena/session.py diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index dab4bf95..3575381e 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -21,6 +21,7 @@ from dbt.adapters.sql import SQLConnectionManager from dbt.exceptions import RuntimeException, FailedToConnectException from dbt.events import AdapterLogger +from dbt.adapters.athena.session import get_boto3_session import tenacity from tenacity.retry import retry_if_exception @@ -140,13 +141,12 @@ def open(cls, connection: Connection) -> Connection: handle = AthenaConnection( s3_staging_dir=creds.s3_staging_dir, endpoint_url=creds.endpoint_url, - region_name=creds.region_name, schema_name=creds.schema, work_group=creds.work_group, cursor_class=AthenaCursor, formatter=AthenaParameterFormatter(), poll_interval=creds.poll_interval, - profile_name=creds.aws_profile_name, + session=get_boto3_session(connection), retry_config=RetryConfig( attempt=creds.num_retries, exceptions=( diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 165c77b2..92e696aa 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,6 +1,5 @@ import agate import re -import boto3 from botocore.exceptions import ClientError from itertools import chain from threading import Lock @@ -59,10 +58,9 @@ def clean_up_partitions( # Look up Glue partitions & clean up conn = self.connections.get_thread_connection() client = conn.handle - with boto3_client_lock: - glue_client = boto3.client('glue', region_name=client.region_name) - s3_resource = boto3.resource('s3', region_name=client.region_name) + glue_client = client.session.client('glue') + s3_resource = client.session.resource('s3') partitions = glue_client.get_partitions( # CatalogId='123456789012', # Need to make this configurable if it is different from default AWS Account ID DatabaseName=database_name, @@ -87,7 +85,7 @@ def clean_up_table( conn = self.connections.get_thread_connection() client = conn.handle with boto3_client_lock: - glue_client = boto3.client('glue', region_name=client.region_name) + glue_client = client.session.client('glue') try: table = glue_client.get_table( DatabaseName=database_name, @@ -105,7 +103,7 @@ def clean_up_table( if m is not None: bucket_name = m.group(1) prefix = m.group(2) - s3_resource = boto3.resource('s3', region_name=client.region_name) + s3_resource = client.session.resource('s3') s3_bucket = s3_resource.Bucket(bucket_name) s3_bucket.objects.filter(Prefix=prefix).delete() @@ -152,7 +150,7 @@ def _get_data_catalog(self, catalog_name): conn = self.connections.get_thread_connection() client = conn.handle with boto3_client_lock: - athena_client = boto3.client('athena', region_name=client.region_name) + athena_client = client.session.client('athena') response = athena_client.get_data_catalog(Name=catalog_name) return response['DataCatalog'] @@ -172,7 +170,7 @@ def list_relations_without_caching( conn = self.connections.get_thread_connection() client = conn.handle with boto3_client_lock: - glue_client = boto3.client('glue', region_name=client.region_name) + glue_client = client.session.client('glue') paginator = glue_client.get_paginator('get_tables') kwargs = { diff --git a/dbt/adapters/athena/session.py b/dbt/adapters/athena/session.py new file mode 100644 index 00000000..dd26620f --- /dev/null +++ b/dbt/adapters/athena/session.py @@ -0,0 +1,19 @@ +import boto3.session +from dbt.contracts.connection import Connection + + +__BOTO3_SESSION__: boto3.session.Session = None + + +def get_boto3_session(connection: Connection) -> boto3.session.Session: + def init_session(): + global __BOTO3_SESSION__ + __BOTO3_SESSION__ = boto3.session.Session( + region_name=connection.credentials.region_name, + profile_name=connection.credentials.aws_profile_name, + ) + + if not __BOTO3_SESSION__: + init_session() + + return __BOTO3_SESSION__