Skip to content

Commit

Permalink
MLCOMPUTE-2001 | Skip UI port selection, pod tpl generation, and logs…
Browse files Browse the repository at this point in the history
… for driver on k8s tron (#151)

* Skip ui port sel and pod tpl and logs for driver on k8s

* Fix tests

* Add comments
  • Loading branch information
chi-yelp authored Oct 10, 2024
1 parent ef95386 commit f79b9e0
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 34 deletions.
77 changes: 45 additions & 32 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

from service_configuration_lib import utils
from service_configuration_lib.text_colors import TextColors
from service_configuration_lib.utils import EPHEMERAL_PORT_END
from service_configuration_lib.utils import EPHEMERAL_PORT_START

AWS_CREDENTIALS_DIR = '/etc/boto_cfg/'
AWS_ENV_CREDENTIALS_PROVIDER = 'com.amazonaws.auth.EnvironmentVariableCredentialsProvider'
Expand All @@ -32,6 +34,7 @@
GPUS_HARD_LIMIT = 15
CLUSTERMAN_METRICS_YAML_FILE_PATH = '/nail/srv/configs/clusterman_metrics.yaml'
CLUSTERMAN_YAML_FILE_PATH = '/nail/srv/configs/clusterman.yaml'
SPARK_TRON_JOB_USER = 'TRON'

NON_CONFIGURABLE_SPARK_OPTS = {
'spark.master',
Expand Down Expand Up @@ -295,7 +298,7 @@ def _get_k8s_spark_env(
paasta_service: str,
paasta_instance: str,
docker_img: str,
pod_template_path: str,
pod_template_path: Optional[str],
volumes: Optional[List[Mapping[str, str]]],
paasta_pool: str,
driver_ui_port: int,
Expand Down Expand Up @@ -335,9 +338,12 @@ def _get_k8s_spark_env(
'spark.kubernetes.executor.label.yelp.com/pool': paasta_pool,
'spark.kubernetes.executor.label.paasta.yelp.com/pool': paasta_pool,
'spark.kubernetes.executor.label.yelp.com/owner': 'core_ml',
'spark.kubernetes.executor.podTemplateFile': pod_template_path,
**_get_k8s_docker_volumes_conf(volumes),
}

if pod_template_path is not None:
spark_env['spark.kubernetes.executor.podTemplateFile'] = pod_template_path

if service_account_name is not None:
spark_env.update(
{
Expand Down Expand Up @@ -419,12 +425,13 @@ def get_total_driver_memory_mb(spark_conf: Dict[str, str]) -> int:

class SparkConfBuilder:

def __init__(self):
self.spark_srv_conf = dict()
self.spark_constants = dict()
self.default_spark_srv_conf = dict()
self.mandatory_default_spark_srv_conf = dict()
self.spark_costs = dict()
def __init__(self, is_driver_on_k8s_tron: bool = False):
self.is_driver_on_k8s_tron = is_driver_on_k8s_tron
self.spark_srv_conf: Dict[str, Any] = dict()
self.spark_constants: Dict[str, Any] = dict()
self.default_spark_srv_conf: Dict[str, Any] = dict()
self.mandatory_default_spark_srv_conf: Dict[str, Any] = dict()
self.spark_costs: Dict[str, Dict[str, float]] = dict()

try:
(
Expand Down Expand Up @@ -628,7 +635,7 @@ def compute_executor_instances_k8s(self, user_spark_opts: Dict[str, str]) -> int
)

# Deprecation message
if 'spark.cores.max' in user_spark_opts:
if not self.is_driver_on_k8s_tron and 'spark.cores.max' in user_spark_opts:
log.warning(
f'spark.cores.max is DEPRECATED. Replace with '
f'spark.executor.instances={executor_instances} in --spark-args and in your service code '
Expand Down Expand Up @@ -1102,23 +1109,27 @@ def get_spark_conf(
spark_app_base_name
)

# Pick a port from a pre-defined port range, which will then be used by our Jupyter
# server metric aggregator API. The aggregator API collects Prometheus metrics from multiple
# Spark sessions and exposes them through a single endpoint.
try:
ui_port = int(
(spark_opts_from_env or {}).get('spark.ui.port') or
utils.ephemeral_port_reserve_range(
self.spark_constants.get('preferred_spark_ui_port_start'),
self.spark_constants.get('preferred_spark_ui_port_end'),
),
)
except Exception as e:
log.warning(
f'Could not get an available port using srv-config port range: {e}. '
'Using default port range to get an available port.',
)
ui_port = utils.ephemeral_port_reserve_range()
if self.is_driver_on_k8s_tron:
# For Tron-launched driver on k8s, we use a static Spark UI port
ui_port: int = self.spark_constants.get('preferred_spark_ui_port_start', EPHEMERAL_PORT_START)
else:
# Pick a port from a pre-defined port range, which will then be used by our Jupyter
# server metric aggregator API. The aggregator API collects Prometheus metrics from multiple
# Spark sessions and exposes them through a single endpoint.
try:
ui_port = int(
(spark_opts_from_env or {}).get('spark.ui.port') or
utils.ephemeral_port_reserve_range(
self.spark_constants.get('preferred_spark_ui_port_start', EPHEMERAL_PORT_START),
self.spark_constants.get('preferred_spark_ui_port_end', EPHEMERAL_PORT_END),
),
)
except Exception as e:
log.warning(
f'Could not get an available port using srv-config port range: {e}. '
'Using default port range to get an available port.',
)
ui_port = utils.ephemeral_port_reserve_range()

spark_conf = {**(spark_opts_from_env or {}), **_filter_user_spark_opts(user_spark_opts)}
random_postfix = utils.get_random_string(4)
Expand Down Expand Up @@ -1157,12 +1168,14 @@ def get_spark_conf(
)

# Add pod template file
pod_template_path = utils.generate_pod_template_path()
try:
utils.create_pod_template(pod_template_path, app_base_name)
except Exception as e:
log.error(f'Failed to generate Spark executor pod template: {e}')
pod_template_path = ''
pod_template_path: Optional[str] = None
if not self.is_driver_on_k8s_tron:
pod_template_path = utils.generate_pod_template_path()
try:
utils.create_pod_template(pod_template_path, app_base_name)
except Exception as e:
log.error(f'Failed to generate Spark executor pod template: {e}')
pod_template_path = None

if cluster_manager == 'kubernetes':
spark_conf.update(_get_k8s_spark_env(
Expand Down
10 changes: 8 additions & 2 deletions service_configuration_lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from socket import SO_REUSEADDR
from socket import socket
from socket import SOL_SOCKET
from typing import Any
from typing import Dict
from typing import Mapping
from typing import Tuple

import yaml
Expand All @@ -36,7 +36,13 @@
log.setLevel(logging.INFO)


def load_spark_srv_conf(preset_values=None) -> Tuple[Mapping, Mapping, Mapping, Mapping, Mapping]:
def load_spark_srv_conf(preset_values=None) -> Tuple[
Dict[str, Any],
Dict[str, Any],
Dict[str, Any],
Dict[str, Any],
Dict[str, Dict[str, float]],
]:
if preset_values is None:
preset_values = dict()
try:
Expand Down

0 comments on commit f79b9e0

Please sign in to comment.