diff --git a/python_modules/dagster-ext/dagster_ext/__init__.py b/python_modules/dagster-ext/dagster_ext/__init__.py index 38366a039aeaf..e95ca7a14650f 100644 --- a/python_modules/dagster-ext/dagster_ext/__init__.py +++ b/python_modules/dagster-ext/dagster_ext/__init__.py @@ -39,6 +39,9 @@ # ##### PROTOCOL # ######################## +# This represents the version of the protocol, rather than the version of the package. It must be +# manually updated whenever there are changes to the protocol. +EXT_PROTOCOL_VERSION = "0.1" ExtExtras = Mapping[str, Any] ExtParams = Mapping[str, Any] @@ -62,8 +65,12 @@ def _param_name_to_env_key(key: str) -> str: # ##### MESSAGE +# Can't use a constant for TypedDict key so this value is repeated in `ExtMessage` defn. +EXT_PROTOCOL_VERSION_FIELD = "__dagster_ext_version" + class ExtMessage(TypedDict): + __dagster_ext_version: str method: str params: Optional[Mapping[str, Any]] @@ -677,7 +684,9 @@ def __init__( self._materialized_assets: set[str] = set() def _write_message(self, method: str, params: Optional[Mapping[str, Any]] = None) -> None: - message = ExtMessage(method=method, params=params) + message = ExtMessage( + EXT_PROTOCOL_VERSION_FIELD=EXT_PROTOCOL_VERSION, method=method, params=params + ) self._message_channel.write_message(message) # ######################## diff --git a/python_modules/dagster/dagster/_core/ext/utils.py b/python_modules/dagster/dagster/_core/ext/utils.py index 9ea033ee4ea73..b3a149e5b272c 100644 --- a/python_modules/dagster/dagster/_core/ext/utils.py +++ b/python_modules/dagster/dagster/_core/ext/utils.py @@ -10,6 +10,7 @@ from typing import Iterator, Optional from dagster_ext import ( + EXT_PROTOCOL_VERSION_FIELD, ExtContextData, ExtDefaultContextLoader, ExtDefaultMessageWriter, @@ -177,8 +178,7 @@ def extract_message_or_forward_to_stdout(handler: "ExtMessageHandler", log_line: # exceptions as control flow, you love to see it try: message = json.loads(log_line) - # need better message check - if message.keys() == {"method", "params"}: + if EXT_PROTOCOL_VERSION_FIELD in message.keys(): handler.handle_message(message) except Exception: # move non-message logs in to stdout for compute log capture