diff --git a/src/db_client.py b/src/db_client.py index 77265c4..31a986a 100644 --- a/src/db_client.py +++ b/src/db_client.py @@ -1,12 +1,14 @@ import sqlite3 +from contextlib import closing + class DBClient: def __init__(self, path): self.path = path - self.con = sqlite3.connect(path) - self.con.row_factory = sqlite3.Row def select(self, table, rows=['*'], where='1'): - cursor = self.con.cursor() - cursor.execute(f'SELECT {", ".join(rows)} FROM {table} WHERE {where};') - return cursor.fetchall() + with closing(sqlite3.connect(self.path)) as conn: + conn.row_factory = sqlite3.Row + with closing(conn.cursor()) as cursor: + cursor.execute(f'SELECT {", ".join(rows)} FROM {table} WHERE {where};') + return cursor.fetchall()