Skip to content

Commit

Permalink
Make AutomationCondition evaluate optionally async
Browse files Browse the repository at this point in the history
  • Loading branch information
briantu committed Oct 11, 2024
1 parent 26284dd commit ea234f1
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def _evaluate_entity_async(entity_key: EntityKey) -> int:
)

try:
self.evaluate_entity(entity_key)
await self.evaluate_entity(entity_key)
except Exception as e:
raise Exception(
f"Error while evaluating conditions for {entity_key.to_user_string()}"
Expand Down Expand Up @@ -161,10 +161,9 @@ async def _evaluate_entity_async(entity_key: EntityKey) -> int:

return list(self.current_results_by_key.values()), list(self._get_entity_subsets())

def evaluate_entity(self, key: EntityKey) -> None:
async def evaluate_entity(self, key: EntityKey) -> None:
# evaluate the condition of this asset
context = AutomationContext.create(key=key, evaluator=self)
result = context.condition.evaluate(context)
result = await AutomationContext.create(key=key, evaluator=self).evaluate_async()

# update dictionaries to keep track of this result
self.current_results_by_key[key] = result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import inspect
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Mapping, Optional, Type, TypeVar
Expand Down Expand Up @@ -116,6 +117,11 @@ def for_child_condition(
_root_log=self._root_log,
)

async def evaluate_async(self) -> AutomationResult[T_EntityKey]:
if inspect.iscoroutinefunction(self.condition.evaluate):
return await self.condition.evaluate(self)
return self.condition.evaluate(self)

@property
def log(self) -> logging.Logger:
"""The logger for the current condition evaluation."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
def requires_cursor(self) -> bool:
return False

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
child_results: List[AutomationResult] = []
true_subset = context.candidate_subset
for i, child in enumerate(self.children):
child_context = context.for_child_condition(
child_condition=child, child_index=i, candidate_subset=true_subset
)
child_result = child.evaluate(child_context)
child_result = await child_context.evaluate_async()
child_results.append(child_result)
true_subset = true_subset.compute_intersection(child_result.true_subset)
return AutomationResult(context, true_subset, child_results=child_results)
Expand Down Expand Up @@ -83,14 +85,16 @@ def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
def requires_cursor(self) -> bool:
return False

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
child_results: List[AutomationResult] = []
true_subset = context.get_empty_subset()
for i, child in enumerate(self.children):
child_context = context.for_child_condition(
child_condition=child, child_index=i, candidate_subset=context.candidate_subset
)
child_result = child.evaluate(child_context)
child_result = await child_context.evaluate_async()
child_results.append(child_result)
true_subset = true_subset.compute_union(child_result.true_subset)

Expand All @@ -116,11 +120,13 @@ def name(self) -> str:
def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
return [self.operand]

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
child_context = context.for_child_condition(
child_condition=self.operand, child_index=0, candidate_subset=context.candidate_subset
)
child_result = self.operand.evaluate(child_context)
child_result = await child_context.evaluate_async()
true_subset = context.candidate_subset.compute_difference(child_result.true_subset)

return AutomationResult(context, true_subset, child_results=[child_result])
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def base_description(self) -> str:
def name(self) -> str:
return "ANY_CHECKS_MATCH"

def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
async def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
check_results = []
true_subset = context.get_empty_subset()

for i, check_key in enumerate(
sorted(self._get_check_keys(context.key, context.asset_graph))
):
check_condition = EntityMatchesCondition(key=check_key, operand=self.operand)
check_result = check_condition.evaluate(
check_result = await check_condition.evaluate(
context.for_child_condition(
child_condition=check_condition,
child_index=i,
Expand All @@ -94,15 +94,15 @@ def base_description(self) -> str:
def name(self) -> str:
return "ALL_CHECKS_MATCH"

def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
async def evaluate(self, context: AutomationContext[AssetKey]) -> AutomationResult[AssetKey]:
check_results = []
true_subset = context.candidate_subset

for i, check_key in enumerate(
sorted(self._get_check_keys(context.key, context.asset_graph))
):
check_condition = EntityMatchesCondition(key=check_key, operand=self.operand)
check_result = check_condition.evaluate(
check_result = await check_condition.evaluate(
context.for_child_condition(
child_condition=check_condition,
child_index=i,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ class EntityMatchesCondition(
key: U_EntityKey
operand: AutomationCondition[U_EntityKey]

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
to_candidate_subset = context.candidate_subset.compute_mapped_subset(self.key)
to_context = context.for_child_condition(
child_condition=self.operand, child_index=0, candidate_subset=to_candidate_subset
)

to_result = self.operand.evaluate(to_context)
to_result = await to_context.evaluate_async()

true_subset = to_result.true_subset.compute_mapped_subset(context.key)
return AutomationResult(context=context, true_subset=true_subset, child_results=[to_result])
Expand Down Expand Up @@ -108,13 +110,15 @@ def base_description(self) -> str:
def name(self) -> str:
return "ANY_DEPS_MATCH"

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
dep_results = []
true_subset = context.get_empty_subset()

for i, dep_key in enumerate(sorted(self._get_dep_keys(context.key, context.asset_graph))):
dep_condition = EntityMatchesCondition(key=dep_key, operand=self.operand)
dep_result = dep_condition.evaluate(
dep_result = await dep_condition.evaluate(
context.for_child_condition(
child_condition=dep_condition,
child_index=i,
Expand All @@ -138,13 +142,15 @@ def base_description(self) -> str:
def name(self) -> str:
return "ALL_DEPS_MATCH"

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
dep_results = []
true_subset = context.candidate_subset

for i, dep_key in enumerate(sorted(self._get_dep_keys(context.key, context.asset_graph))):
dep_condition = EntityMatchesCondition(key=dep_key, operand=self.operand)
dep_result = dep_condition.evaluate(
dep_result = await dep_condition.evaluate(
context.for_child_condition(
child_condition=dep_condition,
child_index=i,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def _get_previous_child_true_subset(
return None
return context.asset_graph_view.get_subset_from_serializable_subset(true_subset)

def evaluate(self, context: AutomationContext) -> AutomationResult:
async def evaluate(self, context: AutomationContext) -> AutomationResult:
# evaluate child condition
child_context = context.for_child_condition(
self.operand,
child_index=0,
# must evaluate child condition over the entire subset to avoid missing state transitions
candidate_subset=context.asset_graph_view.get_full_subset(key=context.key),
)
child_result = self.operand.evaluate(child_context)
child_result = await child_context.evaluate_async()

# get the set of asset partitions of the child which newly became true
newly_true_child_subset = child_result.true_subset.compute_difference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,23 @@ def name(self) -> str:
def children(self) -> Sequence[AutomationCondition[T_EntityKey]]:
return [self.trigger_condition, self.reset_condition]

def evaluate(self, context: AutomationContext[T_EntityKey]) -> AutomationResult[T_EntityKey]:
async def evaluate(
self, context: AutomationContext[T_EntityKey]
) -> AutomationResult[T_EntityKey]:
# must evaluate child condition over the entire subset to avoid missing state transitions
child_candidate_subset = context.asset_graph_view.get_full_subset(key=context.key)

# compute result for trigger condition
trigger_context = context.for_child_condition(
self.trigger_condition, child_index=0, candidate_subset=child_candidate_subset
)
trigger_result = self.trigger_condition.evaluate(trigger_context)
trigger_result = await trigger_context.evaluate_async()

# compute result for reset condition
reset_context = context.for_child_condition(
self.reset_condition, child_index=1, candidate_subset=child_candidate_subset
)
reset_result = self.reset_condition.evaluate(reset_context)
reset_result = await reset_context.evaluate_async()

# take the previous subset that this was true for
true_subset = context.previous_true_subset or context.get_empty_subset()
Expand Down

0 comments on commit ea234f1

Please sign in to comment.