Skip to content

Commit

Permalink
add concurrency config schema
Browse files Browse the repository at this point in the history
  • Loading branch information
prha committed Jan 9, 2025
1 parent 2ce911a commit 8cfee9c
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 24 deletions.
10 changes: 9 additions & 1 deletion python_modules/dagster/dagster/_core/instance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ def get_run_queue_config(self) -> Optional["RunQueueConfig"]:
if not isinstance(self.run_coordinator, QueuedRunCoordinator):
return None

return self.run_coordinator.get_run_queue_config()
run_coordinator_run_queue_config = self.run_coordinator.get_run_queue_config()
return run_coordinator_run_queue_config.with_concurrency_settings(
self.get_settings("concurrency")
)

@property
def run_launcher(self) -> "RunLauncher":
Expand Down Expand Up @@ -972,6 +975,11 @@ def auto_materialize_use_sensors(self) -> int:

@property
def global_op_concurrency_default_limit(self) -> Optional[int]:
default_limit = self.get_settings("concurrency").get("pools", {}).get("default_limit")
if default_limit is not None:
return default_limit

# fallback to the old settings
return self.get_settings("concurrency").get("default_op_concurrency_limit")

# python logs
Expand Down
191 changes: 169 additions & 22 deletions python_modules/dagster/dagster/_core/instance/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple, Type, cast
from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Tuple, Type, cast

from dagster import (
Array,
Bool,
String,
_check as check,
)
from dagster._config import (
Field,
IntSource,
Noneable,
Permissive,
ScalarUnion,
Selector,
Shape,
StringSource,
validate_config,
)
Expand Down Expand Up @@ -119,18 +122,7 @@ def dagster_instance_config(

# validate default op concurrency limits
if "concurrency" in dagster_config_dict:
default_concurrency_limit = dagster_config_dict["concurrency"].get(
"default_op_concurrency_limit"
)
if default_concurrency_limit is not None:
max_limit = get_max_concurrency_limit_value()
if default_concurrency_limit < 0 or default_concurrency_limit > max_limit:
raise DagsterInvalidConfigError(
f"Found value `{default_concurrency_limit}` for `default_op_concurrency_limit`, "
f"Expected value between 0-{max_limit}.",
[],
None,
)
validate_concurrency_config(dagster_config_dict)

dagster_config = validate_config(schema, dagster_config_dict)
if not dagster_config.success:
Expand All @@ -156,6 +148,91 @@ def run_queue_config_schema() -> Field:
)


def validate_concurrency_config(dagster_config_dict: Mapping[str, Any]):
concurrency_config = dagster_config_dict["concurrency"]
default_concurrency_limit = concurrency_config.get("default_op_concurrency_limit")
if default_concurrency_limit is not None:
max_limit = get_max_concurrency_limit_value()
if default_concurrency_limit < 0 or default_concurrency_limit > max_limit:
raise DagsterInvalidConfigError(
f"Found value `{default_concurrency_limit}` for `default_op_concurrency_limit`, "
f"Expected value between 0-{max_limit}.",
[],
None,
)

if "run_queue" in dagster_config_dict:
verify_config_match(
dagster_config_dict,
["concurrency", "runs", "max_concurrent_runs"],
["run_queue", "max_concurrent_runs"],
)
verify_config_match(
dagster_config_dict,
["concurrency", "runs", "tag_concurrency_limits"],
["run_queue", "tag_concurrency_limits"],
)
verify_config_match(
dagster_config_dict,
["concurrency", "pools", "op_run_buffer"],
["run_queue", "block_op_concurrency_limited_runs", "op_concurrency_slot_buffer"],
)
if (
"run_coordinator" in dagster_config_dict
and pluck_config_value(dagster_config_dict, ["run_coordinator", "class"])
== "QueuedRunCoordinator"
):
verify_config_match(
dagster_config_dict,
["concurrency", "runs", "max_concurrent_runs"],
["run_coordinator", "config", "max_concurrent_runs"],
)
verify_config_match(
dagster_config_dict,
["concurrency", "runs", "tag_concurrency_limits"],
["run_coordinator", "config", "tag_concurrency_limits"],
)
verify_config_match(
dagster_config_dict,
["concurrency", "pools", "op_run_buffer"],
[
"run_coordinator",
"config",
"block_op_concurrency_limited_runs",
"op_concurrency_slot_buffer",
],
)


def pluck_config_value(config: Mapping[str, Any], path: Sequence[str]):
value = config
for part in path:
if not isinstance(value, dict):
return None

value = value.get(part)
if value is None:
return value

return value


def verify_config_match(config: Mapping[str, Any], path_a: Sequence[str], path_b: Sequence[str]):
value_a = pluck_config_value(config, path_a)
value_b = pluck_config_value(config, path_b)
if value_a is None or value_b is None:
return

if value_a != value_b:
path_a_str = " > ".join(path_a)
path_b_str = " > ".join(path_b)
raise DagsterInvalidConfigError(
f"Found `{value_a}` for `{path_a_str}` that conflicts with `{value_b}` for `{path_b_str}`.",
[],
None,
)


def storage_config_schema() -> Field:
return Field(
Selector(
Expand Down Expand Up @@ -344,6 +421,84 @@ def secrets_loader_config_schema() -> Field:
)


def get_concurrency_config() -> Field:
return Field(
{
"pools": Field(
{
"default_limit": Field(
int,
is_required=False,
description="The default maximum number of concurrent operations for an unconfigured pool",
),
"granularity": Field(
str,
is_required=False,
description="The granularity of the concurrency enforcement of the pool. One of `run` or `op`.",
default_value="run",
),
"op_run_buffer": Field(
int,
is_required=False,
description=(
"When the pool scope is set to `op`, this determines the number of runs "
"that can be launched with all of its steps blocked waiting for pool slots "
"to be freed."
),
),
}
),
"runs": Field(
{
"max_concurrent_runs": Field(
int,
is_required=False,
description=(
"The maximum number of runs that are allowed to be in progress at once."
" Defaults to 10. Set to -1 to disable the limit. Set to 0 to stop any runs"
" from launching. Any other negative values are disallowed."
),
),
"tag_concurrency_limits": Field(
config=Noneable(
Array(
Shape(
{
"key": String,
"value": Field(
ScalarUnion(
scalar_type=String,
non_scalar_schema=Shape(
{"applyLimitPerUniqueValue": Bool}
),
),
is_required=False,
),
"limit": Field(int),
}
)
)
),
is_required=False,
description=(
"A set of limits that are applied to runs with particular tags. If a value is"
" set, the limit is applied to only that key-value pair. If no value is set,"
" the limit is applied across all values of that key. If the value is set to a"
" dict with `applyLimitPerUniqueValue: true`, the limit will apply to the"
" number of unique values for that key."
),
),
}
),
"default_op_concurrency_limit": Field(
int,
is_required=False,
description="[Deprecated] The default maximum number of concurrent operations for an unconfigured concurrency key",
),
}
)


def dagster_instance_config_schema() -> Mapping[str, Field]:
return {
"local_artifact_storage": config_field_for_configurable_class(),
Expand Down Expand Up @@ -431,13 +586,5 @@ def dagster_instance_config_schema() -> Mapping[str, Field]:
),
}
),
"concurrency": Field(
{
"default_op_concurrency_limit": Field(
int,
is_required=False,
description="The default maximum number of concurrent operations for an unconfigured concurrency key",
),
}
),
"concurrency": get_concurrency_config(),
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ def __new__(
check.int_param(op_concurrency_slot_buffer, "op_concurrency_slot_buffer"),
)

def with_concurrency_settings(
self, concurrency_settings: Mapping[str, Any]
) -> "RunQueueConfig":
run_settings = concurrency_settings.get("runs", {})
pool_settings = concurrency_settings.get("pools", {})
return RunQueueConfig(
max_concurrent_runs=run_settings.get("max_concurrent_runs", self.max_concurrent_runs),
tag_concurrency_limits=run_settings.get(
"tag_concurrency_limits", self.tag_concurrency_limits
),
max_user_code_failure_retries=self.max_user_code_failure_retries,
user_code_failure_retry_delay=self.user_code_failure_retry_delay,
should_block_op_concurrency_limited_runs=self.should_block_op_concurrency_limited_runs,
op_concurrency_slot_buffer=pool_settings.get(
"op_run_buffer", self.op_concurrency_slot_buffer
),
)


class QueuedRunCoordinator(RunCoordinator[T_DagsterInstance], ConfigurableClass):
"""Enqueues runs via the run storage, to be deqeueued by the Dagster Daemon process. Requires
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
concurrency:
runs:
max_concurrent_runs: 5
run_coordinator:
module: dagster.core.run_coordinator
class: QueuedRunCoordinator
config:
max_concurrent_runs: 6
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
concurrency:
runs:
max_concurrent_runs: 5
run_queue:
max_concurrent_runs: 6
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
concurrency:
pools:
scope: op
op_run_buffer: 1
runs:
max_concurrent_runs: 5
tag_concurrency_limits:
- key: "dagster/solid_selection"
limit: 2
run_coordinator:
module: dagster.core.run_coordinator
class: QueuedRunCoordinator
config:
max_concurrent_runs: 5
tag_concurrency_limits:
- key: "dagster/solid_selection"
limit: 2
max_user_code_failure_retries: 3
user_code_failure_retry_delay: 10
block_op_concurrency_limited_runs:
enabled: true
op_concurrency_slot_buffer: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
concurrency:
pools:
scope: op
op_run_buffer: 1
runs:
max_concurrent_runs: 5
tag_concurrency_limits:
- key: "dagster/solid_selection"
limit: 2
run_queue:
max_concurrent_runs: 5
tag_concurrency_limits:
- key: "dagster/solid_selection"
limit: 2
max_user_code_failure_retries: 3
user_code_failure_retry_delay: 10
block_op_concurrency_limited_runs:
enabled: true
op_concurrency_slot_buffer: 1
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest
from dagster import file_relative_path
from dagster._core.errors import DagsterInvalidConfigError
from dagster._core.instance.config import dagster_instance_config
from dagster._core.test_utils import environ
from dagster._core.test_utils import environ, instance_for_test


@pytest.mark.parametrize("config_filename", ("dagster.yaml", "something.yaml"))
Expand All @@ -10,3 +11,44 @@ def test_instance_yaml_config_not_set(config_filename, caplog):
with environ({"DAGSTER_HOME": base_dir}):
dagster_instance_config(base_dir, config_filename)
assert "No dagster instance configuration file" in caplog.text


@pytest.mark.parametrize(
"config_filename",
(
"merged_run_coordinator_concurrency.yaml",
"merged_run_queue_concurrency.yaml",
),
)
def test_concurrency_config(config_filename, caplog):
base_dir = file_relative_path(__file__, "./test_config")
with environ({"DAGSTER_HOME": base_dir}):
instance_config, _ = dagster_instance_config(base_dir, config_filename)
with instance_for_test(overrides=instance_config) as instance:
run_queue_config = instance.get_run_queue_config()
assert run_queue_config
assert run_queue_config.max_concurrent_runs == 5
assert run_queue_config.tag_concurrency_limits == [
{
"key": "dagster/solid_selection",
"limit": 2,
}
]
assert run_queue_config.max_user_code_failure_retries == 3
assert run_queue_config.user_code_failure_retry_delay == 10
assert run_queue_config.should_block_op_concurrency_limited_runs
assert run_queue_config.op_concurrency_slot_buffer == 1


@pytest.mark.parametrize(
"config_filename",
(
"error_run_coordinator_concurrency_mismatch.yaml",
"error_run_queue_concurrency_mismatch.yaml",
),
)
def test_concurrency_config_mismatch(config_filename, caplog):
base_dir = file_relative_path(__file__, "./test_config")
with environ({"DAGSTER_HOME": base_dir}):
with pytest.raises(DagsterInvalidConfigError, match="for `concurrency > "):
dagster_instance_config(base_dir, config_filename)

0 comments on commit 8cfee9c

Please sign in to comment.