Skip to content

Commit

Permalink
Order tags by name
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Dec 13, 2024
1 parent c9fbae3 commit d74e50e
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 39 deletions.
2 changes: 2 additions & 0 deletions src/zenml/zen_stores/schemas/artifact_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ArtifactSchema(NamedSchema, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)

Expand Down Expand Up @@ -194,6 +195,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)
save_type: str = Field(sa_column=Column(TEXT, nullable=False))
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/zen_stores/schemas/model_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ModelSchema(NamedSchema, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)
model_versions: List["ModelVersionSchema"] = Relationship(
Expand Down Expand Up @@ -301,6 +302,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)

Expand Down
1 change: 1 addition & 0 deletions src/zenml/zen_stores/schemas/pipeline_run_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)

Expand Down
1 change: 1 addition & 0 deletions src/zenml/zen_stores/schemas/pipeline_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class PipelineSchema(NamedSchema, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)

Expand Down
1 change: 1 addition & 0 deletions src/zenml/zen_stores/schemas/run_template_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class RunTemplateSchema(BaseSchema, table=True):
primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)",
secondary="tag_resource",
secondaryjoin="TagSchema.id == foreign(TagResourceSchema.tag_id)",
order_by="TagSchema.name",
),
)

Expand Down
40 changes: 1 addition & 39 deletions src/zenml/zen_stores/schemas/tag_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""SQLModel implementation of tag tables."""

from datetime import datetime
from typing import TYPE_CHECKING, Any, List
from typing import Any, List
from uuid import UUID

from sqlalchemy import VARCHAR, Column
Expand All @@ -33,16 +33,6 @@
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field

if TYPE_CHECKING:
from zenml.zen_stores.schemas.artifact_schemas import (
ArtifactSchema,
ArtifactVersionSchema,
)
from zenml.zen_stores.schemas.model_schemas import (
ModelSchema,
ModelVersionSchema,
)


class TagSchema(NamedSchema, table=True):
"""SQL Model for tag."""
Expand Down Expand Up @@ -133,34 +123,6 @@ class TagResourceSchema(BaseSchema, table=True):
tag: "TagSchema" = Relationship(back_populates="links")
resource_id: UUID
resource_type: str = Field(sa_column=Column(VARCHAR(255), nullable=False))
artifact: List["ArtifactSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)",
overlaps="tags,model,artifact_version,model_version",
),
)
artifact_version: List["ArtifactVersionSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)",
overlaps="tags,model,artifact,model_version",
),
)
model: List["ModelSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)",
overlaps="tags,artifact,artifact_version,model_version",
),
)
model_version: List["ModelVersionSchema"] = Relationship(
back_populates="tags",
sa_relationship_kwargs=dict(
primaryjoin=f"and_(TagResourceSchema.resource_type=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)",
overlaps="tags,model,artifact,artifact_version",
),
)

@classmethod
def from_request(cls, request: TagResourceRequest) -> "TagResourceSchema":
Expand Down

0 comments on commit d74e50e

Please sign in to comment.