Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup testing #247

Merged
merged 5 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions src/y0/r_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
"""General utilities for :mod:`rpy2`."""

from __future__ import annotations

import logging
from collections.abc import Callable, Iterable
from functools import lru_cache, wraps
from typing import Any, TypeVar, cast

from rpy2.robjects.packages import importr, isinstalled
from rpy2.robjects.packages import InstalledPackage, InstalledSTPackage, importr, isinstalled
from rpy2.robjects.vectors import StrVector

from .dsl import Variable

__all__ = ["uses_r"]
__all__ = ["uses_r", "prepare_renv", "prepare_default_renv"]

logger = logging.getLogger(__name__)

Expand All @@ -26,10 +28,11 @@
Func = Callable[..., T]


def prepare_renv(requirements: Iterable[str]) -> None:
def prepare_renv(requirements: Iterable[str]) -> list[InstalledSTPackage | InstalledPackage]:
"""Ensure the given R packages are installed.

:param requirements: A list of R packages to ensure are installed
:param requirements: A list of R package names to ensure are installed
:returns: A list of R packages

.. seealso:: https://rpy2.github.io/doc/v3.4.x/html/introduction.html#installing-packages
"""
Expand All @@ -46,8 +49,7 @@ def prepare_renv(requirements: Iterable[str]) -> None:
logger.warning("installing R packages: %s", uninstalled_requirements)
utils.install_packages(StrVector(uninstalled_requirements))

for requirement in requirements:
importr(requirement)
return [importr(requirement) for requirement in requirements]


@lru_cache(maxsize=1)
Expand Down
94 changes: 47 additions & 47 deletions tests/test_algorithm/test_counterfactual_transportability.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def test_7(self):
test7_in = W @ +X
test7_out = {W @ +X}
result = get_ancestors_of_counterfactual(event=test7_in, graph=figure_2a_graph)
logger.warning("In test_7: result = " + str(result))
logger.debug("In test_7: result = " + str(result))
self.assertTrue(variable in test7_out for variable in result)


Expand Down Expand Up @@ -506,7 +506,7 @@ def test_inconsistent_1(self):
"""
event = [(Y @ -X, -Y), (Y @ -X, +Y)]
result = simplify(event=event, graph=figure_2a_graph)
logger.warning("Result for test_inconsistent_1 is " + str(result))
logger.debug("Result for test_inconsistent_1 is " + str(result))
self.assertIsNone(simplify(event=event, graph=figure_2a_graph))

def test_inconsistent_2(self):
Expand Down Expand Up @@ -630,11 +630,11 @@ def test_line_2_1(self):
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -X].add(+Y)

logger.warning(
logger.debug(
"In test_line_2_1: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_1: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -652,11 +652,11 @@ def test_line_2_2(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
logger.warning(
logger.debug(
"In test_line_2_2: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_2: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -674,11 +674,11 @@ def test_line_2_10(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(None)
nonreflexive_variable_to_value_mappings[Y @ -X].add(None)
logger.warning(
logger.debug(
"In test_line_2_10: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_10: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -695,11 +695,11 @@ def test_line_2_3(self):
reflexive_variable_to_value_mappings[Y @ -Y].add(+Y)

nonreflexive_variable_to_value_mappings = defaultdict(set)
logger.warning(
logger.debug(
"In test_line_2_3: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_3: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -721,11 +721,11 @@ def test_line_2_11(self):
reflexive_variable_to_value_mappings[Y @ -Y].add(None)

nonreflexive_variable_to_value_mappings = defaultdict(set)
logger.warning(
logger.debug(
"In test_line_2_11: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_11: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -743,11 +743,11 @@ def test_line_2_4(self):

nonreflexive_variable_to_value_mappings = defaultdict(set)

logger.warning(
logger.debug(
"In test_line_2_4: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_4: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -764,11 +764,11 @@ def test_line_2_5(self):
reflexive_variable_to_value_mappings[Y @ +Y].add(-Y)

nonreflexive_variable_to_value_mappings = defaultdict(set)
logger.warning(
logger.debug(
"In test_line_2_5: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_5: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -786,11 +786,11 @@ def test_line_2_6(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -Z].add(+Y)
logger.warning(
logger.debug(
"In test_line_2_6: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_6: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -810,11 +810,11 @@ def test_line_2_7(self):

nonreflexive_variable_to_value_mappings = defaultdict(set)

logger.warning(
logger.debug(
"In test_line_2_7: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_7: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -841,11 +841,11 @@ def test_line_2_8(self):
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)
nonreflexive_variable_to_value_mappings[Y @ -X].add(None)

logger.warning(
logger.debug(
"In test_line_2_8: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_8: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand All @@ -869,11 +869,11 @@ def test_line_2_9(self):
nonreflexive_variable_to_value_mappings = defaultdict(set)
nonreflexive_variable_to_value_mappings[Y @ -X].add(-Y)

logger.warning(
logger.debug(
"In test_line_2_9: nonreflexive_variable_to_value_mappings = "
+ str(nonreflexive_variable_to_value_mappings)
)
logger.warning(
logger.debug(
"In test_line_2_9: reflexive_variable_to_value_mappings = "
+ str(reflexive_variable_to_value_mappings)
)
Expand Down Expand Up @@ -1777,7 +1777,7 @@ def test_transport_district_intervening_on_parents_2(self):
]
domain_data = [({X}, PP[Pi1](W, X, Y, Z)), (set(), PP[Pi2](W, X, Y, Z))]
expected_result = PP[Pi2](X | Z) * PP[Pi2](Z)
logger.warning(
logger.debug(
"In test_transport_district_intervening_on_parents_2: expected_result is "
+ expected_result.to_latex()
)
Expand Down Expand Up @@ -2485,8 +2485,8 @@ def test_transport_unconditional_counterfactual_query_1(self):
domain_graphs=domain_graphs,
domain_data=domain_data,
)
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result, result_expr)

def test_transport_unconditional_counterfactual_query_2(self):
Expand Down Expand Up @@ -2550,8 +2550,8 @@ def test_transport_unconditional_counterfactual_query_3(self):
domain_graphs=domain_graphs,
domain_data=domain_data,
)
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result, result_expr)
self.assertCountEqual(event, result_event)
# Test sending variables with a value of None into this algorithm
Expand Down Expand Up @@ -3185,10 +3185,10 @@ def test_transport_conditional_counterfactual_query_1(self):
domain_graphs=self.example_1_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -3239,10 +3239,10 @@ def test_transport_conditional_counterfactual_query_2(self):
domain_graphs=self.example_2_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -3588,10 +3588,10 @@ def test_transport_conditional_counterfactual_query_5(self):
domain_graphs=self.example_1_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -3621,10 +3621,10 @@ def test_transport_conditional_counterfactual_query_6(self):
domain_graphs=self.example_1_domain_graphs,
domain_data=domain_data,
)
logger.warning("expected_result_expr = " + expected_result_expr.to_latex())
logger.warning("expected_result_event = " + str(expected_result_event))
logger.warning("Result_expr = " + result_expr.to_latex())
logger.warning("Result_event = " + str(result_event))
logger.debug("expected_result_expr = " + expected_result_expr.to_latex())
logger.debug("expected_result_event = " + str(expected_result_event))
logger.debug("Result_expr = " + result_expr.to_latex())
logger.debug("Result_event = " + str(result_event))
self.assert_expr_equal(expected_result_expr, result_expr)
self.assertCountEqual(expected_result_event, result_event)

Expand Down Expand Up @@ -6036,8 +6036,8 @@ def test_merge_frozen_sets_linked_by_bidirectional_edges(self):
input_sets=test_2_inputs, graph=graph_2
)
expected_result_2 = {frozenset([W, Y]), frozenset([X]), frozenset([W1]), frozenset([R, Z])}
logger.warning(str(expected_result_2))
logger.warning(str(result_2))
logger.debug(str(expected_result_2))
logger.debug(str(result_2))
self.assertSetEqual(result_2, expected_result_2)
graph_3 = NxMixedGraph.from_edges(directed=[], undirected=[(W, X), (X, Y), (R, Z), (W1, Y)])
result_3 = _merge_frozen_sets_linked_by_bidirectional_edges(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_algorithm/test_falsification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test falsification of testable implications given a graph."""

import unittest
import warnings

import numpy as np
import pandas as pd
Expand All @@ -19,7 +20,8 @@ def test_discrete_graph_falsifications(self):
for method in [None, *get_conditional_independence_tests()]:
if method == "pearson":
continue
with self.subTest(method=method):
with self.subTest(method=method), warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=FutureWarning)
issues = get_graph_falsifications(
asia_example.graph, asia_example.data, method=method
)
Expand Down
Loading
Loading