Skip to content

Commit

Permalink
♻️ Remove deprecated from_orm method (Pydantic v2) (ITISFoundation#…
Browse files Browse the repository at this point in the history
  • Loading branch information
giancarloromeo authored Dec 2, 2024
1 parent 3dded08 commit 994c575
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def get(connection: SAConnection, project_uuid: uuid.UUID) -> ProjectMetad
row: RowProxy | None = await result.first()
if row is None:
raise DBProjectNotFoundError(project_uuid=project_uuid)
return ProjectMetadata.from_orm(row)
return ProjectMetadata.model_validate(row)


def _check_valid_ancestors_combination(
Expand Down Expand Up @@ -202,7 +202,7 @@ async def set_project_ancestors(
result: ResultProxy = await connection.execute(upsert_stmt)
row: RowProxy | None = await result.first()
assert row # nosec
return ProjectMetadata.from_orm(row)
return ProjectMetadata.model_validate(row)

except ForeignKeyViolation as err:
assert err.pgerror is not None # nosec # noqa: PT017
Expand Down Expand Up @@ -234,7 +234,7 @@ async def set_project_custom_metadata(
result: ResultProxy = await connection.execute(upsert_stmt)
row: RowProxy | None = await result.first()
assert row # nosec
return ProjectMetadata.from_orm(row)
return ProjectMetadata.model_validate(row)

except ForeignKeyViolation as err:
raise DBProjectNotFoundError(project_uuid=project_uuid) from err
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def add(
assert result # nosec
rows = await result.fetchall()
assert rows is not None # nosec
return [ProjectNode.from_orm(r) for r in rows]
return [ProjectNode.model_validate(r) for r in rows]
except ForeignKeyViolation as exc:
# this happens when the project does not exist, as we first check the node exists
raise ProjectNodesProjectNotFoundError(
Expand All @@ -128,7 +128,7 @@ async def list(self, connection: SAConnection) -> list[ProjectNode]:
assert result # nosec
rows = await result.fetchall()
assert rows is not None # nosec
return [ProjectNode.from_orm(row) for row in rows]
return [ProjectNode.model_validate(row) for row in rows]

async def get(self, connection: SAConnection, *, node_id: uuid.UUID) -> ProjectNode:
"""get a node in the current project
Expand All @@ -154,7 +154,7 @@ async def get(self, connection: SAConnection, *, node_id: uuid.UUID) -> ProjectN
project_uuid=self.project_uuid, node_id=node_id
)
assert row # nosec
return ProjectNode.from_orm(row)
return ProjectNode.model_validate(row)

async def update(
self, connection: SAConnection, *, node_id: uuid.UUID, **values
Expand Down Expand Up @@ -184,7 +184,7 @@ async def update(
project_uuid=self.project_uuid, node_id=node_id
)
assert row # nosec
return ProjectNode.from_orm(row)
return ProjectNode.model_validate(row)

async def delete(self, connection: SAConnection, *, node_id: uuid.UUID) -> None:
"""delete a node in the current project
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def creator(
.returning(sa.literal_column("*"))
)

inserted_project = ProjectAtDB.from_orm(await result.first())
inserted_project = ProjectAtDB.model_validate(await result.first())
project_nodes_repo = ProjectNodesRepo(project_uuid=project_uuid)
# NOTE: currently no resources is passed until it becomes necessary
default_node_config = {"required_resources": {}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def get_everyone_group(self) -> GroupAtDB:
raise UninitializedGroupError(
group=GroupType.EVERYONE, repo_cls=GroupsRepository
)
return GroupAtDB.from_orm(row)
return GroupAtDB.model_validate(row)

async def get_user_gid_from_email(
self, user_email: LowerCaseEmailStr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def list_services(

async with self.db_engine.connect() as conn:
return [
ServiceMetaDataAtDB.from_orm(row)
ServiceMetaDataAtDB.model_validate(row)
async for row in await conn.stream(
list_services_stmt(
gids=gids,
Expand Down Expand Up @@ -134,7 +134,7 @@ async def list_service_releases(

async with self.db_engine.connect() as conn:
releases = [
ServiceMetaDataAtDB.from_orm(row)
ServiceMetaDataAtDB.model_validate(row)
async for row in await conn.stream(query)
]

Expand Down Expand Up @@ -163,7 +163,7 @@ async def get_latest_release(self, key: str) -> ServiceMetaDataAtDB | None:
result = await conn.execute(query)
row = result.first()
if row:
return ServiceMetaDataAtDB.from_orm(row)
return ServiceMetaDataAtDB.model_validate(row)
return None # mypy

async def get_service(
Expand Down Expand Up @@ -208,7 +208,7 @@ async def get_service(
result = await conn.execute(query)
row = result.first()
if row:
return ServiceMetaDataAtDB.from_orm(row)
return ServiceMetaDataAtDB.model_validate(row)
return None # mypy

async def create_or_update_service(
Expand All @@ -234,7 +234,7 @@ async def create_or_update_service(
)
row = result.first()
assert row # nosec
created_service = ServiceMetaDataAtDB.from_orm(row)
created_service = ServiceMetaDataAtDB.model_validate(row)

for access_rights in new_service_access_rights:
insert_stmt = pg_insert(services_access_rights).values(
Expand Down Expand Up @@ -468,7 +468,7 @@ async def get_service_access_rights(

async with self.db_engine.connect() as conn:
return [
ServiceAccessRightsAtDB.from_orm(row)
ServiceAccessRightsAtDB.model_validate(row)
async for row in await conn.stream(query)
]

Expand All @@ -494,7 +494,7 @@ async def list_services_access_rights(
async with self.db_engine.connect() as conn:
async for row in await conn.stream(query):
service_to_access_rights[(row.key, row.version)].append(
ServiceAccessRightsAtDB.from_orm(row)
ServiceAccessRightsAtDB.model_validate(row)
)
return service_to_access_rights

Expand Down Expand Up @@ -585,7 +585,7 @@ async def get_service_specifications(
try:
_logger.debug("found following %s", f"{row=}")
# validate the specs first
db_service_spec = ServiceSpecificationsAtDB.from_orm(row)
db_service_spec = ServiceSpecificationsAtDB.model_validate(row)
db_spec_version = packaging.version.parse(
db_service_spec.service_version
)
Expand Down

0 comments on commit 994c575

Please sign in to comment.