From 65b9e5387c3ebb415fd10d01c747bf4c638611af Mon Sep 17 00:00:00 2001 From: kshitijrajsharma Date: Wed, 13 Dec 2023 09:28:31 +0545 Subject: [PATCH] Fixes leaked db connection issue on auth --- API/auth/__init__.py | 3 +-- API/auth/routers.py | 8 +++++--- src/app.py | 6 +++++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/API/auth/__init__.py b/API/auth/__init__.py index 836ae170..0ff71d79 100644 --- a/API/auth/__init__.py +++ b/API/auth/__init__.py @@ -24,10 +24,9 @@ class AuthUser(BaseModel): osm_auth = Auth(*get_oauth_credentials()) -auth = Users() - def get_user_from_db(osm_id: int): + auth = Users() user = auth.read_user(osm_id) return user diff --git a/API/auth/routers.py b/API/auth/routers.py index c572f420..dd66ffb4 100644 --- a/API/auth/routers.py +++ b/API/auth/routers.py @@ -59,9 +59,6 @@ class User(BaseModel): role: int -auth = Users() - - # Create user @router.post("/users/", response_model=dict) async def create_user(params: User, user_data: AuthUser = Depends(admin_required)): @@ -77,6 +74,7 @@ async def create_user(params: User, user_data: AuthUser = Depends(admin_required Raises: - HTTPException: If the user creation fails. """ + auth = Users() return auth.create_user(params.osm_id, params.role) @@ -95,6 +93,8 @@ async def read_user(osm_id: int, user_data: AuthUser = Depends(staff_required)): Raises: - HTTPException: If the user with the given osm_id is not found. """ + auth = Users() + return auth.read_user(osm_id) @@ -134,6 +134,7 @@ async def delete_user(osm_id: int, user_data: AuthUser = Depends(admin_required) Raises: - HTTPException: If the user with the given osm_id is not found. """ + auth = Users() return auth.delete_user(osm_id) @@ -152,4 +153,5 @@ async def read_users( Returns: - List[Dict[str, Any]]: A list of dictionaries containing user information. """ + auth = Users() return auth.read_users(skip, limit) diff --git a/src/app.py b/src/app.py index 085fecfd..ecd518b0 100644 --- a/src/app.py +++ b/src/app.py @@ -249,6 +249,7 @@ def create_user(self, osm_id, role): self.cur.execute(self.cur.mogrify(query, params).decode("utf-8")) new_osm_id = self.cur.fetchall()[0][0] self.con.commit() + self.d_b.close_conn() return {"osm_id": new_osm_id} def read_user(self, osm_id): @@ -269,7 +270,7 @@ def read_user(self, osm_id): params = (osm_id,) self.cur.execute(self.cur.mogrify(query, params).decode("utf-8")) result = self.cur.fetchall() - + self.d_b.close_conn() if result: return dict(result[0]) else: @@ -295,6 +296,7 @@ def update_user(self, osm_id, update_data): self.cur.execute(self.cur.mogrify(query, params).decode("utf-8")) updated_user = self.cur.fetchall() self.con.commit() + self.d_b.close_conn() if updated_user: return dict(updated_user[0]) raise HTTPException(status_code=404, detail="User not found") @@ -317,6 +319,7 @@ def delete_user(self, osm_id): self.cur.execute(self.cur.mogrify(query, params).decode("utf-8")) deleted_user = self.cur.fetchall() self.con.commit() + self.d_b.close_conn() if deleted_user: return dict(deleted_user[0]) raise HTTPException(status_code=404, detail="User not found") @@ -336,6 +339,7 @@ def read_users(self, skip=0, limit=10): params = (skip, limit) self.cur.execute(self.cur.mogrify(query, params).decode("utf-8")) users_list = self.cur.fetchall() + self.d_b.close_conn() return [dict(user) for user in users_list]