diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 70ea3981..63838237 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,10 @@ Changelog - Added a ``name`` parameter to all ``add_x_constraint`` methods of ``WithinRequirement`` and ``BetweenRequirement``. This will give pytest test a custom name. - Added preliminary support for Impala. +**Other changes** + +- Improve assertion error for :meth:`~datajudge.WithinRequirement.add_row_matching_equality_constraint`. + 1.2.0 - 2022.10.21 ------------------ diff --git a/src/datajudge/constraints/row.py b/src/datajudge/constraints/row.py index 0eef0c8c..28c27c0c 100644 --- a/src/datajudge/constraints/row.py +++ b/src/datajudge/constraints/row.py @@ -177,7 +177,7 @@ def __init__( ) def test(self, engine: sa.engine.Engine) -> TestResult: - missing_fraction, selections = db_access.get_row_mismatch( + missing_fraction, n_rows_match, selections = db_access.get_row_mismatch( engine, self.ref, self.ref2, self.match_and_compare ) self.factual_selections = selections @@ -187,10 +187,10 @@ def test(self, engine: sa.engine.Engine) -> TestResult: return TestResult.success() assertion_message = ( f"{missing_fraction} > " - f"{max_missing_fraction} of rows matched " - f"between {self.ref.get_string()} and " + f"{max_missing_fraction} of the rows differ " + f"on a match of {n_rows_match} rows between {self.ref.get_string()} and " f"{self.ref2.get_string()}. " - f"{self.condition_string}." + f"{self.condition_string}" f"{self.match_and_compare} " ) return TestResult.failure(assertion_message) diff --git a/src/datajudge/db_access.py b/src/datajudge/db_access.py index 7555a77f..a4a6d782 100644 --- a/src/datajudge/db_access.py +++ b/src/datajudge/db_access.py @@ -924,11 +924,15 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare): avg_match_column = sa.func.avg(sa.case([(compare, 0.0)], else_=1.0)) - selection = sa.select([avg_match_column]).select_from( + selection_difference = sa.select([avg_match_column]).select_from( subselection1.join(subselection2, match) ) - result = engine.connect().execute(selection).scalar() - return result, [selection] + selection_n_rows = sa.select(sa.func.count()).select_from( + subselection1.join(subselection2, match) + ) + result_mismatch = engine.connect().execute(selection_difference).scalar() + result_n_rows = engine.connect().execute(selection_n_rows).scalar() + return result_mismatch, result_n_rows, [selection_difference, selection_n_rows] def get_duplicate_sample(engine, ref): diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index ea3aa1f9..a523ab83 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1926,7 +1926,8 @@ def test_row_matching_equality(engine, row_match_table1, row_match_table2, data) condition1=condition1, condition2=condition2, ) - assert operation(req[0].test(engine).outcome) + test_result = req[0].test(engine) + assert operation(test_result.outcome), test_result.failure_message @pytest.mark.parametrize("key", [("some_id",), ("some_id", "extra_id")])