Skip to content

Commit

Permalink
Connection caching (#78)
Browse files Browse the repository at this point in the history
* test and update sql connector

* lint

* add test

* pr review

* pr review

Co-authored-by: ncgl-syngenta <[email protected]>
  • Loading branch information
ncgl-syngenta and ncgl-syngenta authored Oct 8, 2021
1 parent 22a173e commit 1e6bd17
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
3 changes: 2 additions & 1 deletion syngenta_digital_dta/postgres/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from syngenta_digital_dta.common import schema_mapper
from syngenta_digital_dta.common.base_adapter import BaseAdapter
from syngenta_digital_dta.postgres.sql_connection import sql_connection
from syngenta_digital_dta.postgres.sql_connector import SQLConnector


class PostgresAdapter(BaseAdapter):
Expand All @@ -31,7 +32,7 @@ def __init__(self, **kwargs):
self.cursor = None

@sql_connection
def connect(self, connector):
def connect(self, connector: SQLConnector):
self.connection = connector.connect()
self.cursor = connector.cursor()

Expand Down
16 changes: 12 additions & 4 deletions syngenta_digital_dta/postgres/sql_connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import typing

from syngenta_digital_dta.postgres.sql_connector import SQLConnector


def sql_connection(func):
def sql_connection(func: typing.Callable) -> typing.Callable:
__connections = {}
def decorator(obj):
if not __connections.get(obj.database):
__connections[obj.database] = SQLConnector(obj)

def decorator(obj: typing.Union["PostgresAdapter", "RedshiftAdapter"]):

# reuse the existing connection if it isn't closed
if __connections.get(obj.database) and __connections[obj.database].connection and not __connections[obj.database].connection.closed:
return func(obj, __connections[obj.database])

__connections[obj.database] = SQLConnector(obj)
return func(obj, __connections[obj.database])

return decorator
8 changes: 8 additions & 0 deletions tests/syngenta_digital_dta/postgres/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,11 @@ def test_query_no_params_constraint(self):
results = self.user_adapter.query(query='SELECT * FROM users WHERE user_id =1')
except Exception as error:
self.assertEqual(str(error), 'params kwargs are required to prevent sql inject; send empty dict if not needed')

def test_sql_connector(self):
id_conn_first = id(self.user_adapter.connection)
self.user_adapter.connection.close()
self.user_adapter.connect()
id_conn_second = id(self.user_adapter.connection)

self.assertNotEqual(id_conn_first, id_conn_second)

0 comments on commit 1e6bd17

Please sign in to comment.