diff --git a/space2stats_api/app/utils/db_utils.py b/space2stats_api/app/utils/db_utils.py index 5457fc0..cc2a843 100644 --- a/space2stats_api/app/utils/db_utils.py +++ b/space2stats_api/app/utils/db_utils.py @@ -10,12 +10,18 @@ def get_summaries(fields, h3_ids): - h3_ids_str = ", ".join(f"'{h3_id}'" for h3_id in h3_ids) - sql_query = f""" - SELECT hex_id, {', '.join(fields)} - FROM {DB_TABLE_NAME} - WHERE hex_id IN ({h3_ids_str}) - """ + colnames = ['hex_id'] + fields + cols = [pg.sql.Identifier(c) for c in colnames] + sql_query = pg.sql.SQL( + """ + SELECT {0} + FROM {1} + WHERE hex_id = ANY (%s) + """ + ).format( + pg.sql.SQL(', ').join(cols), + pg.sql.Identifier(DB_TABLE_NAME) + ) try: conn = pg.connect( host=DB_HOST, @@ -25,7 +31,7 @@ def get_summaries(fields, h3_ids): password=DB_PASSWORD, ) cur = conn.cursor() - cur.execute(sql_query) + cur.execute(sql_query, [h3_ids,]) rows = cur.fetchall() colnames = [desc[0] for desc in cur.description] cur.close() @@ -37,10 +43,10 @@ def get_summaries(fields, h3_ids): def get_available_fields(): - sql_query = f""" + sql_query = """ SELECT column_name FROM information_schema.columns - WHERE table_name = '{DB_TABLE_NAME}' + WHERE table_name = %s """ try: conn = pg.connect( @@ -51,7 +57,7 @@ def get_available_fields(): password=DB_PASSWORD, ) cur = conn.cursor() - cur.execute(sql_query) + cur.execute(sql_query, [DB_TABLE_NAME,]) columns = [row[0] for row in cur.fetchall() if row[0] != "hex_id"] cur.close() conn.close() diff --git a/space2stats_api/tests/test_db_utils.py b/space2stats_api/tests/test_db_utils.py index 0047b05..a3d0b97 100644 --- a/space2stats_api/tests/test_db_utils.py +++ b/space2stats_api/tests/test_db_utils.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch, Mock from app.utils.db_utils import get_summaries, get_available_fields - +from psycopg.sql import SQL, Identifier @patch("psycopg.connect") def test_get_summaries(mock_connect): @@ -18,18 +18,21 @@ def test_get_summaries(mock_connect): rows, colnames = get_summaries(fields, h3_ids) mock_connect.assert_called_once() - mock_cursor.execute.assert_called_once_with( - f""" - SELECT hex_id, {', '.join(fields)} - FROM space2stats - WHERE hex_id IN ('hex_1') - """ + sql_query = SQL( + """ + SELECT {0} + FROM {1} + WHERE hex_id = ANY (%s) + """ + ).format( + SQL(', ').join([Identifier(c) for c in ['hex_id'] + fields]), + Identifier("space2stats") ) + mock_cursor.execute.assert_called_once_with(sql_query, [h3_ids]) assert rows == [("hex_1", 100, 200)] assert colnames == ["hex_id", "field1", "field2"] - @patch("psycopg.connect") def test_get_available_fields(mock_connect): mock_conn = Mock() @@ -43,15 +46,15 @@ def test_get_available_fields(mock_connect): mock_connect.assert_called_once() mock_cursor.execute.assert_called_once_with( - f""" + """ SELECT column_name FROM information_schema.columns - WHERE table_name = 'space2stats' - """ + WHERE table_name = %s + """, + ["space2stats"] ) assert columns == ["field1", "field2", "field3"] - if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file