From 27a9dd2aad1824030b4b3774d4fe297f406ab4df Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 11:10:38 -0700 Subject: [PATCH 01/15] pushing the refactor code only --- .../public_api_missing_threshold.py | 1 + ...mn_descriptive_metrics_metric_retriever.py | 193 --------------- .../metric_repository/metric_retriever.py | 223 +++++++++++++++++- 3 files changed, 222 insertions(+), 195 deletions(-) diff --git a/docs/sphinx_api_docs_source/public_api_missing_threshold.py b/docs/sphinx_api_docs_source/public_api_missing_threshold.py index 9724b81514a1..ce7ebcd69469 100644 --- a/docs/sphinx_api_docs_source/public_api_missing_threshold.py +++ b/docs/sphinx_api_docs_source/public_api_missing_threshold.py @@ -80,6 +80,7 @@ "File: great_expectations/expectations/regex_based_column_map_expectation.py Name: register_metric", "File: great_expectations/expectations/set_based_column_map_expectation.py Name: register_metric", "File: great_expectations/expectations/set_based_column_map_expectation.py Name: validate_configuration", + "File: great_expectations/experimental/metric_repository/metric_retriever.py Name: get_validator", "File: great_expectations/experimental/datasource/fabric.py Name: build_batch_request", "File: great_expectations/experimental/datasource/fabric.py Name: get_batch_list_from_batch_request", "File: great_expectations/profile/base.py Name: validate", diff --git a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py index 9a983bec42f6..d4ca0040363e 100644 --- a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py @@ -4,20 +4,14 @@ from typing import TYPE_CHECKING, Any, List, Sequence from great_expectations.compatibility.typing_extensions import override -from great_expectations.core.domain import SemanticDomainTypes -from great_expectations.datasource.fluent.interfaces import Batch from great_expectations.experimental.metric_repository.metric_retriever import ( MetricRetriever, ) from great_expectations.experimental.metric_repository.metrics import ( ColumnMetric, Metric, - MetricException, TableMetric, ) -from great_expectations.rule_based_profiler.domain_builder import ColumnDomainBuilder -from great_expectations.validator.exception_info import ExceptionInfo -from great_expectations.validator.metric_configuration import MetricConfiguration if TYPE_CHECKING: from great_expectations.data_context import AbstractDataContext @@ -27,7 +21,6 @@ _MetricKey, _MetricsDict, ) - from great_expectations.validator.validator import Validator class ColumnDescriptiveMetricsMetricRetriever(MetricRetriever): @@ -35,12 +28,6 @@ class ColumnDescriptiveMetricsMetricRetriever(MetricRetriever): def __init__(self, context: AbstractDataContext): super().__init__(context=context) - self._validator: Validator | None = None - - def get_validator(self, batch_request: BatchRequest) -> Validator: - if self._validator is None: - self._validator = self._context.get_validator(batch_request=batch_request) - return self._validator @override def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: @@ -171,52 +158,6 @@ def _get_table_column_types( exception=exception, ) - def _get_columns_to_exclude(self, table_column_types: Metric) -> List[str]: - columns_to_skip: List[str] = [] - for column_type in table_column_types.value: - if not column_type.get("type"): - columns_to_skip.append(column_type["name"]) - return columns_to_skip - - def _get_column_metrics( - self, - batch_request: BatchRequest, - column_list: List[str], - column_metric_names: List[str], - column_metric_type: type[ColumnMetric[Any]], - ) -> Sequence[Metric]: - column_metric_configs = self._generate_column_metric_configurations( - column_list, column_metric_names - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, column_metric_configs - ) - - # Convert computed_metrics - metrics: list[Metric] = [] - metric_lookup_key: _MetricKey - - for metric_name in column_metric_names: - for column in column_list: - metric_lookup_key = (metric_name, f"column={column}", tuple()) - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - metric_lookup_key=metric_lookup_key, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - metrics.append( - column_metric_type( - batch_id=batch_id, - metric_name=metric_name, - column=column, - value=value, - exception=exception, - ) - ) - - return metrics - def _get_numeric_column_metrics( self, batch_request: BatchRequest, column_list: List[str] ) -> Sequence[Metric]: @@ -290,137 +231,3 @@ def _get_non_numeric_column_metrics( ) return metrics - - def _get_all_column_names(self, metrics: Sequence[Metric]) -> List[str]: - column_list: List[str] = [] - for metric in metrics: - if metric.metric_name == "table.columns": - column_list = metric.value - return column_list - - def _get_numeric_column_names( - self, - batch_request: BatchRequest, - exclude_column_names: List[str], - ) -> list[str]: - """Get the names of all numeric columns in the batch.""" - return self._get_column_names_for_semantic_types( - batch_request=batch_request, - include_semantic_types=[SemanticDomainTypes.NUMERIC], - exclude_column_names=exclude_column_names, - ) - - def _get_timestamp_column_names( - self, - batch_request: BatchRequest, - exclude_column_names: List[str], - ) -> list[str]: - """Get the names of all timestamp columns in the batch.""" - return self._get_column_names_for_semantic_types( - batch_request=batch_request, - include_semantic_types=[SemanticDomainTypes.DATETIME], - exclude_column_names=exclude_column_names, - ) - - def _get_column_names_for_semantic_types( - self, - batch_request: BatchRequest, - include_semantic_types: List[SemanticDomainTypes], - exclude_column_names: List[str], - ) -> list[str]: - """Get the names of all columns matching semantic types in the batch.""" - validator = self.get_validator(batch_request=batch_request) - domain_builder = ColumnDomainBuilder( - include_semantic_types=include_semantic_types, # type: ignore[arg-type] # ColumnDomainBuilder supports other ways of specifying semantic types - exclude_column_names=exclude_column_names, - ) - assert isinstance( - validator.active_batch, Batch - ), f"validator.active_batch is type {type(validator.active_batch).__name__} instead of type {Batch.__name__}" - batch_id = validator.active_batch.id - column_names = domain_builder.get_effective_column_names( - validator=validator, - batch_ids=[batch_id], - ) - return column_names - - def _generate_table_metric_configurations( - self, table_metric_names: list[str] - ) -> list[MetricConfiguration]: - table_metric_configs = [ - MetricConfiguration( - metric_name=metric_name, metric_domain_kwargs={}, metric_value_kwargs={} - ) - for metric_name in table_metric_names - ] - return table_metric_configs - - def _generate_column_metric_configurations( - self, column_list: list[str], column_metric_names: list[str] - ) -> list[MetricConfiguration]: - column_metric_configs: List[MetricConfiguration] = list() - for metric_name in column_metric_names: - for column in column_list: - column_metric_configs.append( - MetricConfiguration( - metric_name=metric_name, - metric_domain_kwargs={"column": column}, - metric_value_kwargs={}, - ) - ) - return column_metric_configs - - def _compute_metrics( - self, batch_request: BatchRequest, metric_configs: list[MetricConfiguration] - ) -> tuple[str, _MetricsDict, _AbortedMetricsInfoDict]: - validator = self.get_validator(batch_request=batch_request) - # The runtime configuration catch_exceptions is explicitly set to True to catch exceptions - # that are thrown when computing metrics. This is so we can capture the error for later - # surfacing, and not have the entire metric run fail so that other metrics will still be - # computed. - ( - computed_metrics, - aborted_metrics, - ) = validator.compute_metrics( - metric_configurations=metric_configs, - runtime_configuration={"catch_exceptions": True}, - ) - assert isinstance( - validator.active_batch, Batch - ), f"validator.active_batch is type {type(validator.active_batch).__name__} instead of type {Batch.__name__}" - batch_id = validator.active_batch.id - return batch_id, computed_metrics, aborted_metrics - - def _get_metric_from_computed_metrics( - self, - metric_name: str, - computed_metrics: _MetricsDict, - aborted_metrics: _AbortedMetricsInfoDict, - metric_lookup_key: _MetricKey | None = None, - ) -> tuple[Any, MetricException | None]: - if metric_lookup_key is None: - metric_lookup_key = ( - metric_name, - tuple(), - tuple(), - ) - value = None - metric_exception = None - if metric_lookup_key in computed_metrics: - value = computed_metrics[metric_lookup_key] - elif metric_lookup_key in aborted_metrics: - exception = aborted_metrics[metric_lookup_key] - exception_info = exception["exception_info"] - exception_type = "Unknown" # Note: we currently only capture the message and traceback, not the type - if isinstance(exception_info, ExceptionInfo): - exception_message = exception_info.exception_message - metric_exception = MetricException( - type=exception_type, message=exception_message - ) - else: - metric_exception = MetricException( - type="Not found", - message="Metric was not successfully computed but exception was not found.", - ) - - return value, metric_exception diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index c5f431bd9cc9..c202f0723b49 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -2,19 +2,59 @@ import abc import uuid -from typing import TYPE_CHECKING, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Hashable, + List, + Sequence, + Tuple, + Union, +) + +from great_expectations.core.domain import SemanticDomainTypes +from great_expectations.datasource.fluent.interfaces import Batch +from great_expectations.experimental.metric_repository.metrics import ( + ColumnMetric, + MetricException, + MetricTypes, +) +from great_expectations.rule_based_profiler.domain_builder import ColumnDomainBuilder +from great_expectations.validator.computed_metric import MetricValue +from great_expectations.validator.exception_info import ExceptionInfo +from great_expectations.validator.metric_configuration import MetricConfiguration if TYPE_CHECKING: + from typing_extensions import TypeAlias + from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent import BatchRequest from great_expectations.experimental.metric_repository.metrics import Metric + from great_expectations.validator.validator import Validator + + +_MetricKey: TypeAlias = Union[Tuple[str, Hashable, Hashable], Tuple[str, str, str]] +_MetricsDict: TypeAlias = Dict[_MetricKey, MetricValue] +_AbortedMetricsInfoDict: TypeAlias = Dict[ + _MetricKey, + Dict[str, Union[MetricConfiguration, ExceptionInfo, int]], +] class MetricRetriever(abc.ABC): - """A MetricRetriever is responsible for retrieving metrics for a batch of data.""" + """A MetricRetriever is responsible for retrieving metrics for a batch of data. It is an ABC that contains base logic and + methods share by both the ColumnDescriptiveMetricsMetricReceiver and MetricListMetricRetriver. + """ def __init__(self, context: AbstractDataContext): self._context = context + self._validator: Validator | None = None + + def get_validator(self, batch_request: BatchRequest) -> Validator: + if self._validator is None: + self._validator = self._context.get_validator(batch_request=batch_request) + return self._validator @abc.abstractmethod def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: @@ -22,3 +62,182 @@ def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: def _generate_metric_id(self) -> uuid.UUID: return uuid.uuid4() + + def _get_metric_from_computed_metrics( + self, + metric_name: str, + computed_metrics: _MetricsDict, + aborted_metrics: _AbortedMetricsInfoDict, + metric_lookup_key: _MetricKey | None = None, + ) -> tuple[Any, MetricException | None]: + if metric_lookup_key is None: + metric_lookup_key = ( + metric_name, + tuple(), + tuple(), + ) + value = None + metric_exception = None + if metric_lookup_key in computed_metrics: + value = computed_metrics[metric_lookup_key] + elif metric_lookup_key in aborted_metrics: + exception = aborted_metrics[metric_lookup_key] + exception_info = exception["exception_info"] + exception_type = "Unknown" # Note: we currently only capture the message and traceback, not the type + if isinstance(exception_info, ExceptionInfo): + exception_message = exception_info.exception_message + metric_exception = MetricException( + type=exception_type, message=exception_message + ) + else: + metric_exception = MetricException( + type="Not found", + message="Metric was not successfully computed but exception was not found.", + ) + + return value, metric_exception + + def _generate_table_metric_configurations( + self, table_metric_names: list[str] + ) -> list[MetricConfiguration]: + table_metric_configs = [ + MetricConfiguration( + metric_name=metric_name, metric_domain_kwargs={}, metric_value_kwargs={} + ) + for metric_name in table_metric_names + ] + return table_metric_configs + + def _compute_metrics( + self, batch_request: BatchRequest, metric_configs: list[MetricConfiguration] + ) -> tuple[str, _MetricsDict, _AbortedMetricsInfoDict]: + validator = self.get_validator(batch_request=batch_request) + # The runtime configuration catch_exceptions is explicitly set to True to catch exceptions + # that are thrown when computing metrics. This is so we can capture the error for later + # surfacing, and not have the entire metric run fail so that other metrics will still be + # computed. + ( + computed_metrics, + aborted_metrics, + ) = validator.compute_metrics( + metric_configurations=metric_configs, + runtime_configuration={"catch_exceptions": True}, + ) + assert isinstance( + validator.active_batch, Batch + ), f"validator.active_batch is type {type(validator.active_batch).__name__} instead of type {Batch.__name__}" + batch_id = validator.active_batch.id + return batch_id, computed_metrics, aborted_metrics + + def _get_columns_to_exclude(self, table_column_types: Metric) -> List[str]: + columns_to_skip: List[str] = [] + for column_type in table_column_types.value: + if not column_type.get("type"): + columns_to_skip.append(column_type["name"]) + return columns_to_skip + + def _get_numeric_column_names( + self, + batch_request: BatchRequest, + exclude_column_names: List[str], + ) -> list[str]: + """Get the names of all numeric columns in the batch.""" + return self._get_column_names_for_semantic_types( + batch_request=batch_request, + include_semantic_types=[SemanticDomainTypes.NUMERIC], + exclude_column_names=exclude_column_names, + ) + + def _get_timestamp_column_names( + self, + batch_request: BatchRequest, + exclude_column_names: List[str], + ) -> list[str]: + """Get the names of all timestamp columns in the batch.""" + return self._get_column_names_for_semantic_types( + batch_request=batch_request, + include_semantic_types=[SemanticDomainTypes.DATETIME], + exclude_column_names=exclude_column_names, + ) + + def _get_column_names_for_semantic_types( + self, + batch_request: BatchRequest, + include_semantic_types: List[SemanticDomainTypes], + exclude_column_names: List[str], + ) -> list[str]: + """Get the names of all columns matching semantic types in the batch.""" + validator = self.get_validator(batch_request=batch_request) + domain_builder = ColumnDomainBuilder( + include_semantic_types=include_semantic_types, # type: ignore[arg-type] # ColumnDomainBuilder supports other ways of specifying semantic types + exclude_column_names=exclude_column_names, + ) + assert isinstance( + validator.active_batch, Batch + ), f"validator.active_batch is type {type(validator.active_batch).__name__} instead of type {Batch.__name__}" + batch_id = validator.active_batch.id + column_names = domain_builder.get_effective_column_names( + validator=validator, + batch_ids=[batch_id], + ) + return column_names + + def _get_column_metrics( + self, + batch_request: BatchRequest, + column_list: List[str], + column_metric_names: List[MetricTypes | str], + column_metric_type: type[ColumnMetric[Any]], + ) -> Sequence[Metric]: + column_metric_configs = self._generate_column_metric_configurations( + column_list, column_metric_names + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, column_metric_configs + ) + + # Convert computed_metrics + metrics: list[Metric] = [] + metric_lookup_key: _MetricKey + + for metric_name in column_metric_names: + for column in column_list: + metric_lookup_key = (metric_name, f"column={column}", tuple()) + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + metric_lookup_key=metric_lookup_key, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + metrics.append( + column_metric_type( + batch_id=batch_id, + metric_name=metric_name, + column=column, + value=value, + exception=exception, + ) + ) + return metrics + + def _generate_column_metric_configurations( + self, column_list: list[str], column_metric_names: list[str | MetricTypes] + ) -> list[MetricConfiguration]: + column_metric_configs: List[MetricConfiguration] = list() + for metric_name in column_metric_names: + for column in column_list: + column_metric_configs.append( + MetricConfiguration( + metric_name=metric_name, + metric_domain_kwargs={"column": column}, + metric_value_kwargs={}, + ) + ) + return column_metric_configs + + def _get_all_column_names(self, metrics: Sequence[Metric]) -> List[str]: + column_list: List[str] = [] + for metric in metrics: + if metric.metric_name == MetricTypes.TABLE_COLUMNS: + column_list = metric.value + return column_list From 8a0ce38b31a0b127ed79a269679e67bc8a4b3d15 Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 13:13:09 -0700 Subject: [PATCH 02/15] clean up --- .../metric_repository/metric_retriever.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index c202f0723b49..fba838f6c90a 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -5,12 +5,8 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Hashable, List, Sequence, - Tuple, - Union, ) from great_expectations.core.domain import SemanticDomainTypes @@ -21,25 +17,19 @@ MetricTypes, ) from great_expectations.rule_based_profiler.domain_builder import ColumnDomainBuilder -from great_expectations.validator.computed_metric import MetricValue from great_expectations.validator.exception_info import ExceptionInfo from great_expectations.validator.metric_configuration import MetricConfiguration if TYPE_CHECKING: - from typing_extensions import TypeAlias - from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent import BatchRequest from great_expectations.experimental.metric_repository.metrics import Metric - from great_expectations.validator.validator import Validator - - -_MetricKey: TypeAlias = Union[Tuple[str, Hashable, Hashable], Tuple[str, str, str]] -_MetricsDict: TypeAlias = Dict[_MetricKey, MetricValue] -_AbortedMetricsInfoDict: TypeAlias = Dict[ - _MetricKey, - Dict[str, Union[MetricConfiguration, ExceptionInfo, int]], -] + from great_expectations.validator.validator import ( + Validator, + _AbortedMetricsInfoDict, + _MetricKey, + _MetricsDict, + ) class MetricRetriever(abc.ABC): From 51a119b9234b8ee24d58454d64c8091f369ea4c4 Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 13:29:30 -0700 Subject: [PATCH 03/15] send this up first --- .../metric_list_metric_retriever.py | 316 +++++++++ .../test_metric_list_metric_retriever.py | 640 ++++++++++++++++++ ...etric_list_metric_retriever_integration.py | 286 ++++++++ 3 files changed, 1242 insertions(+) create mode 100644 great_expectations/experimental/metric_repository/metric_list_metric_retriever.py create mode 100644 tests/experimental/metric_repository/test_metric_list_metric_retriever.py create mode 100644 tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py new file mode 100644 index 000000000000..404eec0dc40a --- /dev/null +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING, Any, List, Optional, Sequence + +from great_expectations.compatibility.typing_extensions import override +from great_expectations.experimental.metric_repository.metric_retriever import ( + MetricRetriever, +) +from great_expectations.experimental.metric_repository.metrics import ( + ColumnMetric, + Metric, + MetricTypes, + TableMetric, +) + +if TYPE_CHECKING: + from great_expectations.data_context import AbstractDataContext + from great_expectations.datasource.fluent.batch_request import BatchRequest + from great_expectations.validator.validator import ( + Validator, + _MetricKey, + ) + + +class MetricListMetricRetriever(MetricRetriever): + def __init__(self, context: AbstractDataContext): + super().__init__(context=context) + self._validator: Validator | None = None + + @override + def get_metrics( + self, + batch_request: BatchRequest, + metric_list: Optional[List[MetricTypes]] = None, + ) -> Sequence[Metric]: + metrics_result: List[Metric] = [] + + if not metric_list: + raise ValueError("metric_list cannot be empty") + + self._check_valid_metric_types(metric_list) + + table_metrics = self._get_table_metrics( + batch_request=batch_request, metric_list=metric_list + ) + metrics_result.extend(table_metrics) + + # exit early if only Table Metrics exist + if not self._column_metrics_in_metric_list(metric_list): + return metrics_result + + table_column_types = list( + filter( + lambda m: m.metric_name == MetricTypes.TABLE_COLUMN_TYPES, table_metrics + ) + )[0] + + # We need to skip columns that do not report a type, because the metric computation + # to determine semantic type will fail. + exclude_column_names = self._get_columns_to_exclude(table_column_types) + + numeric_column_names = self._get_numeric_column_names( + batch_request=batch_request, exclude_column_names=exclude_column_names + ) + timestamp_column_names = self._get_timestamp_column_names( + batch_request=batch_request, exclude_column_names=exclude_column_names + ) + numeric_column_metrics = self._get_numeric_column_metrics( + metric_list, batch_request, numeric_column_names + ) + timestamp_column_metrics = self._get_timestamp_column_metrics( + metric_list, batch_request, timestamp_column_names + ) + all_column_names: List[str] = self._get_all_column_names(table_metrics) + non_numeric_column_metrics = self._get_non_numeric_column_metrics( + metric_list, batch_request, all_column_names + ) + + bundled_list = list( + chain( + table_metrics, + numeric_column_metrics, + timestamp_column_metrics, + non_numeric_column_metrics, + ) + ) + + return bundled_list + + def _get_non_numeric_column_metrics( + self, + metrics_list: List[MetricTypes], + batch_request: BatchRequest, + column_list: List[str], + ) -> Sequence[Metric]: + column_metric_names = {MetricTypes.COLUMN_NULL_COUNT} + metrics: list[Metric] = [] + metrics_list_as_set = set(metrics_list) + metrics_to_calculate = sorted( + column_metric_names.intersection(metrics_list_as_set) + ) + + if not metrics_to_calculate: + return metrics + + column_metric_configs = self._generate_column_metric_configurations( + column_list, list(metrics_to_calculate) + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, column_metric_configs + ) + + # Convert computed_metrics + ColumnMetric.update_forward_refs() + metric_lookup_key: _MetricKey + + for metric_name in metrics_to_calculate: + for column in column_list: + metric_lookup_key = (metric_name, f"column={column}", tuple()) + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + metric_lookup_key=metric_lookup_key, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + metrics.append( + ColumnMetric[int]( + batch_id=batch_id, + metric_name=metric_name, + column=column, + value=value, + exception=exception, + ) + ) + + return metrics + + def _get_numeric_column_metrics( + self, + metrics_list: List[MetricTypes], + batch_request: BatchRequest, + column_list: List[str], + ) -> Sequence[Metric]: + metrics: list[Metric] = [] + column_metric_names = { + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_MEDIAN, + } + metrics_list_as_set = set(metrics_list) + metrics_to_calculate = sorted( + column_metric_names.intersection(metrics_list_as_set) + ) + if not metrics_to_calculate: + return metrics + + return self._get_column_metrics( + batch_request=batch_request, + column_list=column_list, + column_metric_names=list(metrics_to_calculate), + column_metric_type=ColumnMetric[float], + ) + + def _get_timestamp_column_metrics( + self, + metrics_list: List[MetricTypes], + batch_request: BatchRequest, + column_list: List[str], + ) -> Sequence[Metric]: + metrics: list[Metric] = [] + column_metric_names = { + "column.min", + "column.max", + # "column.mean", # Currently not supported for timestamp in Snowflake + # "column.median", # Currently not supported for timestamp in Snowflake + } + metrics_list_as_set = set(metrics_list) + metrics_to_calculate = sorted( + column_metric_names.intersection(metrics_list_as_set) + ) + if not metrics_to_calculate: + return metrics + + # Note: Timestamps are returned as strings for Snowflake, this may need to be adjusted + # when we support other datasources. For example in Pandas, timestamps can be returned as Timestamp(). + return self._get_column_metrics( + batch_request=batch_request, + column_list=column_list, + column_metric_names=list(metrics_to_calculate), + column_metric_type=ColumnMetric[str], + ) + + def _get_table_metrics( + self, batch_request: BatchRequest, metric_list: List[MetricTypes] + ) -> List[Metric]: + metrics: List[Metric] = [] + if MetricTypes.TABLE_ROW_COUNT in metric_list: + metrics.append(self._get_table_row_count(batch_request=batch_request)) + if MetricTypes.TABLE_COLUMNS in metric_list: + metrics.append(self._get_table_columns(batch_request=batch_request)) + if MetricTypes.TABLE_COLUMN_TYPES in metric_list: + metrics.append(self._get_table_column_types(batch_request=batch_request)) + return metrics + + def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: + table_metric_configs = self._generate_table_metric_configurations( + table_metric_names=[MetricTypes.TABLE_ROW_COUNT] + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, table_metric_configs + ) + metric_name = MetricTypes.TABLE_ROW_COUNT + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + return TableMetric[int]( + batch_id=batch_id, + metric_name=metric_name, + value=value, + exception=exception, + ) + + def _get_table_columns(self, batch_request: BatchRequest) -> Metric: + table_metric_configs = self._generate_table_metric_configurations( + table_metric_names=[MetricTypes.TABLE_COLUMNS] + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, table_metric_configs + ) + metric_name = MetricTypes.TABLE_COLUMNS + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + return TableMetric[List[str]]( + batch_id=batch_id, + metric_name=metric_name, + value=value, + exception=exception, + ) + + def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: + table_metric_configs = self._generate_table_metric_configurations( + table_metric_names=[MetricTypes.TABLE_COLUMN_TYPES] + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, table_metric_configs + ) + metric_name = MetricTypes.TABLE_COLUMN_TYPES + metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + metric_lookup_key=metric_lookup_key, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + raw_column_types: list[dict[str, Any]] = value + # If type is not found, don't add empty type field. This can happen if our db introspection fails. + column_types_converted_to_str: list[dict[str, str]] = [] + for raw_column_type in raw_column_types: + if raw_column_type.get("type"): + column_types_converted_to_str.append( + { + "name": raw_column_type["name"], + "type": str(raw_column_type["type"]), + } + ) + else: + column_types_converted_to_str.append({"name": raw_column_type["name"]}) + + return TableMetric[List[str]]( + batch_id=batch_id, + metric_name=metric_name, + value=column_types_converted_to_str, + exception=exception, + ) + + def _check_valid_metric_types(self, metric_list: List[MetricTypes]) -> bool: + """Check whether all the metric types in the list are valid. + + Args: + metric_list (List[MetricTypes]): list of MetricTypes that are passed in to MetricListMetricRetriever. + + Returns: + bool: True if all the metric types in the list are valid, False otherwise. + """ + for metric in metric_list: + if metric not in MetricTypes: + return False + return True + + def _column_metrics_in_metric_list(self, metric_list: List[MetricTypes]) -> bool: + """Helper method to check whether any column metrics are present in the metric list. + + Args: + metric_list (List[MetricTypes]): list of MetricTypes that are passed in to MetricListMetricRetriever. + + Returns: + bool: True if any column metrics are present in the metric list, False otherwise. + """ + column_metrics: List[MetricTypes] = [ + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEDIAN, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_NULL_COUNT, + ] + for metric in column_metrics: + if metric in metric_list: + return True + return False diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py new file mode 100644 index 000000000000..e3cf0902fd26 --- /dev/null +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py @@ -0,0 +1,640 @@ +from typing import List +from unittest import mock +from unittest.mock import Mock + +import pytest + +from great_expectations.data_context import CloudDataContext +from great_expectations.datasource.fluent import BatchRequest +from great_expectations.datasource.fluent.interfaces import Batch +from great_expectations.experimental.metric_repository.metric_list_metric_retriever import ( + MetricListMetricRetriever, +) +from great_expectations.experimental.metric_repository.metrics import ( + ColumnMetric, + MetricException, + MetricTypes, + TableMetric, +) +from great_expectations.rule_based_profiler.domain_builder import ColumnDomainBuilder +from great_expectations.validator.exception_info import ExceptionInfo +from great_expectations.validator.validator import Validator + +pytestmark = pytest.mark.unit + + +def test_get_metrics_table_metrics_only(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + computed_metrics = { + ("table.row_count", (), ()): 2, + ("table.columns", (), ()): ["col1", "col2"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + } + table_metrics_list = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + ] + aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, + metric_list=table_metrics_list, + ) + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=2, + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + exception=None, + ), + ] + + +def test_get_metrics_full_list(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + computed_metrics = { + ("table.row_count", (), ()): 2, + ("table.columns", (), ()): ["col1", "col2"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + ("column.min", "column=col1", ()): 2.5, + ("column.min", "column=col2", ()): 2.7, + ("column.max", "column=col1", ()): 5.5, + ("column.max", "column=col2", ()): 5.7, + ("column.mean", "column=col1", ()): 2.5, + ("column.mean", "column=col2", ()): 2.7, + ("column.median", "column=col1", ()): 2.5, + ("column.median", "column=col2", ()): 2.7, + ("column_values.null.count", "column=col1", ()): 1, + ("column_values.null.count", "column=col2", ()): 1, + } + cdm_metrics_list = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_MEDIAN, + MetricTypes.COLUMN_NULL_COUNT, + ] + aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + with mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", + return_value=["col1", "col2"], + ), mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", + return_value=[], + ): + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, + metric_list=cdm_metrics_list, + ) + + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", metric_name="table.row_count", value=2, exception=None + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + exception=None, + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.max", + value=5.5, + exception=None, + column="col1", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.max", + value=5.7, + exception=None, + column="col2", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.mean", + value=2.5, + exception=None, + column="col1", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.mean", + value=2.7, + exception=None, + column="col2", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.median", + value=2.5, + exception=None, + column="col1", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.median", + value=2.7, + exception=None, + column="col2", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + value=2.5, + exception=None, + column="col1", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + value=2.7, + exception=None, + column="col2", + ), + ColumnMetric[int]( + batch_id="batch_id", + metric_name="column_values.null.count", + value=1, + exception=None, + column="col1", + ), + ColumnMetric[int]( + batch_id="batch_id", + metric_name="column_values.null.count", + value=1, + exception=None, + column="col2", + ), + ] + + +def test_get_metrics_metrics_missing(): + """This test is meant to simulate metrics missing from the computed metrics.""" + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + mock_computed_metrics = { + # ("table.row_count", (), ()): 2, # Missing table.row_count metric + ("table.columns", (), ()): ["col1", "col2"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + # ("column.min", "column=col1", ()): 2.5, # Missing column.min metric for col1 + ("column.min", "column=col2", ()): 2.7, + } + + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + ] + mock_aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + mock_computed_metrics, + mock_aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + with mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", + return_value=["col1", "col2"], + ), mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", + return_value=[], + ): + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=None, + exception=MetricException( + type="Not found", + message="Metric was not successfully computed but exception was not found.", + ), + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + exception=None, + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + value=None, + exception=MetricException( + type="Not found", + message="Metric was not successfully computed but exception was not found.", + ), + column="col1", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + value=2.7, + exception=None, + column="col2", + ), + ] + + +def test_get_metrics_with_exception(): + """This test is meant to simulate failed metrics in the computed metrics.""" + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + + exception_info = ExceptionInfo( + exception_traceback="test exception traceback", + exception_message="test exception message", + raised_exception=True, + ) + aborted_metrics = { + ("table.row_count", (), ()): { + "metric_configuration": {}, # Leaving out for brevity + "num_failures": 3, + "exception_info": exception_info, + }, + ("column.min", "column=col1", ()): { + "metric_configuration": {}, # Leaving out for brevity + "num_failures": 3, + "exception_info": exception_info, + }, + } + computed_metrics = { + # ("table.row_count", (), ()): 2, # Error in table.row_count metric + ("table.columns", (), ()): ["col1", "col2"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + } + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + ] + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=None, + exception=MetricException(type="Unknown", message="test exception message"), + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + exception=None, + ), + ] + + +def test_get_metrics_with_column_type_missing(): + """This test is meant to simulate failed metrics in the computed metrics.""" + """This test is meant to simulate failed metrics in the computed metrics.""" + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + + exception_info = ExceptionInfo( + exception_traceback="test exception traceback", + exception_message="test exception message", + raised_exception=True, + ) + + aborted_metrics = { + ("table.row_count", (), ()): { + "metric_configuration": {}, # Leaving out for brevity + "num_failures": 3, + "exception_info": exception_info, + }, + ("column.min", "column=col1", ()): { + "metric_configuration": {}, # Leaving out for brevity + "num_failures": 3, + "exception_info": exception_info, + }, + } + + computed_metrics = { + # ("table.row_count", (), ()): 2, # Error in table.row_count metric + ("table.columns", (), ()): ["col1", "col2"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + { + "name": "col2", + }, # Missing type for col2 + ], + # ("column.min", "column=col1", ()): 2.5, # Error in column.min metric for col1 + ("column.min", "column=col2", ()): 2.7, + } + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + ] + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + with mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", + return_value=["col1", "col2"], + ), mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", + return_value=[], + ): + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + # why is this not sorted? + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=None, + exception=MetricException(type="Unknown", message="test exception message"), + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[ + {"name": "col1", "type": "float"}, + { + "name": "col2", + }, # Note: No type for col2 + ], + exception=None, + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + column="col1", + value=None, + exception=MetricException(type="Unknown", message="test exception message"), + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + column="col2", + value=2.7, + exception=None, + ), + ] + + +def test_get_metrics_with_timestamp_columns(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + computed_metrics = { + ("table.row_count", (), ()): 2, + ("table.columns", (), ()): ["timestamp_col"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "timestamp_col", "type": "TIMESTAMP_NTZ"}, + ], + ("column.min", "column=timestamp_col", ()): "2023-01-01T00:00:00", + ("column.max", "column=timestamp_col", ()): "2023-12-31T00:00:00", + ("column_values.null.count", "column=timestamp_col", ()): 1, + } + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_NULL_COUNT, + ] + aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + with mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", + return_value=[], + ), mock.patch( + f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", + return_value=["timestamp_col"], + ): + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=2, + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["timestamp_col"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[{"name": "timestamp_col", "type": "TIMESTAMP_NTZ"}], + exception=None, + ), + ColumnMetric[str]( + batch_id="batch_id", + metric_name="column.max", + value="2023-12-31T00:00:00", + exception=None, + column="timestamp_col", + ), + ColumnMetric[str]( + batch_id="batch_id", + metric_name="column.min", + value="2023-01-01T00:00:00", + exception=None, + column="timestamp_col", + ), + ColumnMetric[int]( + batch_id="batch_id", + metric_name="column_values.null.count", + value=1, + exception=None, + column="timestamp_col", + ), + ] + + +def test_get_metrics_only_gets_a_validator_once(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + + aborted_metrics = {} + + computed_metrics = { + ("table.row_count", (), ()): 2, + ("table.columns", (), ()): ["col1", "col2"], + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + } + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + ] + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + with mock.patch( + f"{ColumnDomainBuilder.__module__}.{ColumnDomainBuilder.__name__}.get_effective_column_names", + return_value=["col1", "col2"], + ): + metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + + mock_context.get_validator.assert_called_once_with(batch_request=mock_batch_request) diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py new file mode 100644 index 000000000000..3726f32422df --- /dev/null +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py @@ -0,0 +1,286 @@ +"""Test using actual sample data.""" +from __future__ import annotations + +from typing import List + +import pandas as pd +import pytest +from pandas import Timestamp + +from great_expectations.data_context import CloudDataContext +from great_expectations.datasource.fluent.batch_request import BatchRequest +from great_expectations.experimental.metric_repository.metric_list_metric_retriever import ( + MetricListMetricRetriever, +) +from great_expectations.experimental.metric_repository.metrics import ( + ColumnMetric, + MetricTypes, + TableMetric, +) + + +@pytest.fixture +def cloud_context_and_batch_request_with_simple_dataframe( + empty_cloud_context_fluent: CloudDataContext, # used as a fixture +): + context = empty_cloud_context_fluent + datasource = context.sources.add_pandas(name="my_pandas_datasource") + + d = { + "numeric_with_nulls_1": [1, 2, None], + "numeric_with_nulls_2": [3, 4, None], + "string": ["a", "b", "c"], + "string_with_nulls": ["a", "b", None], + "boolean": [True, False, True], + "datetime": [ + pd.to_datetime("2020-01-01"), + pd.to_datetime("2020-01-02"), + pd.to_datetime("2020-01-03"), + ], + } + df = pd.DataFrame(data=d) + + name = "dataframe" + data_asset = datasource.add_dataframe_asset(name=name) + batch_request = data_asset.build_batch_request(dataframe=df) + return context, batch_request + + +@pytest.mark.cloud +def test_get_metrics_table_metrics_only( + cloud_context_and_batch_request_with_simple_dataframe: tuple[ + CloudDataContext, BatchRequest + ], +): + context, batch_request = cloud_context_and_batch_request_with_simple_dataframe + table_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + ] + metric_retriever = MetricListMetricRetriever(context) + metrics = metric_retriever.get_metrics( + batch_request=batch_request, metric_list=table_metrics_list + ) + validator = context.get_validator(batch_request=batch_request) + batch_id = validator.active_batch.id + + expected_metrics = [ + TableMetric[int]( + batch_id=batch_id, + metric_name="table.row_count", + value=3, + exception=None, + ), + TableMetric[List[str]]( + batch_id=batch_id, + metric_name="table.columns", + value=[ + "numeric_with_nulls_1", + "numeric_with_nulls_2", + "string", + "string_with_nulls", + "boolean", + "datetime", + ], + exception=None, + ), + TableMetric[List[str]]( + batch_id=batch_id, + metric_name="table.column_types", + value=[ + {"name": "numeric_with_nulls_1", "type": "float64"}, + {"name": "numeric_with_nulls_2", "type": "float64"}, + {"name": "string", "type": "object"}, + {"name": "string_with_nulls", "type": "object"}, + {"name": "boolean", "type": "bool"}, + {"name": "datetime", "type": "datetime64[ns]"}, + ], + exception=None, + ), + ] + + # Assert each metric so it is easier to see which one fails (instead of assert metrics == expected_metrics): + assert len(metrics) == len(expected_metrics) + for metric in metrics: + assert metric.dict() in [ + expected_metric.dict() for expected_metric in expected_metrics + ] + + +@pytest.mark.cloud +def test_get_metrics_full_cdm( + cloud_context_and_batch_request_with_simple_dataframe: tuple[ + CloudDataContext, BatchRequest + ], +): + context, batch_request = cloud_context_and_batch_request_with_simple_dataframe + cdm_metrics_list: List[MetricTypes] = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_MEDIAN, + MetricTypes.COLUMN_NULL_COUNT, + ] + metric_retriever = MetricListMetricRetriever(context) + metrics = metric_retriever.get_metrics( + batch_request=batch_request, metric_list=cdm_metrics_list + ) + validator = context.get_validator(batch_request=batch_request) + batch_id = validator.active_batch.id + + expected_metrics = [ + TableMetric[int]( + batch_id=batch_id, + metric_name="table.row_count", + value=3, + exception=None, + ), + TableMetric[List[str]]( + batch_id=batch_id, + metric_name="table.columns", + value=[ + "numeric_with_nulls_1", + "numeric_with_nulls_2", + "string", + "string_with_nulls", + "boolean", + "datetime", + ], + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.min", + column="numeric_with_nulls_1", + value=1, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.min", + column="numeric_with_nulls_2", + value=3, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.max", + column="numeric_with_nulls_1", + value=2, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.max", + column="numeric_with_nulls_2", + value=4, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.mean", + column="numeric_with_nulls_1", + value=1.5, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.mean", + column="numeric_with_nulls_2", + value=3.5, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.median", + column="numeric_with_nulls_1", + value=1.5, + exception=None, + ), + ColumnMetric[float]( + batch_id=batch_id, + metric_name="column.median", + column="numeric_with_nulls_2", + value=3.5, + exception=None, + ), + TableMetric[List[str]]( + batch_id=batch_id, + metric_name="table.column_types", + value=[ + {"name": "numeric_with_nulls_1", "type": "float64"}, + {"name": "numeric_with_nulls_2", "type": "float64"}, + {"name": "string", "type": "object"}, + {"name": "string_with_nulls", "type": "object"}, + {"name": "boolean", "type": "bool"}, + {"name": "datetime", "type": "datetime64[ns]"}, + ], + exception=None, + ), + ColumnMetric[int]( + batch_id=batch_id, + metric_name="column_values.null.count", + column="numeric_with_nulls_1", + value=1, + exception=None, + ), + ColumnMetric[int]( + batch_id=batch_id, + metric_name="column_values.null.count", + column="numeric_with_nulls_2", + value=1, + exception=None, + ), + ColumnMetric[int]( + batch_id=batch_id, + metric_name="column_values.null.count", + column="string", + value=0, + exception=None, + ), + ColumnMetric[int]( + batch_id=batch_id, + metric_name="column_values.null.count", + column="string_with_nulls", + value=1, + exception=None, + ), + ColumnMetric[int]( + batch_id=batch_id, + metric_name="column_values.null.count", + column="boolean", + value=0, + exception=None, + ), + ColumnMetric[int]( + batch_id=batch_id, + metric_name="column_values.null.count", + column="datetime", + value=0, + exception=None, + ), + ColumnMetric[str]( + batch_id=batch_id, + metric_name="column.min", + value=Timestamp("2020-01-01 00:00:00"), + exception=None, + column="datetime", + ), + ColumnMetric[str]( + batch_id=batch_id, + metric_name="column.max", + value=Timestamp("2020-01-03 00:00:00"), + exception=None, + column="datetime", + ), + ] + + assert len(metrics) == len(expected_metrics) + for metric in metrics: + assert metric.dict() in [ + expected_metric.dict() for expected_metric in expected_metrics + ] From 6f6b66ce12be6979ec806c6bf5cedde23f72c352 Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 13:30:42 -0700 Subject: [PATCH 04/15] from metrics_calculator --- .../experimental/metric_repository/metric_retriever.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index fba838f6c90a..8ef508bae84b 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -24,12 +24,14 @@ from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent import BatchRequest from great_expectations.experimental.metric_repository.metrics import Metric - from great_expectations.validator.validator import ( - Validator, + from great_expectations.validator.metrics_calculator import ( _AbortedMetricsInfoDict, _MetricKey, _MetricsDict, ) + from great_expectations.validator.validator import ( + Validator, + ) class MetricRetriever(abc.ABC): From 35201de57277b15fe0a154930477a1a62567d557 Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 15:34:49 -0700 Subject: [PATCH 05/15] pushing doc strings --- .../metric_list_metric_retriever.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py index 404eec0dc40a..2608d95d705b 100644 --- a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -94,6 +94,17 @@ def _get_non_numeric_column_metrics( batch_request: BatchRequest, column_list: List[str], ) -> Sequence[Metric]: + """Calculate column metrics for non-numeric columns. + + Args: + metrics_list (List[MetricTypes]): list of metrics sent from Agent. + batch_request (BatchRequest): for current batch. + column_list (List[str]): list of non-numeric columns. + + Returns: + Sequence[Metric]: List of metrics for non-numeric columns. + """ + # currently only the null-count is supported. If more metrics are added, this set will need to be updated. column_metric_names = {MetricTypes.COLUMN_NULL_COUNT} metrics: list[Metric] = [] metrics_list_as_set = set(metrics_list) @@ -142,6 +153,16 @@ def _get_numeric_column_metrics( batch_request: BatchRequest, column_list: List[str], ) -> Sequence[Metric]: + """Calculate column metrics for numeric columns. + + Args: + metrics_list (List[MetricTypes]): list of metrics sent from Agent. + batch_request (BatchRequest): for current batch. + column_list (List[str]): list of numeric columns. + + Returns: + Sequence[Metric]: List of metrics for numeric columns. + """ metrics: list[Metric] = [] column_metric_names = { MetricTypes.COLUMN_MIN, @@ -169,6 +190,16 @@ def _get_timestamp_column_metrics( batch_request: BatchRequest, column_list: List[str], ) -> Sequence[Metric]: + """Calculate column metrics for timestamp columns. + + Args: + metrics_list (List[MetricTypes]): list of metrics sent from Agent. + batch_request (BatchRequest): for current batch. + column_list (List[str]): list of timestamp columns. + + Returns: + Sequence[Metric]: List of metrics for timestamp columns. + """ metrics: list[Metric] = [] column_metric_names = { "column.min", @@ -195,6 +226,15 @@ def _get_timestamp_column_metrics( def _get_table_metrics( self, batch_request: BatchRequest, metric_list: List[MetricTypes] ) -> List[Metric]: + """Calculate column metrics for table metrics, which include row_count, column names and types. + + Args: + metrics_list (List[MetricTypes]): list of metrics sent from Agent. + batch_request (BatchRequest): for current batch. + + Returns: + Sequence[Metric]: List of table metrics. + """ metrics: List[Metric] = [] if MetricTypes.TABLE_ROW_COUNT in metric_list: metrics.append(self._get_table_row_count(batch_request=batch_request)) @@ -205,6 +245,14 @@ def _get_table_metrics( return metrics def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: + """Return row_count for the table. + + Args: + batch_request (BatchRequest): For current batch. + + Returns: + Metric: Row count for the table. + """ table_metric_configs = self._generate_table_metric_configurations( table_metric_names=[MetricTypes.TABLE_ROW_COUNT] ) @@ -225,6 +273,14 @@ def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: ) def _get_table_columns(self, batch_request: BatchRequest) -> Metric: + """Return column names for the table. + + Args: + batch_request (BatchRequest): For current batch. + + Returns: + Metric: Column names for the table. + """ table_metric_configs = self._generate_table_metric_configurations( table_metric_names=[MetricTypes.TABLE_COLUMNS] ) @@ -245,6 +301,14 @@ def _get_table_columns(self, batch_request: BatchRequest) -> Metric: ) def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: + """Return column types for the table. + + Args: + batch_request (BatchRequest): For current batch. + + Returns: + Metric: Column types for the table. + """ table_metric_configs = self._generate_table_metric_configurations( table_metric_names=[MetricTypes.TABLE_COLUMN_TYPES] ) From 22bfddd7923762a33c5488ae84bccbafc4228eb4 Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 16:52:13 -0700 Subject: [PATCH 06/15] added more unit test --- .../metric_list_metric_retriever.py | 8 +- .../experimental/metric_repository/metrics.py | 9 +- .../test_metric_list_metric_retriever.py | 149 ++++++++++++++++++ 3 files changed, 161 insertions(+), 5 deletions(-) diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py index 2608d95d705b..1ea9f08f889d 100644 --- a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -202,10 +202,10 @@ def _get_timestamp_column_metrics( """ metrics: list[Metric] = [] column_metric_names = { - "column.min", - "column.max", - # "column.mean", # Currently not supported for timestamp in Snowflake - # "column.median", # Currently not supported for timestamp in Snowflake + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + # MetricTypes.COLUMN_MEAN, # Currently not supported for timestamp in Snowflake + # MetricTypes.COLUMN_MEDIAN, # Currently not supported for timestamp in Snowflake } metrics_list_as_set = set(metrics_list) metrics_to_calculate = sorted( diff --git a/great_expectations/experimental/metric_repository/metrics.py b/great_expectations/experimental/metric_repository/metrics.py index 8886b3dbc22b..fb0a3b5a838f 100644 --- a/great_expectations/experimental/metric_repository/metrics.py +++ b/great_expectations/experimental/metric_repository/metrics.py @@ -24,7 +24,14 @@ AbstractSetIntStr = AbstractSet[Union[int, str]] -class MetricTypes(str, enum.Enum): +class MetricTypesMeta(enum.EnumMeta): + """Metaclass definition for MetricTypes that allows for membership checking.""" + + def __contains__(cls, item): + return item in cls.__members__.values() + + +class MetricTypes(str, enum.Enum, metaclass=MetricTypesMeta): """Represents Metric types in OSS that are used for ColumnDescriptiveMetrics and MetricRepository. More Metric types will be added in the future. diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py index e3cf0902fd26..cff59a5d69a3 100644 --- a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py @@ -638,3 +638,152 @@ def test_get_metrics_only_gets_a_validator_once(): ) mock_context.get_validator.assert_called_once_with(batch_request=mock_batch_request) + + +def test_get_metrics_with_no_metrics(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + computed_metrics = {} + cdm_metrics_list: List[MetricTypes] = [] + aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + + mock_batch_request = Mock(spec=BatchRequest) + + with pytest.raises(ValueError): + metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + + +def test_valid_metric_types_true(): + mock_context = Mock(spec=CloudDataContext) + metric_retriever = MetricListMetricRetriever(context=mock_context) + + valid_metric_types = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + MetricTypes.COLUMN_MAX, + MetricTypes.COLUMN_MEAN, + MetricTypes.COLUMN_MEDIAN, + MetricTypes.COLUMN_NULL_COUNT, + ] + assert metric_retriever._check_valid_metric_types(valid_metric_types) is True + + +def test_valid_metric_types_false(): + mock_context = Mock(spec=CloudDataContext) + metric_retriever = MetricListMetricRetriever(context=mock_context) + + invalid_metric_type = ["I_am_invalid"] + assert metric_retriever._check_valid_metric_types(invalid_metric_type) is False + + +def test_column_metrics_in_metrics_list_only_table_metrics(): + mock_context = Mock(spec=CloudDataContext) + metric_retriever = MetricListMetricRetriever(context=mock_context) + table_metrics_only = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + ] + assert metric_retriever._column_metrics_in_metric_list(table_metrics_only) is False + + +def test_column_metrics_in_metrics_list_with_column_metrics(): + mock_context = Mock(spec=CloudDataContext) + metric_retriever = MetricListMetricRetriever(context=mock_context) + metrics_list_with_column_metrics = [ + MetricTypes.TABLE_ROW_COUNT, + MetricTypes.TABLE_COLUMNS, + MetricTypes.TABLE_COLUMN_TYPES, + MetricTypes.COLUMN_MIN, + ] + assert ( + metric_retriever._column_metrics_in_metric_list( + metrics_list_with_column_metrics + ) + is True + ) + + +def test_get_table_column_types(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + mock_batch_request = Mock(spec=BatchRequest) + computed_metrics = { + ("table.column_types", (), "include_nested=True"): [ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + } + aborted_metrics = {} + mock_validator.compute_metrics.return_value = ( + computed_metrics, + aborted_metrics, + ) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + ret = metric_retriever._get_table_column_types(mock_batch_request) + print(ret) + + +def test_get_table_columns(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + mock_batch_request = Mock(spec=BatchRequest) + computed_metrics = { + ("table.columns", (), ()): ["col1", "col2"], + } + aborted_metrics = {} + mock_validator.compute_metrics.return_value = (computed_metrics, aborted_metrics) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + ret = metric_retriever._get_table_columns(mock_batch_request) + assert ret == TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ) + + +def test_get_table_row_count(): + mock_context = Mock(spec=CloudDataContext) + mock_validator = Mock(spec=Validator) + mock_context.get_validator.return_value = mock_validator + mock_batch_request = Mock(spec=BatchRequest) + computed_metrics = {("table.row_count", (), ()): 2} + aborted_metrics = {} + mock_validator.compute_metrics.return_value = (computed_metrics, aborted_metrics) + mock_batch = Mock(spec=Batch) + mock_batch.id = "batch_id" + mock_validator.active_batch = mock_batch + + metric_retriever = MetricListMetricRetriever(context=mock_context) + ret = metric_retriever._get_table_row_count(mock_batch_request) + assert ret == TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=2, + exception=None, + ) From 2ceb2c2f1e0b81431bc77658da1abd31b1ea322a Mon Sep 17 00:00:00 2001 From: William Shin Date: Tue, 12 Mar 2024 17:12:45 -0700 Subject: [PATCH 07/15] pushing static --- .../metric_repository/metric_list_metric_retriever.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py index 1ea9f08f889d..2b791f906921 100644 --- a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -17,9 +17,11 @@ if TYPE_CHECKING: from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent.batch_request import BatchRequest + from great_expectations.validator.metrics_calculator import ( + _MetricKey, + ) from great_expectations.validator.validator import ( Validator, - _MetricKey, ) From 86545b6106ef95cc5a5dba9b2bfc911fe6e09fbd Mon Sep 17 00:00:00 2001 From: William Shin Date: Wed, 13 Mar 2024 15:55:29 -0700 Subject: [PATCH 08/15] pushing metric_list changes --- .../metric_list_metric_retriever.py | 129 ++++-------------- .../metric_repository/metric_retriever.py | 60 ++++++++ 2 files changed, 83 insertions(+), 106 deletions(-) diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py index 2b791f906921..d8ce6ffc0c0d 100644 --- a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, List, Optional, Sequence from great_expectations.compatibility.typing_extensions import override from great_expectations.experimental.metric_repository.metric_retriever import ( @@ -17,9 +17,6 @@ if TYPE_CHECKING: from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent.batch_request import BatchRequest - from great_expectations.validator.metrics_calculator import ( - _MetricKey, - ) from great_expectations.validator.validator import ( Validator, ) @@ -43,7 +40,7 @@ def get_metrics( self._check_valid_metric_types(metric_list) - table_metrics = self._get_table_metrics( + table_metrics = self._calculate_table_metrics( batch_request=batch_request, metric_list=metric_list ) metrics_result.extend(table_metrics) @@ -116,38 +113,13 @@ def _get_non_numeric_column_metrics( if not metrics_to_calculate: return metrics - - column_metric_configs = self._generate_column_metric_configurations( - column_list, list(metrics_to_calculate) - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, column_metric_configs - ) - - # Convert computed_metrics - ColumnMetric.update_forward_refs() - metric_lookup_key: _MetricKey - - for metric_name in metrics_to_calculate: - for column in column_list: - metric_lookup_key = (metric_name, f"column={column}", tuple()) - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - metric_lookup_key=metric_lookup_key, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - metrics.append( - ColumnMetric[int]( - batch_id=batch_id, - metric_name=metric_name, - column=column, - value=value, - exception=exception, - ) - ) - - return metrics + else: + return self._get_column_metrics( + batch_request=batch_request, + column_list=column_list, + column_metric_names=list(metrics_to_calculate), + column_metric_type=ColumnMetric[int], + ) def _get_numeric_column_metrics( self, @@ -225,10 +197,10 @@ def _get_timestamp_column_metrics( column_metric_type=ColumnMetric[str], ) - def _get_table_metrics( + def _calculate_table_metrics( self, batch_request: BatchRequest, metric_list: List[MetricTypes] ) -> List[Metric]: - """Calculate column metrics for table metrics, which include row_count, column names and types. + """Calculate table metrics, which include row_count, column names and types. Args: metrics_list (List[MetricTypes]): list of metrics sent from Agent. @@ -255,23 +227,10 @@ def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: Returns: Metric: Row count for the table. """ - table_metric_configs = self._generate_table_metric_configurations( - table_metric_names=[MetricTypes.TABLE_ROW_COUNT] - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, table_metric_configs - ) - metric_name = MetricTypes.TABLE_ROW_COUNT - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - return TableMetric[int]( - batch_id=batch_id, - metric_name=metric_name, - value=value, - exception=exception, + return self._get_table_metrics( + batch_request=batch_request, + metric_name=MetricTypes.TABLE_ROW_COUNT, + metric_type=TableMetric[int], ) def _get_table_columns(self, batch_request: BatchRequest) -> Metric: @@ -283,23 +242,10 @@ def _get_table_columns(self, batch_request: BatchRequest) -> Metric: Returns: Metric: Column names for the table. """ - table_metric_configs = self._generate_table_metric_configurations( - table_metric_names=[MetricTypes.TABLE_COLUMNS] - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, table_metric_configs - ) - metric_name = MetricTypes.TABLE_COLUMNS - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - return TableMetric[List[str]]( - batch_id=batch_id, - metric_name=metric_name, - value=value, - exception=exception, + return self._get_table_metrics( + batch_request=batch_request, + metric_name=MetricTypes.TABLE_COLUMNS, + metric_type=TableMetric[List[str]], ) def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: @@ -311,39 +257,10 @@ def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: Returns: Metric: Column types for the table. """ - table_metric_configs = self._generate_table_metric_configurations( - table_metric_names=[MetricTypes.TABLE_COLUMN_TYPES] - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, table_metric_configs - ) - metric_name = MetricTypes.TABLE_COLUMN_TYPES - metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - metric_lookup_key=metric_lookup_key, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - raw_column_types: list[dict[str, Any]] = value - # If type is not found, don't add empty type field. This can happen if our db introspection fails. - column_types_converted_to_str: list[dict[str, str]] = [] - for raw_column_type in raw_column_types: - if raw_column_type.get("type"): - column_types_converted_to_str.append( - { - "name": raw_column_type["name"], - "type": str(raw_column_type["type"]), - } - ) - else: - column_types_converted_to_str.append({"name": raw_column_type["name"]}) - - return TableMetric[List[str]]( - batch_id=batch_id, - metric_name=metric_name, - value=column_types_converted_to_str, - exception=exception, + return self._get_table_metrics_column_types( + batch_request=batch_request, + metric_name=MetricTypes.TABLE_COLUMN_TYPES, + metric_type=TableMetric[List[str]], ) def _check_valid_metric_types(self, metric_list: List[MetricTypes]) -> bool: diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index 8ef508bae84b..233d1954b321 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -174,6 +174,65 @@ def _get_column_names_for_semantic_types( ) return column_names + def _get_table_metrics( + self, + batch_request: BatchRequest, + metric_name: MetricTypes | str, + metric_type: type[Metric], + ) -> Metric: + metric_configs = self._generate_table_metric_configurations([metric_name]) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, metric_configs + ) + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + return metric_type( + batch_id=batch_id, metric_name=metric_name, value=value, exception=exception + ) + + def _get_table_metrics_column_types( + self, + batch_request: BatchRequest, + metric_name: MetricTypes | str, + metric_type: type[Metric], + ) -> Metric: + metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") + table_metric_configs = self._generate_table_metric_configurations( + table_metric_names=[MetricTypes.TABLE_COLUMN_TYPES] + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, table_metric_configs + ) + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + metric_lookup_key=metric_lookup_key, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + raw_column_types: list[dict[str, Any]] = value + # If type is not found, don't add empty type field. This can happen if our db introspection fails. + column_types_converted_to_str: list[dict[str, str]] = [] + for raw_column_type in raw_column_types: + if raw_column_type.get("type"): + column_types_converted_to_str.append( + { + "name": raw_column_type["name"], + "type": str(raw_column_type["type"]), + } + ) + else: + column_types_converted_to_str.append({"name": raw_column_type["name"]}) + + return metric_type( + batch_id=batch_id, + metric_name=metric_name, + value=column_types_converted_to_str, + exception=exception, + ) + def _get_column_metrics( self, batch_request: BatchRequest, @@ -189,6 +248,7 @@ def _get_column_metrics( ) # Convert computed_metrics + ColumnMetric.update_forward_refs() metrics: list[Metric] = [] metric_lookup_key: _MetricKey From ab611b0781266a35ed1fdde87f31048d29ac683d Mon Sep 17 00:00:00 2001 From: William Shin Date: Wed, 13 Mar 2024 16:00:53 -0700 Subject: [PATCH 09/15] change name --- .../column_descriptive_metrics_metric_retriever.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py index d4ca0040363e..b0eba3369f1b 100644 --- a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py @@ -31,7 +31,7 @@ def __init__(self, context: AbstractDataContext): @override def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: - table_metrics = self._get_table_metrics(batch_request) + table_metrics = self._calculate_table_metrics(batch_request) # We need to skip columns that do not report a type, because the metric computation # to determine semantic type will fail. @@ -68,7 +68,7 @@ def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: ) return bundled_list - def _get_table_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: + def _calculate_table_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: table_metric_names = ["table.row_count", "table.columns", "table.column_types"] table_metric_configs = self._generate_table_metric_configurations( table_metric_names From 4d209c23406db9d07d7f7ef8e8803fc44f085cb7 Mon Sep 17 00:00:00 2001 From: William Shin Date: Wed, 13 Mar 2024 16:17:37 -0700 Subject: [PATCH 10/15] clean up CDM --- ...mn_descriptive_metrics_metric_retriever.py | 93 +------------------ .../metric_list_metric_retriever.py | 46 --------- .../metric_repository/metric_retriever.py | 46 +++++++++ 3 files changed, 50 insertions(+), 135 deletions(-) diff --git a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py index b0eba3369f1b..ae8fa6dc4cf7 100644 --- a/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/column_descriptive_metrics_metric_retriever.py @@ -1,7 +1,7 @@ from __future__ import annotations from itertools import chain -from typing import TYPE_CHECKING, Any, List, Sequence +from typing import TYPE_CHECKING, List, Sequence from great_expectations.compatibility.typing_extensions import override from great_expectations.experimental.metric_repository.metric_retriever import ( @@ -10,16 +10,13 @@ from great_expectations.experimental.metric_repository.metrics import ( ColumnMetric, Metric, - TableMetric, ) if TYPE_CHECKING: from great_expectations.data_context import AbstractDataContext from great_expectations.datasource.fluent import BatchRequest from great_expectations.validator.metrics_calculator import ( - _AbortedMetricsInfoDict, _MetricKey, - _MetricsDict, ) @@ -69,95 +66,13 @@ def get_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: return bundled_list def _calculate_table_metrics(self, batch_request: BatchRequest) -> Sequence[Metric]: - table_metric_names = ["table.row_count", "table.columns", "table.column_types"] - table_metric_configs = self._generate_table_metric_configurations( - table_metric_names - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, table_metric_configs - ) - metrics = [ - self._get_table_row_count(batch_id, computed_metrics, aborted_metrics), - self._get_table_columns(batch_id, computed_metrics, aborted_metrics), - self._get_table_column_types(batch_id, computed_metrics, aborted_metrics), + self._get_table_row_count(batch_request), + self._get_table_columns(batch_request), + self._get_table_column_types(batch_request), ] - return metrics - def _get_table_row_count( - self, - batch_id: str, - computed_metrics: _MetricsDict, - aborted_metrics: _AbortedMetricsInfoDict, - ) -> Metric: - metric_name = "table.row_count" - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - return TableMetric[int]( - batch_id=batch_id, - metric_name=metric_name, - value=value, - exception=exception, - ) - - def _get_table_columns( - self, - batch_id: str, - computed_metrics: _MetricsDict, - aborted_metrics: _AbortedMetricsInfoDict, - ) -> Metric: - metric_name = "table.columns" - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - return TableMetric[List[str]]( - batch_id=batch_id, - metric_name=metric_name, - value=value, - exception=exception, - ) - - def _get_table_column_types( - self, - batch_id: str, - computed_metrics: _MetricsDict, - aborted_metrics: _AbortedMetricsInfoDict, - ) -> Metric: - metric_name = "table.column_types" - metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - metric_lookup_key=metric_lookup_key, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - raw_column_types: list[dict[str, Any]] = value - # If type is not found, don't add empty type field. This can happen if our db introspection fails. - column_types_converted_to_str: list[dict[str, str]] = [] - for raw_column_type in raw_column_types: - if raw_column_type.get("type"): - column_types_converted_to_str.append( - { - "name": raw_column_type["name"], - "type": str(raw_column_type["type"]), - } - ) - else: - column_types_converted_to_str.append({"name": raw_column_type["name"]}) - - return TableMetric[List[str]]( - batch_id=batch_id, - metric_name=metric_name, - value=column_types_converted_to_str, - exception=exception, - ) - def _get_numeric_column_metrics( self, batch_request: BatchRequest, column_list: List[str] ) -> Sequence[Metric]: diff --git a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py index d8ce6ffc0c0d..efbb08e00991 100644 --- a/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_list_metric_retriever.py @@ -11,7 +11,6 @@ ColumnMetric, Metric, MetricTypes, - TableMetric, ) if TYPE_CHECKING: @@ -218,51 +217,6 @@ def _calculate_table_metrics( metrics.append(self._get_table_column_types(batch_request=batch_request)) return metrics - def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: - """Return row_count for the table. - - Args: - batch_request (BatchRequest): For current batch. - - Returns: - Metric: Row count for the table. - """ - return self._get_table_metrics( - batch_request=batch_request, - metric_name=MetricTypes.TABLE_ROW_COUNT, - metric_type=TableMetric[int], - ) - - def _get_table_columns(self, batch_request: BatchRequest) -> Metric: - """Return column names for the table. - - Args: - batch_request (BatchRequest): For current batch. - - Returns: - Metric: Column names for the table. - """ - return self._get_table_metrics( - batch_request=batch_request, - metric_name=MetricTypes.TABLE_COLUMNS, - metric_type=TableMetric[List[str]], - ) - - def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: - """Return column types for the table. - - Args: - batch_request (BatchRequest): For current batch. - - Returns: - Metric: Column types for the table. - """ - return self._get_table_metrics_column_types( - batch_request=batch_request, - metric_name=MetricTypes.TABLE_COLUMN_TYPES, - metric_type=TableMetric[List[str]], - ) - def _check_valid_metric_types(self, metric_list: List[MetricTypes]) -> bool: """Check whether all the metric types in the list are valid. diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index 233d1954b321..e950ccca54dc 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -15,6 +15,7 @@ ColumnMetric, MetricException, MetricTypes, + TableMetric, ) from great_expectations.rule_based_profiler.domain_builder import ColumnDomainBuilder from great_expectations.validator.exception_info import ExceptionInfo @@ -293,3 +294,48 @@ def _get_all_column_names(self, metrics: Sequence[Metric]) -> List[str]: if metric.metric_name == MetricTypes.TABLE_COLUMNS: column_list = metric.value return column_list + + def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: + """Return row_count for the table. + + Args: + batch_request (BatchRequest): For current batch. + + Returns: + Metric: Row count for the table. + """ + return self._get_table_metrics( + batch_request=batch_request, + metric_name=MetricTypes.TABLE_ROW_COUNT, + metric_type=TableMetric[int], + ) + + def _get_table_columns(self, batch_request: BatchRequest) -> Metric: + """Return column names for the table. + + Args: + batch_request (BatchRequest): For current batch. + + Returns: + Metric: Column names for the table. + """ + return self._get_table_metrics( + batch_request=batch_request, + metric_name=MetricTypes.TABLE_COLUMNS, + metric_type=TableMetric[List[str]], + ) + + def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: + """Return column types for the table. + + Args: + batch_request (BatchRequest): For current batch. + + Returns: + Metric: Column types for the table. + """ + return self._get_table_metrics_column_types( + batch_request=batch_request, + metric_name=MetricTypes.TABLE_COLUMN_TYPES, + metric_type=TableMetric[List[str]], + ) From 777814d073c7d9570bec0051f053c78b7ad59d5c Mon Sep 17 00:00:00 2001 From: William Shin Date: Wed, 13 Mar 2024 16:29:45 -0700 Subject: [PATCH 11/15] clean up --- .../metric_repository/metric_retriever.py | 78 ++++++++----------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index e950ccca54dc..7e02fe021e8e 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -194,46 +194,6 @@ def _get_table_metrics( batch_id=batch_id, metric_name=metric_name, value=value, exception=exception ) - def _get_table_metrics_column_types( - self, - batch_request: BatchRequest, - metric_name: MetricTypes | str, - metric_type: type[Metric], - ) -> Metric: - metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") - table_metric_configs = self._generate_table_metric_configurations( - table_metric_names=[MetricTypes.TABLE_COLUMN_TYPES] - ) - batch_id, computed_metrics, aborted_metrics = self._compute_metrics( - batch_request, table_metric_configs - ) - value, exception = self._get_metric_from_computed_metrics( - metric_name=metric_name, - metric_lookup_key=metric_lookup_key, - computed_metrics=computed_metrics, - aborted_metrics=aborted_metrics, - ) - raw_column_types: list[dict[str, Any]] = value - # If type is not found, don't add empty type field. This can happen if our db introspection fails. - column_types_converted_to_str: list[dict[str, str]] = [] - for raw_column_type in raw_column_types: - if raw_column_type.get("type"): - column_types_converted_to_str.append( - { - "name": raw_column_type["name"], - "type": str(raw_column_type["type"]), - } - ) - else: - column_types_converted_to_str.append({"name": raw_column_type["name"]}) - - return metric_type( - batch_id=batch_id, - metric_name=metric_name, - value=column_types_converted_to_str, - exception=exception, - ) - def _get_column_metrics( self, batch_request: BatchRequest, @@ -334,8 +294,38 @@ def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: Returns: Metric: Column types for the table. """ - return self._get_table_metrics_column_types( - batch_request=batch_request, - metric_name=MetricTypes.TABLE_COLUMN_TYPES, - metric_type=TableMetric[List[str]], + metric_name = MetricTypes.TABLE_COLUMN_TYPES + + metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") + table_metric_configs = self._generate_table_metric_configurations( + table_metric_names=[metric_name] + ) + batch_id, computed_metrics, aborted_metrics = self._compute_metrics( + batch_request, table_metric_configs + ) + value, exception = self._get_metric_from_computed_metrics( + metric_name=metric_name, + metric_lookup_key=metric_lookup_key, + computed_metrics=computed_metrics, + aborted_metrics=aborted_metrics, + ) + raw_column_types: list[dict[str, Any]] = value + # If type is not found, don't add empty type field. This can happen if our db introspection fails. + column_types_converted_to_str: list[dict[str, str]] = [] + for raw_column_type in raw_column_types: + if raw_column_type.get("type"): + column_types_converted_to_str.append( + { + "name": raw_column_type["name"], + "type": str(raw_column_type["type"]), + } + ) + else: + column_types_converted_to_str.append({"name": raw_column_type["name"]}) + + return TableMetric[List[str]]( + batch_id=batch_id, + metric_name=metric_name, + value=column_types_converted_to_str, + exception=exception, ) From ac8365175ae0dd95ce6b45a8f471aa5416ec6f6d Mon Sep 17 00:00:00 2001 From: William Shin Date: Wed, 13 Mar 2024 16:31:55 -0700 Subject: [PATCH 12/15] remove docstring --- .../metric_repository/metric_retriever.py | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/great_expectations/experimental/metric_repository/metric_retriever.py b/great_expectations/experimental/metric_repository/metric_retriever.py index 7e02fe021e8e..7b1a472c28e1 100644 --- a/great_expectations/experimental/metric_repository/metric_retriever.py +++ b/great_expectations/experimental/metric_repository/metric_retriever.py @@ -256,14 +256,6 @@ def _get_all_column_names(self, metrics: Sequence[Metric]) -> List[str]: return column_list def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: - """Return row_count for the table. - - Args: - batch_request (BatchRequest): For current batch. - - Returns: - Metric: Row count for the table. - """ return self._get_table_metrics( batch_request=batch_request, metric_name=MetricTypes.TABLE_ROW_COUNT, @@ -271,14 +263,6 @@ def _get_table_row_count(self, batch_request: BatchRequest) -> Metric: ) def _get_table_columns(self, batch_request: BatchRequest) -> Metric: - """Return column names for the table. - - Args: - batch_request (BatchRequest): For current batch. - - Returns: - Metric: Column names for the table. - """ return self._get_table_metrics( batch_request=batch_request, metric_name=MetricTypes.TABLE_COLUMNS, @@ -286,14 +270,6 @@ def _get_table_columns(self, batch_request: BatchRequest) -> Metric: ) def _get_table_column_types(self, batch_request: BatchRequest) -> Metric: - """Return column types for the table. - - Args: - batch_request (BatchRequest): For current batch. - - Returns: - Metric: Column types for the table. - """ metric_name = MetricTypes.TABLE_COLUMN_TYPES metric_lookup_key: _MetricKey = (metric_name, tuple(), "include_nested=True") From 6581df4f5825e8a758ba547973d3ff88ae1cbe3e Mon Sep 17 00:00:00 2001 From: William Shin Date: Thu, 14 Mar 2024 09:50:30 -0700 Subject: [PATCH 13/15] remove stray comment --- .../metric_repository/test_metric_list_metric_retriever.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py index cff59a5d69a3..ea453fe29423 100644 --- a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py @@ -466,7 +466,6 @@ def test_get_metrics_with_column_type_missing(): metrics = metric_retriever.get_metrics( batch_request=mock_batch_request, metric_list=cdm_metrics_list ) - # why is this not sorted? assert metrics == [ TableMetric[int]( batch_id="batch_id", From 5035801afff8bb981980d75eefc5d3e3503740ee Mon Sep 17 00:00:00 2001 From: William Shin Date: Thu, 14 Mar 2024 11:07:33 -0700 Subject: [PATCH 14/15] pushing mocker changes --- .../test_metric_list_metric_retriever.py | 155 +++++++++--------- ...etric_list_metric_retriever_integration.py | 1 + 2 files changed, 78 insertions(+), 78 deletions(-) diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py index ea453fe29423..5807b738615f 100644 --- a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py @@ -1,6 +1,4 @@ -from typing import List -from unittest import mock -from unittest.mock import Mock +from typing import Dict, List import pytest @@ -22,10 +20,12 @@ pytestmark = pytest.mark.unit +from pytest_mock import MockerFixture -def test_get_metrics_table_metrics_only(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) + +def test_get_metrics_table_metrics_only(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator computed_metrics = { ("table.row_count", (), ()): 2, @@ -40,18 +40,18 @@ def test_get_metrics_table_metrics_only(): MetricTypes.TABLE_COLUMNS, MetricTypes.TABLE_COLUMN_TYPES, ] - aborted_metrics = {} + aborted_metrics: Dict[str, str] = {} mock_validator.compute_metrics.return_value = ( computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) metrics = metric_retriever.get_metrics( batch_request=mock_batch_request, @@ -82,9 +82,9 @@ def test_get_metrics_table_metrics_only(): ] -def test_get_metrics_full_list(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_metrics_full_list(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator computed_metrics = { ("table.row_count", (), ()): 2, @@ -114,23 +114,23 @@ def test_get_metrics_full_list(): MetricTypes.COLUMN_MEDIAN, MetricTypes.COLUMN_NULL_COUNT, ] - aborted_metrics = {} + aborted_metrics: Dict[str, str] = {} mock_validator.compute_metrics.return_value = ( computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) - with mock.patch( + with mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=["col1", "col2"], - ), mock.patch( + ), mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=[], ): @@ -231,10 +231,10 @@ def test_get_metrics_full_list(): ] -def test_get_metrics_metrics_missing(): +def test_get_metrics_metrics_missing(mocker: MockerFixture): """This test is meant to simulate metrics missing from the computed metrics.""" - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator mock_computed_metrics = { # ("table.row_count", (), ()): 2, # Missing table.row_count metric @@ -258,18 +258,18 @@ def test_get_metrics_metrics_missing(): mock_computed_metrics, mock_aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) - with mock.patch( + with mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=["col1", "col2"], - ), mock.patch( + ), mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=[], ): @@ -321,10 +321,10 @@ def test_get_metrics_metrics_missing(): ] -def test_get_metrics_with_exception(): +def test_get_metrics_with_exception(mocker: MockerFixture): """This test is meant to simulate failed metrics in the computed metrics.""" - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator exception_info = ExceptionInfo( @@ -356,7 +356,7 @@ def test_get_metrics_with_exception(): computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch @@ -368,7 +368,7 @@ def test_get_metrics_with_exception(): metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) metrics = metric_retriever.get_metrics( batch_request=mock_batch_request, metric_list=cdm_metrics_list @@ -399,11 +399,10 @@ def test_get_metrics_with_exception(): ] -def test_get_metrics_with_column_type_missing(): - """This test is meant to simulate failed metrics in the computed metrics.""" +def test_get_metrics_with_column_type_missing(mocker: MockerFixture): """This test is meant to simulate failed metrics in the computed metrics.""" - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator exception_info = ExceptionInfo( @@ -441,7 +440,7 @@ def test_get_metrics_with_column_type_missing(): computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch @@ -454,12 +453,12 @@ def test_get_metrics_with_column_type_missing(): metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) - with mock.patch( + with mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=["col1", "col2"], - ), mock.patch( + ), mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=[], ): @@ -507,9 +506,9 @@ def test_get_metrics_with_column_type_missing(): ] -def test_get_metrics_with_timestamp_columns(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_metrics_with_timestamp_columns(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator computed_metrics = { ("table.row_count", (), ()): 2, @@ -534,18 +533,18 @@ def test_get_metrics_with_timestamp_columns(): computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) - with mock.patch( + with mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=[], - ), mock.patch( + ), mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=["timestamp_col"], ): @@ -596,9 +595,9 @@ def test_get_metrics_with_timestamp_columns(): ] -def test_get_metrics_only_gets_a_validator_once(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_metrics_only_gets_a_validator_once(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator aborted_metrics = {} @@ -620,15 +619,15 @@ def test_get_metrics_only_gets_a_validator_once(): computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) - with mock.patch( + with mocker.patch( f"{ColumnDomainBuilder.__module__}.{ColumnDomainBuilder.__name__}.get_effective_column_names", return_value=["col1", "col2"], ): @@ -639,9 +638,9 @@ def test_get_metrics_only_gets_a_validator_once(): mock_context.get_validator.assert_called_once_with(batch_request=mock_batch_request) -def test_get_metrics_with_no_metrics(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_metrics_with_no_metrics(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator computed_metrics = {} cdm_metrics_list: List[MetricTypes] = [] @@ -650,13 +649,13 @@ def test_get_metrics_with_no_metrics(): computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch metric_retriever = MetricListMetricRetriever(context=mock_context) - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) with pytest.raises(ValueError): metric_retriever.get_metrics( @@ -664,8 +663,8 @@ def test_get_metrics_with_no_metrics(): ) -def test_valid_metric_types_true(): - mock_context = Mock(spec=CloudDataContext) +def test_valid_metric_types_true(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) metric_retriever = MetricListMetricRetriever(context=mock_context) valid_metric_types = [ @@ -681,16 +680,16 @@ def test_valid_metric_types_true(): assert metric_retriever._check_valid_metric_types(valid_metric_types) is True -def test_valid_metric_types_false(): - mock_context = Mock(spec=CloudDataContext) +def test_valid_metric_types_false(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) metric_retriever = MetricListMetricRetriever(context=mock_context) invalid_metric_type = ["I_am_invalid"] assert metric_retriever._check_valid_metric_types(invalid_metric_type) is False -def test_column_metrics_in_metrics_list_only_table_metrics(): - mock_context = Mock(spec=CloudDataContext) +def test_column_metrics_in_metrics_list_only_table_metrics(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) metric_retriever = MetricListMetricRetriever(context=mock_context) table_metrics_only = [ MetricTypes.TABLE_ROW_COUNT, @@ -700,8 +699,8 @@ def test_column_metrics_in_metrics_list_only_table_metrics(): assert metric_retriever._column_metrics_in_metric_list(table_metrics_only) is False -def test_column_metrics_in_metrics_list_with_column_metrics(): - mock_context = Mock(spec=CloudDataContext) +def test_column_metrics_in_metrics_list_with_column_metrics(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) metric_retriever = MetricListMetricRetriever(context=mock_context) metrics_list_with_column_metrics = [ MetricTypes.TABLE_ROW_COUNT, @@ -717,11 +716,11 @@ def test_column_metrics_in_metrics_list_with_column_metrics(): ) -def test_get_table_column_types(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_table_column_types(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) computed_metrics = { ("table.column_types", (), "include_nested=True"): [ {"name": "col1", "type": "float"}, @@ -733,7 +732,7 @@ def test_get_table_column_types(): computed_metrics, aborted_metrics, ) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch @@ -742,17 +741,17 @@ def test_get_table_column_types(): print(ret) -def test_get_table_columns(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_table_columns(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) computed_metrics = { ("table.columns", (), ()): ["col1", "col2"], } aborted_metrics = {} mock_validator.compute_metrics.return_value = (computed_metrics, aborted_metrics) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch @@ -766,15 +765,15 @@ def test_get_table_columns(): ) -def test_get_table_row_count(): - mock_context = Mock(spec=CloudDataContext) - mock_validator = Mock(spec=Validator) +def test_get_table_row_count(mocker: MockerFixture): + mock_context = mocker.Mock(spec=CloudDataContext) + mock_validator = mocker.Mock(spec=Validator) mock_context.get_validator.return_value = mock_validator - mock_batch_request = Mock(spec=BatchRequest) + mock_batch_request = mocker.Mock(spec=BatchRequest) computed_metrics = {("table.row_count", (), ()): 2} aborted_metrics = {} mock_validator.compute_metrics.return_value = (computed_metrics, aborted_metrics) - mock_batch = Mock(spec=Batch) + mock_batch = mocker.Mock(spec=Batch) mock_batch.id = "batch_id" mock_validator.active_batch = mock_batch diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py index 3726f32422df..7e0aea8c65a5 100644 --- a/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever_integration.py @@ -1,4 +1,5 @@ """Test using actual sample data.""" + from __future__ import annotations from typing import List From 220a649c67b6bf2962bf431237659aba94ebd756 Mon Sep 17 00:00:00 2001 From: William Shin Date: Thu, 14 Mar 2024 12:19:01 -0700 Subject: [PATCH 15/15] clean up context manager --- .../test_metric_list_metric_retriever.py | 146 +++++++++--------- 1 file changed, 75 insertions(+), 71 deletions(-) diff --git a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py index 5807b738615f..a6c0c3919262 100644 --- a/tests/experimental/metric_repository/test_metric_list_metric_retriever.py +++ b/tests/experimental/metric_repository/test_metric_list_metric_retriever.py @@ -127,17 +127,18 @@ def test_get_metrics_full_list(mocker: MockerFixture): mock_batch_request = mocker.Mock(spec=BatchRequest) - with mocker.patch( + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=["col1", "col2"], - ), mocker.patch( + ) + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=[], - ): - metrics = metric_retriever.get_metrics( - batch_request=mock_batch_request, - metric_list=cdm_metrics_list, - ) + ) + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, + metric_list=cdm_metrics_list, + ) assert metrics == [ TableMetric[int]( @@ -266,59 +267,60 @@ def test_get_metrics_metrics_missing(mocker: MockerFixture): mock_batch_request = mocker.Mock(spec=BatchRequest) - with mocker.patch( + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=["col1", "col2"], - ), mocker.patch( + ) + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=[], - ): - metrics = metric_retriever.get_metrics( - batch_request=mock_batch_request, metric_list=cdm_metrics_list - ) - assert metrics == [ - TableMetric[int]( - batch_id="batch_id", - metric_name="table.row_count", - value=None, - exception=MetricException( - type="Not found", - message="Metric was not successfully computed but exception was not found.", - ), - ), - TableMetric[List[str]]( - batch_id="batch_id", - metric_name="table.columns", - value=["col1", "col2"], - exception=None, - ), - TableMetric[List[str]]( - batch_id="batch_id", - metric_name="table.column_types", - value=[ - {"name": "col1", "type": "float"}, - {"name": "col2", "type": "float"}, - ], - exception=None, - ), - ColumnMetric[float]( - batch_id="batch_id", - metric_name="column.min", - value=None, - exception=MetricException( - type="Not found", - message="Metric was not successfully computed but exception was not found.", - ), - column="col1", + ) + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) + assert metrics == [ + TableMetric[int]( + batch_id="batch_id", + metric_name="table.row_count", + value=None, + exception=MetricException( + type="Not found", + message="Metric was not successfully computed but exception was not found.", ), - ColumnMetric[float]( - batch_id="batch_id", - metric_name="column.min", - value=2.7, - exception=None, - column="col2", + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.columns", + value=["col1", "col2"], + exception=None, + ), + TableMetric[List[str]]( + batch_id="batch_id", + metric_name="table.column_types", + value=[ + {"name": "col1", "type": "float"}, + {"name": "col2", "type": "float"}, + ], + exception=None, + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + value=None, + exception=MetricException( + type="Not found", + message="Metric was not successfully computed but exception was not found.", ), - ] + column="col1", + ), + ColumnMetric[float]( + batch_id="batch_id", + metric_name="column.min", + value=2.7, + exception=None, + column="col2", + ), + ] def test_get_metrics_with_exception(mocker: MockerFixture): @@ -455,16 +457,17 @@ def test_get_metrics_with_column_type_missing(mocker: MockerFixture): mock_batch_request = mocker.Mock(spec=BatchRequest) - with mocker.patch( + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=["col1", "col2"], - ), mocker.patch( + ) + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=[], - ): - metrics = metric_retriever.get_metrics( - batch_request=mock_batch_request, metric_list=cdm_metrics_list - ) + ) + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) assert metrics == [ TableMetric[int]( batch_id="batch_id", @@ -541,16 +544,17 @@ def test_get_metrics_with_timestamp_columns(mocker: MockerFixture): mock_batch_request = mocker.Mock(spec=BatchRequest) - with mocker.patch( + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_numeric_column_names", return_value=[], - ), mocker.patch( + ) + mocker.patch( f"{MetricListMetricRetriever.__module__}.{MetricListMetricRetriever.__name__}._get_timestamp_column_names", return_value=["timestamp_col"], - ): - metrics = metric_retriever.get_metrics( - batch_request=mock_batch_request, metric_list=cdm_metrics_list - ) + ) + metrics = metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) assert metrics == [ TableMetric[int]( @@ -627,13 +631,13 @@ def test_get_metrics_only_gets_a_validator_once(mocker: MockerFixture): mock_batch_request = mocker.Mock(spec=BatchRequest) - with mocker.patch( + mocker.patch( f"{ColumnDomainBuilder.__module__}.{ColumnDomainBuilder.__name__}.get_effective_column_names", return_value=["col1", "col2"], - ): - metric_retriever.get_metrics( - batch_request=mock_batch_request, metric_list=cdm_metrics_list - ) + ) + metric_retriever.get_metrics( + batch_request=mock_batch_request, metric_list=cdm_metrics_list + ) mock_context.get_validator.assert_called_once_with(batch_request=mock_batch_request)