Skip to content

Commit

Permalink
only validate needed image labels
Browse files Browse the repository at this point in the history
  • Loading branch information
bisgaard-itis committed Oct 5, 2023
1 parent 39f4697 commit bd4b3e9
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import os
import re
import socket
from dataclasses import dataclass
from pathlib import Path
Expand Down Expand Up @@ -47,7 +46,7 @@
pull_image,
)
from .errors import ServiceBadFormattedOutputError
from .models import LEGACY_INTEGRATION_VERSION, PROGRESS_REGEXP
from .models import LEGACY_INTEGRATION_VERSION, ImageLabels
from .task_shared_volume import TaskSharedVolumes

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -195,15 +194,9 @@ async def run(self, command: list[str]) -> TaskOutputData:
self._publish_sidecar_log,
)

integration_version: version.Version = LEGACY_INTEGRATION_VERSION
progress_regexp: re.Pattern[str] = PROGRESS_REGEXP
if image_labels := await get_image_labels(
image_labels: ImageLabels = await get_image_labels(
docker_client, self.docker_auth, self.service_key, self.service_version
):
if iversion := image_labels.integration_version:
integration_version = version.Version(iversion)
if pregexp := image_labels.progress_regexp:
progress_regexp = re.compile(pregexp)
)
computational_shared_data_mount_point = (
await get_computational_shared_data_mount_point(docker_client)
)
Expand All @@ -218,7 +211,9 @@ async def run(self, command: list[str]) -> TaskOutputData:
envs=self.task_envs,
labels=self.task_labels,
)
await self._write_input_data(task_volumes, integration_version)
await self._write_input_data(
task_volumes, image_labels.get_integration_version()
)

# PROCESSING
async with managed_container(
Expand All @@ -227,11 +222,11 @@ async def run(self, command: list[str]) -> TaskOutputData:
name=f"{self.service_key.split(sep='/')[-1]}_{run_id}",
) as container, managed_monitor_container_log_task(
container=container,
progress_regexp=progress_regexp,
progress_regexp=image_labels.get_progress_regexp(),
service_key=self.service_key,
service_version=self.service_version,
task_publishers=self.task_publishers,
integration_version=integration_version,
integration_version=image_labels.get_integration_version(),
task_volumes=task_volumes,
log_file_url=self.log_file_url,
log_publishing_cb=self._publish_sidecar_log,
Expand Down Expand Up @@ -261,7 +256,7 @@ async def run(self, command: list[str]) -> TaskOutputData:

# POST-PROCESSING
results = await self._retrieve_output_data(
task_volumes, integration_version
task_volumes, image_labels.get_integration_version()
)
await self._publish_sidecar_log("Task completed successfully.")
return results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
ContainerTag,
LogFileUploadURL,
)
from models_library.services import ServiceDockerData
from models_library.services_resources import BootMode
from models_library.utils.labels_annotations import OSPARC_LABEL_PREFIXES, from_labels
from packaging import version
Expand All @@ -44,6 +43,7 @@
LEGACY_INTEGRATION_VERSION,
ContainerHostConfig,
DockerContainerConfig,
ImageLabels,
)
from .task_shared_volume import TaskSharedVolumes

Expand Down Expand Up @@ -414,7 +414,7 @@ async def get_image_labels(
docker_auth: DockerBasicAuth,
service_key: ContainerImage,
service_version: ContainerTag,
) -> ServiceDockerData | None:
) -> ImageLabels:
image_cfg = await docker_client.images.inspect(
f"{docker_auth.server_address}/{service_key}:{service_version}"
)
Expand All @@ -425,8 +425,8 @@ async def get_image_labels(
data = from_labels(
image_labels, prefix_key=OSPARC_LABEL_PREFIXES[0], trim_key_head=False
)
return parse_obj_as(ServiceDockerData, data)
return None
return parse_obj_as(ImageLabels, data)
return ImageLabels()


async def get_computational_shared_data_mount_point(docker_client: Docker) -> Path:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import re

from models_library.basic_types import VERSION_RE
from models_library.services import ServiceDockerData
from packaging import version
from pydantic import BaseModel, ByteSize, Field, validator
from pydantic import BaseModel, ByteSize, Extra, Field, validator

LEGACY_INTEGRATION_VERSION = version.Version("0")
PROGRESS_REGEXP: re.Pattern[str] = re.compile(
Expand Down Expand Up @@ -62,3 +64,46 @@ class DockerContainerConfig(BaseModel):
image: str = Field(..., alias="Image")
labels: dict[str, str] = Field(..., alias="Labels")
host_config: ContainerHostConfig = Field(..., alias="HostConfig")


class ImageLabels(BaseModel):
integration_version: str = Field(
default=LEGACY_INTEGRATION_VERSION,
alias="integration-version",
description="integration version number",
regex=VERSION_RE,
examples=["1.0.0"],
)
progress_regexp: str = Field(
default=PROGRESS_REGEXP,
alias="progress_regexp",
description="regexp pattern for detecting computational service's progress",
)

class Config:
extra = Extra.ignore

@validator("integration_version", pre=True)
@classmethod
def default_integration_version(cls, v):
if v is None:
return ImageLabels().integration_version
return v

@validator("progress_regexp", pre=True)
@classmethod
def default_progress_regexp(cls, v):
if v is None:
return ImageLabels().progress_regexp
return v

def get_integration_version(self) -> version.Version:
return version.Version(self.integration_version)

def get_progress_regexp(self) -> re.Pattern[str]:
return re.compile(self.progress_regexp)


assert set(ImageLabels.__fields__).issubset(
ServiceDockerData.__fields__
), "ImageLabels must be compatible with ServiceDockerData"

0 comments on commit bd4b3e9

Please sign in to comment.