diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 8a99dbf2ea..9575995ce4 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass -from typing import List, Optional, Union +from dataclasses import dataclass, fields +from typing import Any, List, Optional, Union +from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements from mashumaro.mixins.json import DataClassJSONMixin from flytekit.models import task as task_models @@ -73,7 +74,10 @@ def _convert_resources_to_resource_entries(resources: Resources) -> List[_Resour resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu))) if resources.ephemeral_storage is not None: resource_entries.append( - _ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage)) + _ResourceEntry( + name=_ResourceName.EPHEMERAL_STORAGE, + value=str(resources.ephemeral_storage), + ) ) return resource_entries @@ -96,3 +100,48 @@ def convert_resources_to_resource_model( if limits is not None: limit_entries = _convert_resources_to_resource_entries(limits) return task_models.Resources(requests=request_entries, limits=limit_entries) + + +def construct_k8s_pod_spec_from_resources( + k8s_pod_name: str, + requests: Optional[Resources], + limits: Optional[Resources], +) -> dict[str, Any]: + def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str = "nvidia.com/gpu"): + if resources is None: + return None + + resources_map = { + "cpu": "cpu", + "mem": "memory", + "gpu": k8s_gpu_resource_key, + "ephemeral_storage": "ephemeral-storage", + } + + k8s_pod_resources = {} + + for resource in fields(resources): + resource_value = getattr(resources, resource.name) + if resource_value is not None: + k8s_pod_resources[resources_map[resource.name]] = resource_value + + return k8s_pod_resources + + requests = _construct_k8s_pods_resources(resources=requests) + limits = _construct_k8s_pods_resources(resources=limits) + requests = requests or limits + limits = limits or requests + + k8s_pod = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests=requests, + limits=limits, + ), + ) + ] + ) + + return k8s_pod.to_dict() diff --git a/plugins/flytekit-ray/flytekitplugins/ray/models.py b/plugins/flytekit-ray/flytekitplugins/ray/models.py index 1f3a830f16..295b119efb 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/models.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/models.py @@ -2,8 +2,9 @@ from flyteidl.plugins import ray_pb2 as _ray_pb2 +from flytekit.core.resources import Resources, construct_k8s_pod_spec_from_resources from flytekit.models import common as _common -from flytekit.models.task import K8sPod +from flytekit.models.task import K8sObjectMetadata, K8sPod class WorkerGroupSpec(_common.FlyteIdlEntity): @@ -14,14 +15,22 @@ def __init__( min_replicas: typing.Optional[int] = None, max_replicas: typing.Optional[int] = None, ray_start_params: typing.Optional[typing.Dict[str, str]] = None, - k8s_pod: typing.Optional[K8sPod] = None, + requests: typing.Optional[Resources] = None, + limits: typing.Optional[Resources] = None, ): self._group_name = group_name self._replicas = replicas self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas self._ray_start_params = ray_start_params - self._k8s_pod = k8s_pod + self._requests = requests + self._limits = limits + self._k8s_pod = K8sPod( + metadata=K8sObjectMetadata(), + pod_spec=construct_k8s_pod_spec_from_resources( + k8s_pod_name="ray-worker", requests=self._requests, limits=self._limits + ), + ) @property def group_name(self): @@ -104,10 +113,19 @@ class HeadGroupSpec(_common.FlyteIdlEntity): def __init__( self, ray_start_params: typing.Optional[typing.Dict[str, str]] = None, - k8s_pod: typing.Optional[K8sPod] = None, + requests: typing.Optional[Resources] = None, + limits: typing.Optional[Resources] = None, ): self._ray_start_params = ray_start_params - self._k8s_pod = k8s_pod + self._requests = requests + self._limits = limits + + self._k8s_pod = K8sPod( + metadata=K8sObjectMetadata(), + pod_spec=construct_k8s_pod_spec_from_resources( + k8s_pod_name="ray-head", requests=self._requests, limits=self._limits + ), + ) @property def ray_start_params(self): diff --git a/plugins/flytekit-ray/flytekitplugins/ray/task.py b/plugins/flytekit-ray/flytekitplugins/ray/task.py index 3620a0494c..57be2a0b4b 100644 --- a/plugins/flytekit-ray/flytekitplugins/ray/task.py +++ b/plugins/flytekit-ray/flytekitplugins/ray/task.py @@ -18,8 +18,8 @@ from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.resources import Resources from flytekit.extend import TaskPlugins -from flytekit.models.task import K8sPod ray = lazy_module("ray") @@ -27,7 +27,8 @@ @dataclass class HeadNodeConfig: ray_start_params: typing.Optional[typing.Dict[str, str]] = None - k8s_pod: typing.Optional[K8sPod] = None + requests: typing.Optional[Resources] = None + limits: typing.Optional[Resources] = None @dataclass @@ -37,7 +38,8 @@ class WorkerNodeConfig: min_replicas: typing.Optional[int] = None max_replicas: typing.Optional[int] = None ray_start_params: typing.Optional[typing.Dict[str, str]] = None - k8s_pod: typing.Optional[K8sPod] = None + requests: typing.Optional[Resources] = None + limits: typing.Optional[Resources] = None @dataclass @@ -92,7 +94,11 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] ray_job = RayJob( ray_cluster=RayCluster( head_group_spec=( - HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod) + HeadGroupSpec( + cfg.head_node_config.ray_start_params, + cfg.head_node_config.requests, + cfg.head_node_config.limits, + ) if cfg.head_node_config else None ), @@ -103,7 +109,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any] c.min_replicas, c.max_replicas, c.ray_start_params, - c.k8s_pod, + c.requests, + c.limits, ) for c in cfg.worker_node_config ], diff --git a/plugins/flytekit-ray/tests/test_ray.py b/plugins/flytekit-ray/tests/test_ray.py index 737cdf6f4a..206c8ea04a 100644 --- a/plugins/flytekit-ray/tests/test_ray.py +++ b/plugins/flytekit-ray/tests/test_ray.py @@ -4,17 +4,31 @@ import ray import yaml from flytekitplugins.ray import HeadNodeConfig -from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, HeadGroupSpec +from flytekitplugins.ray.models import ( + RayCluster, + RayJob, + WorkerGroupSpec, + HeadGroupSpec, +) from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig from google.protobuf.json_format import MessageToDict -from flytekit.models.task import K8sPod +from flytekit.core.resources import Resources from flytekit import PythonFunctionTask, task from flytekit.configuration import Image, ImageConfig, SerializationSettings config = RayJobConfig( - worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], - head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), + worker_node_config=[ + WorkerNodeConfig( + group_name="test_group", + replicas=3, + min_replicas=0, + max_replicas=10, + requests=Resources(cpu=2, mem="2Gi"), + limits=Resources(cpu=2, mem="4Gi"), + ) + ], + head_node_config=HeadNodeConfig(requests=Resources(cpu=2)), runtime_env={"pip": ["numpy"]}, enable_autoscaling=True, shutdown_after_job_finishes=True, @@ -44,7 +58,20 @@ def t1(a: int) -> str: ) ray_job_pb = RayJob( - ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True), + ray_cluster=RayCluster( + worker_group_spec=[ + WorkerGroupSpec( + group_name="test_group", + replicas=3, + min_replicas=0, + max_replicas=10, + requests=Resources(cpu=2, mem="2Gi"), + limits=Resources(cpu=2, mem="4Gi"), + ) + ], + head_group_spec=HeadGroupSpec(requests=Resources(cpu=2)), + enable_autoscaling=True, + ), runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(), runtime_env_yaml=yaml.dump({"pip": ["numpy"]}), shutdown_after_job_finishes=True, diff --git a/pyproject.toml b/pyproject.toml index 4cba669364..862b516204 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "jsonlines", "jsonpickle", "keyring>=18.0.1", + "kubernetes>=12.0.1", "markdown-it-py", "marshmallow-enum", "marshmallow-jsonschema>=0.12.0", diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 5dd9926039..57885527ed 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -1,10 +1,14 @@ from typing import Dict import pytest +from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements import flytekit.models.task as _task_models from flytekit import Resources -from flytekit.core.resources import convert_resources_to_resource_model +from flytekit.core.resources import ( + construct_k8s_pod_spec_from_resources, + convert_resources_to_resource_model, +) _ResourceName = _task_models.Resources.ResourceName @@ -101,3 +105,53 @@ def test_resources_round_trip(): json_str = original.to_json() result = Resources.from_json(json_str) assert original == result + + +def test_construct_k8s_pod_spec_from_resources_requests_limits_set(): + requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi") + limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi") + k8s_pod_name = "foo" + + expected_pod_spec = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests={ + "cpu": "1", + "memory": "1Gi", + "nvidia.com/gpu": "1", + "ephemeral-storage": "1Gi", + }, + limits={ + "cpu": "4", + "memory": "2Gi", + "nvidia.com/gpu": "1", + "ephemeral-storage": "1Gi", + }, + ), + ) + ] + ) + pod_spec = construct_k8s_pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) + assert expected_pod_spec == V1PodSpec(**pod_spec) + + +def test_construct_k8s_pod_spec_from_resources_requests_set(): + requests = Resources(cpu="1", mem="1Gi") + limits = None + k8s_pod_name = "foo" + + expected_pod_spec = V1PodSpec( + containers=[ + V1Container( + name=k8s_pod_name, + resources=V1ResourceRequirements( + requests={"cpu": "1", "memory": "1Gi"}, + limits={"cpu": "1", "memory": "1Gi"}, + ), + ) + ] + ) + pod_spec = construct_k8s_pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits) + assert expected_pod_spec == V1PodSpec(**pod_spec)