Skip to content

Commit

Permalink
Fixes leaked db connection issue on auth
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Dec 13, 2023
1 parent 4cd0e34 commit 65b9e53
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
3 changes: 1 addition & 2 deletions API/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ class AuthUser(BaseModel):

osm_auth = Auth(*get_oauth_credentials())

auth = Users()


def get_user_from_db(osm_id: int):
auth = Users()
user = auth.read_user(osm_id)
return user

Expand Down
8 changes: 5 additions & 3 deletions API/auth/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ class User(BaseModel):
role: int


auth = Users()


# Create user
@router.post("/users/", response_model=dict)
async def create_user(params: User, user_data: AuthUser = Depends(admin_required)):
Expand All @@ -77,6 +74,7 @@ async def create_user(params: User, user_data: AuthUser = Depends(admin_required
Raises:
- HTTPException: If the user creation fails.
"""
auth = Users()
return auth.create_user(params.osm_id, params.role)


Expand All @@ -95,6 +93,8 @@ async def read_user(osm_id: int, user_data: AuthUser = Depends(staff_required)):
Raises:
- HTTPException: If the user with the given osm_id is not found.
"""
auth = Users()

return auth.read_user(osm_id)


Expand Down Expand Up @@ -134,6 +134,7 @@ async def delete_user(osm_id: int, user_data: AuthUser = Depends(admin_required)
Raises:
- HTTPException: If the user with the given osm_id is not found.
"""
auth = Users()
return auth.delete_user(osm_id)


Expand All @@ -152,4 +153,5 @@ async def read_users(
Returns:
- List[Dict[str, Any]]: A list of dictionaries containing user information.
"""
auth = Users()
return auth.read_users(skip, limit)
6 changes: 5 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def create_user(self, osm_id, role):
self.cur.execute(self.cur.mogrify(query, params).decode("utf-8"))
new_osm_id = self.cur.fetchall()[0][0]
self.con.commit()
self.d_b.close_conn()
return {"osm_id": new_osm_id}

def read_user(self, osm_id):
Expand All @@ -269,7 +270,7 @@ def read_user(self, osm_id):
params = (osm_id,)
self.cur.execute(self.cur.mogrify(query, params).decode("utf-8"))
result = self.cur.fetchall()

self.d_b.close_conn()
if result:
return dict(result[0])
else:
Expand All @@ -295,6 +296,7 @@ def update_user(self, osm_id, update_data):
self.cur.execute(self.cur.mogrify(query, params).decode("utf-8"))
updated_user = self.cur.fetchall()
self.con.commit()
self.d_b.close_conn()
if updated_user:
return dict(updated_user[0])
raise HTTPException(status_code=404, detail="User not found")
Expand All @@ -317,6 +319,7 @@ def delete_user(self, osm_id):
self.cur.execute(self.cur.mogrify(query, params).decode("utf-8"))
deleted_user = self.cur.fetchall()
self.con.commit()
self.d_b.close_conn()
if deleted_user:
return dict(deleted_user[0])
raise HTTPException(status_code=404, detail="User not found")
Expand All @@ -336,6 +339,7 @@ def read_users(self, skip=0, limit=10):
params = (skip, limit)
self.cur.execute(self.cur.mogrify(query, params).decode("utf-8"))
users_list = self.cur.fetchall()
self.d_b.close_conn()
return [dict(user) for user in users_list]


Expand Down

0 comments on commit 65b9e53

Please sign in to comment.