Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple Ray Resources: Construct ray k8spods from Resources #2943

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
55 changes: 52 additions & 3 deletions flytekit/core/resources.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import List, Optional, Union
from dataclasses import dataclass, fields
from typing import Any, List, Optional, Union

from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.models import task as task_models
Expand Down Expand Up @@ -73,7 +74,10 @@ def _convert_resources_to_resource_entries(resources: Resources) -> List[_Resour
resource_entries.append(_ResourceEntry(name=_ResourceName.GPU, value=str(resources.gpu)))
if resources.ephemeral_storage is not None:
resource_entries.append(
_ResourceEntry(name=_ResourceName.EPHEMERAL_STORAGE, value=str(resources.ephemeral_storage))
_ResourceEntry(
name=_ResourceName.EPHEMERAL_STORAGE,
value=str(resources.ephemeral_storage),
)
)
return resource_entries

Expand All @@ -96,3 +100,48 @@ def convert_resources_to_resource_model(
if limits is not None:
limit_entries = _convert_resources_to_resource_entries(limits)
return task_models.Resources(requests=request_entries, limits=limit_entries)


def construct_k8s_pod_spec_from_resources(
k8s_pod_name: str,
requests: Optional[Resources],
limits: Optional[Resources],
) -> dict[str, Any]:
def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str = "nvidia.com/gpu"):
if resources is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using other gpus is going to be hard, even if we push this parameter to the outer function (i.e. construct_k8s_pod_spec_from_resources).

return None

resources_map = {
"cpu": "cpu",
"mem": "memory",
"gpu": k8s_gpu_resource_key,
"ephemeral_storage": "ephemeral-storage",
}

k8s_pod_resources = {}

for resource in fields(resources):
resource_value = getattr(resources, resource.name)
if resource_value is not None:
k8s_pod_resources[resources_map[resource.name]] = resource_value

return k8s_pod_resources

requests = _construct_k8s_pods_resources(resources=requests)
limits = _construct_k8s_pods_resources(resources=limits)
requests = requests or limits
limits = limits or requests

k8s_pod = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests=requests,
limits=limits,
),
)
]
)

return k8s_pod.to_dict()
28 changes: 23 additions & 5 deletions plugins/flytekit-ray/flytekitplugins/ray/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from flyteidl.plugins import ray_pb2 as _ray_pb2

from flytekit.core.resources import Resources, construct_k8s_pod_spec_from_resources
from flytekit.models import common as _common
from flytekit.models.task import K8sPod
from flytekit.models.task import K8sObjectMetadata, K8sPod


class WorkerGroupSpec(_common.FlyteIdlEntity):
Expand All @@ -14,14 +15,22 @@ def __init__(
min_replicas: typing.Optional[int] = None,
max_replicas: typing.Optional[int] = None,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, we should keep it. If someone specifies both k8s_pod and requests, then we should merge it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked with Eduardo about this today and agreed to only expose k8s_pod and let the user use helper functions to construct the k8s_pod like construct_k8s_pod_spec_from_resources() in this PR

requests: typing.Optional[Resources] = None,
limits: typing.Optional[Resources] = None,
):
self._group_name = group_name
self._replicas = replicas
self._max_replicas = max(replicas, max_replicas) if max_replicas is not None else replicas
self._min_replicas = min(replicas, min_replicas) if min_replicas is not None else replicas
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep this as part of the interface and build helper functions that construct valid pod specs instead (as mentioned in the original flyte PR). This is going to help in the other problem we're having with passing the gpu resource name around (in other words, gpu can be an argument of one of the helper function that builds pod specs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get what you are saying. So we want users to construct the pod specs themself like calling construct_k8s_pod_spec_from_resources() or specifying pod templates in user code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make the method name simple, maybe pod from resources

self._requests = requests
self._limits = limits
self._k8s_pod = K8sPod(
metadata=K8sObjectMetadata(),
pod_spec=construct_k8s_pod_spec_from_resources(
k8s_pod_name="ray-worker", requests=self._requests, limits=self._limits
),
)

@property
def group_name(self):
Expand Down Expand Up @@ -104,10 +113,19 @@ class HeadGroupSpec(_common.FlyteIdlEntity):
def __init__(
self,
ray_start_params: typing.Optional[typing.Dict[str, str]] = None,
k8s_pod: typing.Optional[K8sPod] = None,
requests: typing.Optional[Resources] = None,
limits: typing.Optional[Resources] = None,
Comment on lines +116 to +117
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

):
self._ray_start_params = ray_start_params
self._k8s_pod = k8s_pod
self._requests = requests
self._limits = limits

self._k8s_pod = K8sPod(
metadata=K8sObjectMetadata(),
pod_spec=construct_k8s_pod_spec_from_resources(
k8s_pod_name="ray-head", requests=self._requests, limits=self._limits
),
)

@property
def ray_start_params(self):
Expand Down
17 changes: 12 additions & 5 deletions plugins/flytekit-ray/flytekitplugins/ray/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.resources import Resources
from flytekit.extend import TaskPlugins
from flytekit.models.task import K8sPod

ray = lazy_module("ray")


@dataclass
class HeadNodeConfig:
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None
requests: typing.Optional[Resources] = None
limits: typing.Optional[Resources] = None


@dataclass
Expand All @@ -37,7 +38,8 @@ class WorkerNodeConfig:
min_replicas: typing.Optional[int] = None
max_replicas: typing.Optional[int] = None
ray_start_params: typing.Optional[typing.Dict[str, str]] = None
k8s_pod: typing.Optional[K8sPod] = None
requests: typing.Optional[Resources] = None
limits: typing.Optional[Resources] = None


@dataclass
Expand Down Expand Up @@ -92,7 +94,11 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
ray_job = RayJob(
ray_cluster=RayCluster(
head_group_spec=(
HeadGroupSpec(cfg.head_node_config.ray_start_params, cfg.head_node_config.k8s_pod)
HeadGroupSpec(
cfg.head_node_config.ray_start_params,
cfg.head_node_config.requests,
cfg.head_node_config.limits,
)
if cfg.head_node_config
else None
),
Expand All @@ -103,7 +109,8 @@ def get_custom(self, settings: SerializationSettings) -> Optional[Dict[str, Any]
c.min_replicas,
c.max_replicas,
c.ray_start_params,
c.k8s_pod,
c.requests,
c.limits,
)
for c in cfg.worker_node_config
],
Expand Down
37 changes: 32 additions & 5 deletions plugins/flytekit-ray/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,31 @@
import ray
import yaml
from flytekitplugins.ray import HeadNodeConfig
from flytekitplugins.ray.models import RayCluster, RayJob, WorkerGroupSpec, HeadGroupSpec
from flytekitplugins.ray.models import (
RayCluster,
RayJob,
WorkerGroupSpec,
HeadGroupSpec,
)
from flytekitplugins.ray.task import RayJobConfig, WorkerNodeConfig
from google.protobuf.json_format import MessageToDict
from flytekit.models.task import K8sPod
from flytekit.core.resources import Resources

from flytekit import PythonFunctionTask, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings

config = RayJobConfig(
worker_node_config=[WorkerNodeConfig(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))],
head_node_config=HeadNodeConfig(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})),
worker_node_config=[
WorkerNodeConfig(
group_name="test_group",
replicas=3,
min_replicas=0,
max_replicas=10,
requests=Resources(cpu=2, mem="2Gi"),
limits=Resources(cpu=2, mem="4Gi"),
)
],
head_node_config=HeadNodeConfig(requests=Resources(cpu=2)),
runtime_env={"pip": ["numpy"]},
enable_autoscaling=True,
shutdown_after_job_finishes=True,
Expand Down Expand Up @@ -44,7 +58,20 @@ def t1(a: int) -> str:
)

ray_job_pb = RayJob(
ray_cluster=RayCluster(worker_group_spec=[WorkerGroupSpec(group_name="test_group", replicas=3, min_replicas=0, max_replicas=10, k8s_pod=K8sPod(pod_spec={"str": "worker", "int": 1}))], head_group_spec=HeadGroupSpec(k8s_pod=K8sPod(pod_spec={"str": "head", "int": 2})), enable_autoscaling=True),
ray_cluster=RayCluster(
worker_group_spec=[
WorkerGroupSpec(
group_name="test_group",
replicas=3,
min_replicas=0,
max_replicas=10,
requests=Resources(cpu=2, mem="2Gi"),
limits=Resources(cpu=2, mem="4Gi"),
)
],
head_group_spec=HeadGroupSpec(requests=Resources(cpu=2)),
enable_autoscaling=True,
),
runtime_env=base64.b64encode(json.dumps({"pip": ["numpy"]}).encode()).decode(),
runtime_env_yaml=yaml.dump({"pip": ["numpy"]}),
shutdown_after_job_finishes=True,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"jsonlines",
"jsonpickle",
"keyring>=18.0.1",
"kubernetes>=12.0.1",
"markdown-it-py",
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
Expand Down
56 changes: 55 additions & 1 deletion tests/flytekit/unit/core/test_resources.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Dict

import pytest
from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements

import flytekit.models.task as _task_models
from flytekit import Resources
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.core.resources import (
construct_k8s_pod_spec_from_resources,
convert_resources_to_resource_model,
)

_ResourceName = _task_models.Resources.ResourceName

Expand Down Expand Up @@ -101,3 +105,53 @@ def test_resources_round_trip():
json_str = original.to_json()
result = Resources.from_json(json_str)
assert original == result


def test_construct_k8s_pod_spec_from_resources_requests_limits_set():
requests = Resources(cpu="1", mem="1Gi", gpu="1", ephemeral_storage="1Gi")
limits = Resources(cpu="4", mem="2Gi", gpu="1", ephemeral_storage="1Gi")
k8s_pod_name = "foo"

expected_pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests={
"cpu": "1",
"memory": "1Gi",
"nvidia.com/gpu": "1",
"ephemeral-storage": "1Gi",
},
limits={
"cpu": "4",
"memory": "2Gi",
"nvidia.com/gpu": "1",
"ephemeral-storage": "1Gi",
},
),
)
]
)
pod_spec = construct_k8s_pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)


def test_construct_k8s_pod_spec_from_resources_requests_set():
requests = Resources(cpu="1", mem="1Gi")
limits = None
k8s_pod_name = "foo"

expected_pod_spec = V1PodSpec(
containers=[
V1Container(
name=k8s_pod_name,
resources=V1ResourceRequirements(
requests={"cpu": "1", "memory": "1Gi"},
limits={"cpu": "1", "memory": "1Gi"},
),
)
]
)
pod_spec = construct_k8s_pod_spec_from_resources(k8s_pod_name=k8s_pod_name, requests=requests, limits=limits)
assert expected_pod_spec == V1PodSpec(**pod_spec)
Loading