Skip to content

Commit

Permalink
chore: improve the drop_relation default macro (Tomme#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
nicor88 authored Jan 20, 2023
1 parent ef31df4 commit eac24b8
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 5 deletions.
37 changes: 37 additions & 0 deletions dbt/adapters/athena/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ class AthenaAdapter(SQLAdapter):
ConnectionManager = AthenaConnectionManager
Relation = AthenaRelation

relation_type_map = {
"EXTERNAL_TABLE": "table",
"MANAGED_TABLE": "table",
"VIRTUAL_VIEW": "view",
"table": "table",
"view": "view",
"cte": "cte",
"materializedview": "materializedview",
}

@classmethod
def date_function(cls) -> str:
return "now()"
Expand Down Expand Up @@ -309,3 +319,30 @@ def list_relations_without_caching(
)

return relations

@available
def get_table_type(self, db_name, table_name):
conn = self.connections.get_thread_connection()
client = conn.handle

with boto3_client_lock:
glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config())

try:
response = glue_client.get_table(DatabaseName=db_name, Name=table_name)
_type = self.relation_type_map.get(response.get("Table", {}).get("TableType", "Table"))
_specific_type = response.get("Table", {}).get("Parameters", {}).get("table_type", "")

if _specific_type.lower() == "iceberg":
_type = "iceberg_table"

if _type is None:
raise ValueError("Table type cannot be None")

logger.debug("table_name : " + table_name)
logger.debug("table type : " + _type)

return _type

except glue_client.exceptions.EntityNotFoundException as e:
logger.debug(e)
9 changes: 5 additions & 4 deletions dbt/include/athena/macros/adapters/relation.sql
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
{% macro drop_relation(relation) -%}
{% if config.get('table_type') != 'iceberg' %}
{% macro athena__drop_relation(relation) -%}
{% set rel_type = adapter.get_table_type(relation.schema, relation.table) %}
{%- if rel_type is not none and rel_type == 'table' %}
{%- do adapter.clean_up_table(relation.schema, relation.table) -%}
{% endif %}
{% call statement('drop_relation', auto_begin=False) -%}
{%- endif %}
{% call statement('drop_relation', auto_begin=False) -%}
drop {{ relation.type }} if exists {{ relation }}
{%- endcall %}
{% endmacro %}
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,51 @@ def test__get_data_catalog(self, aws_credentials):
res = self.adapter._get_data_catalog(DATA_CATALOG_NAME)
assert {"Name": "awsdatacatalog", "Type": "GLUE", "Parameters": {"catalog-id": "catalog_id"}} == res

@mock_glue
@mock_s3
@mock_athena
def test__get_relation_type_table(self, aws_credentials):
self.mock_aws_service.create_data_catalog()
self.mock_aws_service.create_database()
self.mock_aws_service.create_table("test_table")
self.adapter.acquire_connection("dummy")
table_type = self.adapter.get_table_type(DATABASE_NAME, "test_table")
assert table_type == "table"

@mock_glue
@mock_s3
@mock_athena
def test__get_relation_type_with_no_type(self, aws_credentials):
self.mock_aws_service.create_data_catalog()
self.mock_aws_service.create_database()
self.mock_aws_service.create_table_without_table_type("test_table")
self.adapter.acquire_connection("dummy")

with pytest.raises(ValueError):
self.adapter.get_table_type(DATABASE_NAME, "test_table")

@mock_glue
@mock_s3
@mock_athena
def test__get_relation_type_view(self, aws_credentials):
self.mock_aws_service.create_data_catalog()
self.mock_aws_service.create_database()
self.mock_aws_service.create_view("test_view")
self.adapter.acquire_connection("dummy")
table_type = self.adapter.get_table_type(DATABASE_NAME, "test_view")
assert table_type == "view"

@mock_glue
@mock_s3
@mock_athena
def test__get_relation_type_iceberg(self, aws_credentials):
self.mock_aws_service.create_data_catalog()
self.mock_aws_service.create_database()
self.mock_aws_service.create_iceberg_table("test_iceberg")
self.adapter.acquire_connection("dummy")
table_type = self.adapter.get_table_type(DATABASE_NAME, "test_iceberg")
assert table_type == "iceberg_table"

def _test_list_relations_without_caching(self, schema_relation):
self.adapter.acquire_connection("dummy")
relations = self.adapter.list_relations_without_caching(schema_relation)
Expand Down
39 changes: 38 additions & 1 deletion tests/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,44 @@ def create_table(self, table_name: str):
"Type": "date",
},
],
"TableType": "Table",
"TableType": "table",
},
)

def create_iceberg_table(self, table_name: str):
glue = boto3.client("glue", region_name=AWS_REGION)
glue.create_table(
DatabaseName=DATABASE_NAME,
TableInput={
"Name": table_name,
"StorageDescriptor": {
"Columns": [
{
"Name": "id",
"Type": "string",
},
{
"Name": "country",
"Type": "string",
},
{
"Name": "dt",
"Type": "date",
},
],
"Location": f"s3://{BUCKET}/tables/data/{table_name}",
},
"PartitionKeys": [
{
"Name": "dt",
"Type": "date",
},
],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"metadata_location": f"s3://{BUCKET}/tables/metadata/{table_name}/123.json",
"table_type": "iceberg",
},
},
)

Expand Down

0 comments on commit eac24b8

Please sign in to comment.