Skip to content

Commit

Permalink
Fixed view update and errors
Browse files Browse the repository at this point in the history
  • Loading branch information
khoroshevskyi committed Feb 5, 2024
1 parent e785fd4 commit 9b489ce
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 63 deletions.
5 changes: 5 additions & 0 deletions pepdbagent/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def __init__(self, msg=""):
super().__init__(f"""View does not exist. {msg}""")


class SampleNotInViewError(PEPDatabaseAgentError):
def __init__(self, msg=""):
super().__init__(f"""Sample is not in the view. {msg}""")


class SampleAlreadyInView(PEPDatabaseAgentError):
"""
Sample is already in the view exception
Expand Down
121 changes: 58 additions & 63 deletions pepdbagent/modules/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ProjectNotFoundError,
SampleNotFoundError,
ViewAlreadyExistsError,
SampleNotInViewError,
)

from pepdbagent.db_utils import BaseEngine, Samples, Projects, Views, ViewSampleAssociation
Expand Down Expand Up @@ -163,41 +164,40 @@ def create(
Projects.tag == view_dict.project_tag,
)
)

with Session(self._sa_engine) as sa_session:
project = sa_session.scalar(project_statement)
if not project:
raise ProjectNotFoundError(
f"Project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} does not exist"
)
view = Views(
name=view_name,
description=description,
project_mapping=project,
)
sa_session.add(view)

for sample_name in view_dict.sample_list:
sample_statement = select(Samples.id).where(
and_(
Samples.project_id == project.id,
Samples.sample_name == sample_name,
try:
with Session(self._sa_engine) as sa_session:
project = sa_session.scalar(project_statement)
if not project:
raise ProjectNotFoundError(
f"Project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} does not exist"
)
view = Views(
name=view_name,
description=description,
project_mapping=project,
)
sample_id = sa_session.execute(sample_statement).one()[0]
if not sample_id:
raise SampleNotFoundError(
f"Sample {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag}:{sample_name} does not exist"
sa_session.add(view)

for sample_name in view_dict.sample_list:
sample_statement = select(Samples.id).where(
and_(
Samples.project_id == project.id,
Samples.sample_name == sample_name,
)
)
try:
sa_session.add(ViewSampleAssociation(sample_id=sample_id, view=view))
sample_id = sa_session.execute(sample_statement).one()[0]
if not sample_id:
raise SampleNotFoundError(
f"Sample {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag}:{sample_name} does not exist"
)

except IntegrityError:
raise ViewAlreadyExistsError(
f"View {view_name} of the project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} already exists"
)
sa_session.add(ViewSampleAssociation(sample_id=sample_id, view=view))

sa_session.commit()
sa_session.commit()
except IntegrityError:
raise ViewAlreadyExistsError(
f"View {view_name} of the project {view_dict.project_namespace}/{view_dict.project_name}:{view_dict.project_tag} already exists"
)

def delete(
self,
Expand Down Expand Up @@ -265,34 +265,32 @@ def add_sample(
Views.name == view_name,
)
)

with Session(self._sa_engine) as sa_session:
view = sa_session.scalar(view_statement)
if not view:
raise ViewNotFoundError(
f"View {view_name} of the project {namespace}/{name}:{tag} does not exist"
)
for sample_name_one in sample_name:
sample_statement = select(Samples).where(
and_(
Samples.project_id == view.project_mapping.id,
Samples.sample_name == sample_name_one,
try:
with Session(self._sa_engine) as sa_session:
view = sa_session.scalar(view_statement)
if not view:
raise ViewNotFoundError(
f"View {view_name} of the project {namespace}/{name}:{tag} does not exist"
)
)
sample = sa_session.scalar(sample_statement)
if not sample:
raise SampleNotFoundError(
f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist"
for sample_name_one in sample_name:
sample_statement = select(Samples).where(
and_(
Samples.project_id == view.project_mapping.id,
Samples.sample_name == sample_name_one,
)
)
try:
sample = sa_session.scalar(sample_statement)
if not sample:
raise SampleNotFoundError(
f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist"
)

sa_session.add(ViewSampleAssociation(sample=sample, view=view))
sa_session.commit()
except IntegrityError:
raise SampleAlreadyInView(
f"Sample {namespace}/{name}:{tag}:{sample_name} already in view {view_name}"
)

return None
except IntegrityError:
raise SampleAlreadyInView(
f"Sample {namespace}/{name}:{tag}:{sample_name} already in view {view_name}"
)

def remove_sample(
self,
Expand Down Expand Up @@ -335,21 +333,18 @@ def remove_sample(
)
)
sample = sa_session.scalar(sample_statement)
if sample.id not in [view_sample.sample_id for view_sample in view.samples]:
raise SampleNotInViewError(
f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist in view {view_name}"
)
delete_statement = delete(ViewSampleAssociation).where(
and_(
ViewSampleAssociation.sample_id == sample.id,
ViewSampleAssociation.view_id == view.id,
)
)
try:
sa_session.execute(delete_statement)
sa_session.commit()
except IntegrityError:
raise SampleNotFoundError(
f"Sample {namespace}/{name}:{tag}:{sample_name} does not exist in view {view_name}"
)

return None
sa_session.execute(delete_statement)
sa_session.commit()

def get_snap_view(
self, namespace: str, name: str, tag: str, sample_name_list: List[str], raw: bool = False
Expand Down
4 changes: 4 additions & 0 deletions tests/test_pepagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SampleNotFoundError,
ViewNotFoundError,
SampleAlreadyInView,
SampleNotInViewError,
)
from .conftest import DNS

Expand Down Expand Up @@ -1189,6 +1190,9 @@ def test_remove_sample_from_view(self, initiate_pepdb_con, namespace, name, samp
assert len(initiate_pepdb_con.view.get(namespace, name, "default", "view1").samples) == 1
assert len(initiate_pepdb_con.project.get(namespace, name).samples) == 4

with pytest.raises(SampleNotInViewError):
initiate_pepdb_con.view.remove_sample(namespace, name, "default", "view1", sample_name)

@pytest.mark.parametrize(
"namespace, name, sample_name",
[
Expand Down

0 comments on commit 9b489ce

Please sign in to comment.