Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Docker and singularity login at runtime #725

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions WDL/runtime/backend/cli_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,13 @@ def cli_exe(self) -> List[str]:
return [self.cli_name]

def _pull_invocation(self, logger: logging.Logger, cleanup: ExitStack) -> Tuple[str, List[str]]:
image = self.runtime_values.get(
"docker", self.cfg.get_dict("task_runtime", "defaults")["docker"]
)
image = self._get_runtime_image()
return (image, self.cli_exe + ["pull", image])

@abstractmethod
def _login_invocation(self, logger: logging.Logger) -> Optional[List[str]]:
pass

@abstractmethod
def _run_invocation(self, logger: logging.Logger, cleanup: ExitStack, image: str) -> List[str]:
pass
Expand Down Expand Up @@ -183,7 +185,7 @@ def _pull(self, logger: logging.Logger, cleanup: ExitStack) -> str:
Pull the image under a global lock, ensuring we'll only download it once even if used by
many parallel tasks all starting at the same time.
"""
image, invocation = self._pull_invocation(logger, cleanup)
image, pull_invocation = self._pull_invocation(logger, cleanup)
with self._pulled_images_lock:
if image in self._pulled_images:
logger.info(_(f"{self.cli_name} image already pulled", image=image))
Expand All @@ -197,13 +199,13 @@ def _pull(self, logger: logging.Logger, cleanup: ExitStack) -> str:
logger.info(_(f"{self.cli_name} image already pulled", image=image))
return image

if not invocation:
if not pull_invocation:
# No action required, image could be cached externally.
return image
logger.info(_(f"begin {self.cli_name} pull", command=" ".join(invocation)))
logger.info(_(f"begin {self.cli_name} pull", command=" ".join(pull_invocation)))
try:
subprocess.run(
invocation,
pull_invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
Expand All @@ -217,7 +219,45 @@ def _pull(self, logger: logging.Logger, cleanup: ExitStack) -> str:
stdout=cpe.stdout.strip().split("\n"),
)
)
raise DownloadFailed(image) from None
raise_error = True

logger.info("Pull failed, try to login")
# I didn't find a way to test in advance if registry login is needed for a specific registry
# using singularity cli, therefore we try to login if pull fails and pull again
login_invocation = self._login_invocation(logger)
if login_invocation:
try:
subprocess.run(
login_invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
check=True,
)

logger.info(
_(f"retry {self.cli_name} pull", command=" ".join(pull_invocation))
)
subprocess.run(
pull_invocation,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
check=True,
)
raise_error = False
except subprocess.CalledProcessError as cpe:
logger.error(
_(
f"Retry {self.cli_name} pull failed",
stderr=cpe.stderr.strip().split("\n"),
stdout=cpe.stdout.strip().split("\n"),
)
)

if raise_error:
raise DownloadFailed(image) from None

with self._pulled_images_lock:
self._pulled_images.add(image)

Expand Down
26 changes: 25 additions & 1 deletion WDL/runtime/backend/docker_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import json
import stat
import time
import shlex
import uuid
import shlex
import base64
import random
import hashlib
Expand All @@ -26,6 +26,9 @@
from ..task_container import TaskContainer


logging.getLogger("botocore").setLevel(logging.WARNING)


class SwarmContainer(TaskContainer):
"""
TaskContainer docker (swarm) runtime
Expand Down Expand Up @@ -325,6 +328,16 @@ def resolve_tag(
try:
image_attrs = client.images.get(image_tag).attrs
except docker.errors.ImageNotFound:
try:
# docker.errors.APIError is thrown if permissions are missing
client.images.get_registry_data(image_tag) # type: ignore[attr-defined]
except docker.errors.APIError:
user, password, registry_name = super().get_image_registry_credentials(
logger, image_tag, client
)
if all((user, password, registry_name)):
self.docker_login(logger, client, user, password, registry_name) # type: ignore[arg-type]

try:
logger.info(_("docker pull", tag=image_tag))
client.images.pull(image_tag)
Expand All @@ -342,6 +355,17 @@ def resolve_tag(
logger.notice(_("docker image", **image_log))
return image_tag

def docker_login(
self,
logger: logging.Logger,
client: docker.DockerClient,
username: str,
password: str,
registry_name: str,
) -> None:
logger.debug(f"Login to {registry_name} registry")
client.login(username, password, registry=registry_name, reauth=True) # type: ignore[attr-defined]

def prepare_mounts(self, logger: logging.Logger) -> List[docker.types.Mount]:
def escape(s):
# docker processes {{ interpolations }}
Expand Down
19 changes: 18 additions & 1 deletion WDL/runtime/backend/singularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def cli_exe(self) -> List[str]:
return self.cfg.get_list("singularity", "exe")

def _pull_invocation(self, logger: logging.Logger, cleanup: ExitStack) -> Tuple[str, List[str]]:
image, invocation = super()._pull_invocation(logger, cleanup)
image = super()._get_runtime_image()
docker_uri = "docker://" + image
pulldir = self.image_cache_dir or cleanup.enter_context(
tempfile.TemporaryDirectory(prefix="miniwdl_sif_")
Expand All @@ -72,6 +72,23 @@ def _pull_invocation(self, logger: logging.Logger, cleanup: ExitStack) -> Tuple[
logger.info(_("Singularity SIF found in image cache directory", sif=image_path))
return image_path, []

def _login_invocation(self, logger: logging.Logger) -> Optional[List[str]]:
login_invocation = None
image = super()._get_runtime_image()
user, password, registry_name = super().get_image_registry_credentials(logger, image)
if all((user, password, registry_name)):
registry_name = registry_name.split("/")[0] # type: ignore[union-attr]
login_invocation = self.cli_exe + [
"registry",
"login",
"--username",
user,
"--password",
password,
"docker://" + registry_name,
]
return login_invocation # type: ignore[return-value]

def _run_invocation(self, logger: logging.Logger, cleanup: ExitStack, image: str) -> List[str]:
"""
Formulate `singularity run` command-line invocation
Expand Down
96 changes: 95 additions & 1 deletion WDL/runtime/task_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import shutil
import threading
import typing
from typing import Callable, Iterable, Any, Dict, Optional, ContextManager, Set
from typing import Callable, Iterable, Any, Dict, Optional, ContextManager, Set, Tuple
from abc import ABC, abstractmethod
from contextlib import suppress
from .. import Error, Value, Type
Expand All @@ -18,11 +18,25 @@
PygtailLogger,
parse_byte_size,
)
import base64
from enum import Enum
import re
import warnings
import boto3
import docker
import google.auth
import google.auth.transport.requests
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .error import OutputError, Terminated, CommandFailed


class SupportedProviders(Enum):
AWS = "aws"
GCP = "gcp"
UNKNOWN = None


class TaskContainer(ABC):
"""
Base class for task containers, subclassed by runtime-specific backends (e.g. Docker).
Expand Down Expand Up @@ -293,6 +307,86 @@ def process_runtime(self, logger: logging.Logger, runtime_eval: Dict[str, Value.
raise Error.RuntimeError("invalid setting of runtime.gpu")
ans["gpu"] = runtime_eval["gpu"].value

def _get_runtime_image(self):
image = self.runtime_values.get(
"docker", self.cfg.get_dict("task_runtime", "defaults")["docker"]
)
return image

def get_image_registry_credentials(
self,
logger: logging.Logger,
image_tag: str,
docker_client: Optional[docker.DockerClient] = None,
) -> Tuple[Optional[str], Optional[str], Optional[str]]:
close_docker_client = False
if not docker_client:
docker_client = docker.from_env(version="auto")
close_docker_client = True
logger.debug(f"Need to login to {image_tag} registry")
registry_name, provider = self._get_registry_name_and_provider(logger, image_tag)
if registry_name and provider is SupportedProviders.AWS:
user, password = self._aws_ecr_login_args(logger, registry_name)
if registry_name and provider is SupportedProviders.GCP:
user, password = self._gcp_docker_registry_login_args()
if provider is SupportedProviders.UNKNOWN:
logger.warning(
f"{image_tag} registry pattern unrecognized. If login is needed do it before running the workflow"
)
user, password = None, None # type: ignore[assignment]
# close the docker client in case it was instantiated in the scope of the function
if close_docker_client:
docker_client.close()
return user, password, registry_name

def _get_registry_name_and_provider(
self, logger: logging.Logger, image_tag: str
) -> Tuple[str | None, SupportedProviders]:
logger.debug(f"Get registry name and provider for {image_tag}")
# GCP:
# - <LOCATION>-docker.pkg.dev/<PROJECT-ID>/<REPOSITORY>
# - <LOCATION>.gcr.io/<PROJECT-ID> (legacy)
gcp_registry_pattern = (
r"^(?P<gcp>[a-z-]+[0-9]+-docker\.pkg\.dev/[a-z0-9-]+/[a-z0-9-]+|[a-z\.]*gcr\.io)/.*$"
)
# AWS:
# - <AWS_ACCOUNT_ID>.dkr.ecr.<REGION>.amazonaws.com
aws_registry_pattern = r"^(?P<aws>[0-9]{12}\.dkr\.ecr\.[a-z-]+[0-9]+\.amazonaws\.com)/.*$"

pattern_match = re.match(gcp_registry_pattern, image_tag) or re.match(
aws_registry_pattern, image_tag
)
registry_name = pattern_match.group(1) if pattern_match else None
provider = SupportedProviders(
list(pattern_match.groupdict().keys())[0] if pattern_match else None
)
logger.debug(f"Registry: {registry_name}. Provider: {provider}")
return registry_name, provider

def _aws_ecr_login_args(self, logger: logging.Logger, registry_name: str) -> Tuple[str, str]:
logger.debug(f"Get region and account ID from registry name {registry_name}")
aws_account_id, _, _, aws_region, _, _ = registry_name.split(".")
logger.debug(f"AWS account: {aws_account_id}. Region: {aws_region}")
ecr_client = boto3.client("ecr", region_name=aws_region)
logger.debug(f"Get ECR token for {registry_name}")
response = ecr_client.get_authorization_token(registryIds=[aws_account_id])
ecr_password = (
base64.b64decode(response["authorizationData"][0]["authorizationToken"])
.replace(b"AWS:", b"")
.decode("utf-8")
)
return "AWS", ecr_password

def _gcp_docker_registry_login_args(self) -> Tuple[str, str]:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
creds, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
return "oauth2accesstoken", creds.token

def run(self, logger: logging.Logger, command: str) -> None:
"""
1. Container is instantiated with the configured mounts and resources
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ dependencies = [
"python-json-logger>=2,<3",
"lark~=1.1",
"bullet>=2,<3",
"psutil>=5,<7"
"psutil>=5,<7",
"google-auth>=2.32.0",
"boto3>=1.34.153",
"boto3-stubs>=1.34.153",
]

[project.optional-dependencies]
Expand Down