Skip to content

Commit

Permalink
partition api 4
Browse files Browse the repository at this point in the history
  • Loading branch information
jamiedemaria committed Jan 26, 2024
1 parent 04fd627 commit 5711594
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 131 deletions.
10 changes: 5 additions & 5 deletions examples/partition_example/partition_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def relativedelta(*args, **kwargs):
metadata={"partition_expr": "LastModifiedDate"},
)
def salesforce_customers(context: AssetExecutionContext) -> pd.DataFrame:
start_date_str = context.asset_partition_key_for_output()
start_date_str = context.partition_key

timezone = pytz.timezone("GMT") # Replace 'Your_Timezone' with the desired timezone
start_obj = datetime.datetime.strptime(start_date_str, "%Y-%m-%d").replace(tzinfo=timezone)
Expand Down Expand Up @@ -65,7 +65,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram
The volatility is calculated using various methods such as close-to-close, Parkinson, Hodges-Tompkins, and Yang-Zhang.
The function returns a DataFrame with the calculated volatilities.
"""
trade_date = context.asset_partition_key_for_output()
trade_date = context.partition_key
ticker_id = 1

df = all_realvols(orats_daily_prices, ticker_id, trade_date)
Expand All @@ -80,7 +80,7 @@ def realized_vol(context: AssetExecutionContext, orats_daily_prices: pd.DataFram

@asset(io_manager_def="parquet_io_manager", partitions_def=hourly_partitions)
def my_custom_df(context: AssetExecutionContext) -> pd.DataFrame:
start, end = context.asset_partitions_time_window_for_output()
start, end = context.partition_time_window

df = pd.DataFrame({"timestamp": pd.date_range(start, end, freq="5T")})
df["count"] = df["timestamp"].map(lambda a: random.randint(1, 1000))
Expand All @@ -93,7 +93,7 @@ def fetch_blog_posts_from_external_api(*args, **kwargs):

@asset(partitions_def=HourlyPartitionsDefinition(start_date="2022-01-01-00:00"))
def blog_posts(context: AssetExecutionContext) -> List[Dict]:
partition_datetime_str = context.asset_partition_key_for_output()
partition_datetime_str = context.partition_key
hour = datetime.datetime.fromisoformat(partition_datetime_str)
posts = fetch_blog_posts_from_external_api(hour_when_posted=hour)
return posts
Expand All @@ -106,7 +106,7 @@ def blog_posts(context: AssetExecutionContext) -> List[Dict]:
key_prefix=["snowflake", "eldermark_proxy"],
)
def resident(context: AssetExecutionContext) -> Output[pd.DataFrame]:
start, end = context.asset_partitions_time_window_for_output()
start, end = context.partition_time_window
filter_str = f"LastMod_Stamp >= {start.timestamp()} AND LastMod_Stamp < {end.timestamp()}"

records = context.resources.eldermark.fetch_obj(obj="Resident", filter=filter_str)
Expand Down
290 changes: 180 additions & 110 deletions python_modules/dagster/dagster/_core/execution/context/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,17 @@ def _copy_docs_from_op_execution_context(obj):
"dagster_run": "run",
"run_config": "run.run_config",
"run_tags": "run.tags",
"asset_partition_key_for_output": "partition_key",
"asset_partitions_time_window_for_output": "partition_time_window",
"asset_partition_key_range_for_output": "partition_key_range",
"asset_partition_key_range_for_input": "dep_context(asset_key).partition_key_range",
"asset_partition_key_for_input": "dep_context(asset_key).partition_key",
"asset_partitions_def_for_output": "assets_def.partitions_def",
"asset_partitions_def_for_input": "dep_context(asset_key).partitions_def",
"asset_partition_keys_for_output": "partition_keys",
"asset_partition_keys_for_input": "dep_context(asset_key).partition_keys",
"asset_partitions_time_window_for_input": "dep_context(asset_key).partition_time_window",
"has_partition_key": "is_partitioned_materialization",
}

ALTERNATE_EXPRESSIONS = {
Expand All @@ -1400,11 +1411,67 @@ def _get_deprecation_kwargs(attr: str):
return deprecation_kwargs


class DepContext:
def __init__(
self, asset_key: CoercibleToAssetKey, step_execution_context: StepExecutionContext
):
self._key = AssetKey.from_coercible(asset_key)
self._step_execution_context = step_execution_context

@property
def key(self) -> AssetKey:
return self._key

@property
def latest_materialization(self) -> Optional[AssetMaterialization]:
materialization_events = self._step_execution_context.upstream_asset_materialization_events
if self.key in materialization_events.keys():
return materialization_events.get(self.key)

raise DagsterInvariantViolationError(
f"Cannot fetch AssetMaterialization for asset {self.key}. {self.key} must be an upstream dependency"
"in order to call latest_materialization."
)

@property
def partition_key(self) -> str:
return self._step_execution_context.asset_partition_key_for_upstream_asset(self.key)

@property
def partition_keys(self) -> Sequence[str]:
return list(
self._step_execution_context.asset_partitions_subset_for_upstream_asset(
self.key
).get_partition_keys()
)

@property
def partition_time_window(self) -> TimeWindow:
return self._step_execution_context.asset_partitions_time_window_for_upstream_asset(
self.key
)

@property
def partition_key_range(self) -> PartitionKeyRange:
return self._step_execution_context.asset_partition_key_range_for_upstream_asset(self.key)

@property
def partitions_def(self) -> PartitionsDefinition:
result = self._step_execution_context.job_def.asset_layer.partitions_def_for_asset(self.key)
if result is None:
raise DagsterInvariantViolationError(
f"Attempting to access partitions def for asset {self.key}, but it is not partitioned"
)

return result


class AssetExecutionContext(OpExecutionContext):
def __init__(self, op_execution_context: OpExecutionContext) -> None:
self._op_execution_context = check.inst_param(
op_execution_context, "op_execution_context", OpExecutionContext
)
self._step_execution_context = op_execution_context._step_execution_context # noqa: SLF001

@staticmethod
def get() -> "AssetExecutionContext":
Expand Down Expand Up @@ -1469,30 +1536,55 @@ def job_def(self) -> JobDefinition:
return self.op_execution_context.job_def

@public
def latest_materialization_for_upstream_asset(
self, key: CoercibleToAssetKey
) -> Optional[AssetMaterialization]:
"""Get the most recent AssetMaterialization event for the key. The key must be an upstream
asset for the currently materializing asset. Information like metadata and tags can be found
on the AssetMaterialization. If the key is not an upstream asset of the currently
materializing asset, an error will be raised. If no AssetMaterialization exists for key, None
will be returned.
def dep_context(self, asset_key: CoercibleToAssetKey):
if (
self.job_def.asset_layer.input_for_asset_key(
self.op_execution_context.node_handle, AssetKey.from_coercible(asset_key)
)
is not None
):
return DepContext(
asset_key=asset_key, step_execution_context=self._step_execution_context
)
else:
raise DagsterInvariantViolationError(
f"Cannot access DepContext for asset {asset_key} since it is not an upstream dependency of {self.asset_key}."
)

Returns: Optional[AssetMaterialization]
"""
materialization_events = (
self.op_execution_context._step_execution_context.upstream_asset_materialization_events # noqa: SLF001
)
if AssetKey.from_coercible(key) in materialization_events.keys():
return materialization_events.get(AssetKey.from_coercible(key))
@public
@property
def is_partitioned_materialization(self) -> bool:
return self.op_execution_context.has_partition_key

raise DagsterInvariantViolationError(
f"Cannot fetch AssetMaterialization for asset {key}. {key} must be an upstream dependency"
"in order to call latest_materialization_for_upstream_asset."
)
@public
@property
def partition_key(self) -> str:
return self.op_execution_context.partition_key

@public
@property
def partition_keys(self) -> Sequence[str]:
return self.op_execution_context.partition_keys

@public
@property
def partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.partition_key_range

@public
@property
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

######## Deprecated methods

@deprecated(**_get_deprecation_kwargs("has_partition_key"))
@public
@property
@_copy_docs_from_op_execution_context
def has_partition_key(self) -> bool:
return self.op_execution_context.has_partition_key

@deprecated(**_get_deprecation_kwargs("dagster_run"))
@property
@_copy_docs_from_op_execution_context
Expand Down Expand Up @@ -1529,6 +1621,75 @@ def has_tag(self, key: str) -> bool:
def get_tag(self, key: str) -> Optional[str]:
return self.op_execution_context.get_tag(key)

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@public
@property
@_copy_docs_from_op_execution_context
def asset_partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_range_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_input(input_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_key_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_def_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_output"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@deprecated(**_get_deprecation_kwargs("asset_partition_keys_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name)

@deprecated(**_get_deprecation_kwargs("asset_partitions_time_window_for_input"))
@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_input(input_name)

########## pass-through to op context

#### op related
Expand Down Expand Up @@ -1646,97 +1807,6 @@ def step_launcher(self) -> Optional[StepLauncher]:
def get_step_execution_context(self) -> StepExecutionContext:
return self.op_execution_context.get_step_execution_context()

#### partition_related

@public
@property
@_copy_docs_from_op_execution_context
def has_partition_key(self) -> bool:
return self.op_execution_context.has_partition_key

@public
@property
@_copy_docs_from_op_execution_context
def partition_key(self) -> str:
return self.op_execution_context.partition_key

@public
@property
@_copy_docs_from_op_execution_context
def partition_keys(self) -> Sequence[str]:
return self.op_execution_context.partition_keys

@deprecated(breaking_version="2.0", additional_warn_text="Use `partition_key_range` instead.")
@public
@property
@_copy_docs_from_op_execution_context
def asset_partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range

@public
@property
@_copy_docs_from_op_execution_context
def partition_key_range(self) -> PartitionKeyRange:
return self.op_execution_context.partition_key_range

@public
@property
@_copy_docs_from_op_execution_context
def partition_time_window(self) -> TimeWindow:
return self.op_execution_context.partition_time_window

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_output(self, output_name: str = "result") -> str:
return self.op_execution_context.asset_partition_key_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_output(self, output_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_output(
self, output_name: str = "result"
) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_output(output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_range_for_input(self, input_name: str) -> PartitionKeyRange:
return self.op_execution_context.asset_partition_key_range_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_key_for_input(self, input_name: str) -> str:
return self.op_execution_context.asset_partition_key_for_input(input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_output(self, output_name: str = "result") -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_def_for_input(self, input_name: str) -> PartitionsDefinition:
return self.op_execution_context.asset_partitions_def_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_output(self, output_name: str = "result") -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_output(output_name=output_name)

@public
@_copy_docs_from_op_execution_context
def asset_partition_keys_for_input(self, input_name: str) -> Sequence[str]:
return self.op_execution_context.asset_partition_keys_for_input(input_name=input_name)

@public
@_copy_docs_from_op_execution_context
def asset_partitions_time_window_for_input(self, input_name: str = "result") -> TimeWindow:
return self.op_execution_context.asset_partitions_time_window_for_input(input_name)

#### Event log related

@_copy_docs_from_op_execution_context
Expand Down
Loading

0 comments on commit 5711594

Please sign in to comment.