Skip to content

Commit

Permalink
Cleanup testing
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Sep 24, 2024
1 parent 66e5ac8 commit c45102c
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 71 deletions.
3 changes: 1 addition & 2 deletions src/y0/r_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def prepare_renv(requirements: Iterable[str]) -> None:
logger.warning("installing R packages: %s", uninstalled_requirements)
utils.install_packages(StrVector(uninstalled_requirements))

for requirement in requirements:
importr(requirement)
return [importr(requirement) for requirement in requirements]


@lru_cache(maxsize=1)
Expand Down
94 changes: 47 additions & 47 deletions tests/test_algorithm/test_counterfactual_transportability.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def test_7(self):
test7_in = W @ +X
test7_out = {W @ +X}
result = get_ancestors_of_counterfactual(event=test7_in, graph=figure_2a_graph)
logger.warning("In test_7: result = " + str(result))
logger.debug("In test_7: result = " + str(result))
self.assertTrue(variable in test7_out for variable in result)


Expand Down Expand Up @@ -506,7 +506,7 @@ def test_inconsistent_1(self):
"""
event = [(Y @ -X, -Y), (Y @ -X, +Y)]
result = simplify(event=event, graph=figure_2a_graph)
logger.warning("Result for test_inconsistent_1 is " + str(result))
logger.debug("Result for test_inconsistent_1 is " + str(result))
self.assertIsNone(simplify(event=event, graph=figure_2a_graph))

def test_inconsistent_2(self):
Expand Down Expand Up @@ -630,11 +630,11 @@ def test_line_2_1(self):
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -X].add(+Y)

logger.warning(
logger.debug(
"In test_line_2_1: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_1: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -652,11 +652,11 @@ def test_line_2_2(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
logger.warning(
logger.debug(
"In test_line_2_2: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_2: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -674,11 +674,11 @@ def test_line_2_10(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(None)
nonreflexive_variable_to_value_mappings[Y @ -X].add(None)
logger.warning(
logger.debug(
"In test_line_2_10: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_10: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -695,11 +695,11 @@ def test_line_2_3(self):
reflexive_variable_to_value_mappings[Y @ -Y].add(+Y)

nonreflexive_variable_to_value_mappings = defaultdict(set)
logger.warning(
logger.debug(
"In test_line_2_3: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_3: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -721,11 +721,11 @@ def test_line_2_11(self):
reflexive_variable_to_value_mappings[Y @ -Y].add(None)

nonreflexive_variable_to_value_mappings = defaultdict(set)
logger.warning(
logger.debug(
"In test_line_2_11: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_11: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -743,11 +743,11 @@ def test_line_2_4(self):

nonreflexive_variable_to_value_mappings = defaultdict(set)

logger.warning(
logger.debug(
"In test_line_2_4: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_4: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -764,11 +764,11 @@ def test_line_2_5(self):
reflexive_variable_to_value_mappings[Y @ +Y].add(-Y)

nonreflexive_variable_to_value_mappings = defaultdict(set)
logger.warning(
logger.debug(
"In test_line_2_5: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_5: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -786,11 +786,11 @@ def test_line_2_6(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -Z].add(+Y)
logger.warning(
logger.debug(
"In test_line_2_6: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_6: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -810,11 +810,11 @@ def test_line_2_7(self):

nonreflexive_variable_to_value_mappings = defaultdict(set)

logger.warning(
logger.debug(
"In test_line_2_7: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_7: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -841,11 +841,11 @@ def test_line_2_8(self):
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -X].add(None)

logger.warning(
logger.debug(
"In test_line_2_8: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_8: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -869,11 +869,11 @@ def test_line_2_9(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)

logger.warning(
logger.debug(
"In test_line_2_9: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_9: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand Down Expand Up @@ -1777,7 +1777,7 @@ def test_transport_district_intervening_on_parents_2(self):
]
domain_data = [({X}, PP[Pi1](W, X, Y, Z)), (set(), PP[Pi2](W, X, Y, Z))]
expected_result = PP[Pi2](X | Z) * PP[Pi2](Z)
logger.warning(
logger.debug(
"In test_transport_district_intervening_on_parents_2: expected_result is "
+ expected_result.to_latex()
)
Expand Down Expand Up @@ -2485,8 +2485,8 @@ def test_transport_unconditional_counterfactual_query_1(self):
domain_graphs=domain_graphs,
domain_data=domain_data,
)
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result, result_expr)

def test_transport_unconditional_counterfactual_query_2(self):
Expand Down Expand Up @@ -2550,8 +2550,8 @@ def test_transport_unconditional_counterfactual_query_3(self):
domain_graphs=domain_graphs,
domain_data=domain_data,
)
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result, result_expr)
self.assertCountEqual(event, result_event)
# Test sending variables with a value of None into this algorithm
Expand Down Expand Up @@ -3185,10 +3185,10 @@ def test_transport_conditional_counterfactual_query_1(self):
domain_graphs=self.example_1_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -3239,10 +3239,10 @@ def test_transport_conditional_counterfactual_query_2(self):
domain_graphs=self.example_2_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -3588,10 +3588,10 @@ def test_transport_conditional_counterfactual_query_5(self):
domain_graphs=self.example_1_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -3621,10 +3621,10 @@ def test_transport_conditional_counterfactual_query_6(self):
domain_graphs=self.example_1_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -6036,8 +6036,8 @@ def test_merge_frozen_sets_linked_by_bidirectional_edges(self):
input_sets=test_2_inputs, graph=graph_2
)
expected_result_2 = {frozenset([W, Y]), frozenset([X]), frozenset([W1]), frozenset([R, Z])}
logger.warning(str(expected_result_2))
logger.warning(str(result_2))
logger.debug(str(expected_result_2))
logger.debug(str(result_2))
self.assertSetEqual(result_2, expected_result_2)
graph_3 = NxMixedGraph.from_edges(directed=[], undirected=[(W, X), (X, Y), (R, Z), (W1, Y)])
result_3 = _merge_frozen_sets_linked_by_bidirectional_edges(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_algorithm/test_falsification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test falsification of testable implications given a graph."""

import unittest
import warnings

import numpy as np
import pandas as pd
Expand All @@ -19,7 +20,8 @@ def test_discrete_graph_falsifications(self):
for method in [None, *get_conditional_independence_tests()]:
if method == "pearson":
continue
with self.subTest(method=method):
with self.subTest(method=method), warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
issues = get_graph_falsifications(
asia_example.graph, asia_example.data, method=method
)
Expand Down
Loading

0 comments on commit c45102c

Please sign in to comment.