diff --git a/.gitignore b/.gitignore index 7cdd5b60..1da51175 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ __pycache__/ .envrc .manage +.venv +.coverage +.vscode +cov.xml *.DS_Store* diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..a5a5b873 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --cov=aequitas diff --git a/requirements/cli.txt b/requirements/cli.txt index 7a63afc6..b43f09ad 100644 --- a/requirements/cli.txt +++ b/requirements/cli.txt @@ -1,5 +1,5 @@ SQLAlchemy>=1.1.1 tabulate==0.8.2 -xhtml2pdf==0.2.2 +xhtml2pdf==0.2.15 ohio>=0.2.0 markdown2==2.3.5 diff --git a/requirements/main.txt b/requirements/main.txt index 5dbd830b..cad49c51 100644 --- a/requirements/main.txt +++ b/requirements/main.txt @@ -12,4 +12,5 @@ fairlearn>=0.8.0 hydra-core>=1.3.0 validators>=0.22.0 hyperparameter-tuning>=0.3.1 -numpy==1.23.5 \ No newline at end of file +numpy==1.23.5 +fastparquet==2024.2.0 \ No newline at end of file diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 00000000..cf5ce115 --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Create a virtual environment +python3 -m venv .venv + +# Activate the virtual environment +source .venv/bin/activate + +# Install the required packages +pip install -r ./requirements/main.txt -r ./requirements/cli.txt -r ./requirements/webapp.txt +pip install -e . + +pip install pytest pytest-cov + +# Run the tests +coverage run -m pytest --cov-report xml:cov.xml --cov-report term + +# Deactivate the virtual environment +deactivate diff --git a/src/tests/README.rst b/src/tests/README.rst deleted file mode 100644 index 776e9eaa..00000000 --- a/src/tests/README.rst +++ /dev/null @@ -1,18 +0,0 @@ -# Testing using pytest - -These tests were developed using pytest version 3.5.1 - -- Install pytest - - pip install -U pytest - -- Run pytest -in a terminal - - pytest - -For debugging: - pytest -pdb - -For more output (e.g. to see which columns are being compared) - pytest -s \ No newline at end of file diff --git a/src/tests/test_bias_report.py b/src/tests/test_bias_report.py deleted file mode 100644 index 978dadb0..00000000 --- a/src/tests/test_bias_report.py +++ /dev/null @@ -1,120 +0,0 @@ -# Aequitas -# -# Test code for aequitas_audit - - -import os -import sys - -timeout = 60 - -sys.path.append(os.getcwd()) - -# Get the test files from the same directory as -# this file. -BASE_DIR = os.path.dirname(__file__) - -import pytest -import pandas as pd -import numpy as np -from aequitas_cli.aequitas_audit import audit -from aequitas_cli.utils.configs_loader import Configs - -def helper(input_filename, expected_filename, config_file): - ''' - - ''' - - input_filename = os.path.join(BASE_DIR, input_filename) - expected_df = pd.read_csv(os.path.join(BASE_DIR, expected_filename)) - - if config_file: - config_file = os.path.join(BASE_DIR, config_file) - - config = Configs.load_configs(config_file) - - test_df, _ = audit(pd.read_csv(os.path.join(BASE_DIR, input_filename)), config) - - # match expected_df columns - shared_columns = [c for c in expected_df.columns if c in test_df.columns] - - try: - expected_df = expected_df[shared_columns] - test_df = test_df[shared_columns] - combined_data = pd.merge(expected_df, test_df, on=['attribute_name', 'attribute_value']) - # subtract expected_df from test_df - except: - # collect output for - print('could not merge') - return (test_df, expected_df) - # see if close enough to 0 - - s = "" - EPS = 1e-6 - for col in shared_columns: - if col not in {'attribute_value', 'attribute_name'}: - print('testing {} ...'.format(col)) - - try: - # TypeError: numpy boolean subtract, the `-` operator, is - # deprecated, use the bitwise_xor, the `^` operator, or the - # logical_xor function instead. - # found online that casting as float64 will go around, but - # would like to get Jesse's take on best way to avoid issue. - if np.mean(combined_data[col + "_x"].astype("float64") - combined_data[col + "_y"].astype("float64")) > EPS: - exp_mean = np.mean(combined_data[col + "_x"]) - aeq_mean = np.mean(combined_data[col + "_y"]) - s += "{} fails: Expected {} on average, but aequitas returned {}\n".format(col, exp_mean, aeq_mean) - - pytest.fail(s) - - except: - if not all(combined_data[col + "_x"] == combined_data[col + "_y"]): - s += "{} fails: at least one entry was not the same between data sets\n".format(col) - pytest.fail(s) - - -# simplest tests -def test_group_class_1(): - # test that the results from group are as expected - return helper('test_1.csv', 'expected_output_group_test_1.csv', 'test_1.yaml') - - -def test_bias_class_1(): - # test that the results from bias are as expected (note it also tests group) - return helper('test_1.csv', 'expected_output_bias_test_1.csv', 'test_1.yaml') - - -def test_fairness_class_1(): - # test that the results from fairness are as expected (note it also tests bias and group) - return helper('test_1.csv', 'expected_output_fairness_test_1.csv', 'test_1.yaml') - - -def test_common_attributes_2(): - # test that aequitas deals with shared group attribute labels - return helper('test_2.csv', 'expected_output_test_2.csv', 'test_2.yaml') - - -def test_all_1_scores_3(): - return helper('test_3.csv', 'expected_output_test_3.csv', 'test_1.yaml') - - -def test_all_0_scores_4(): - return helper('test_4.csv', 'expected_output_test_4.csv', 'test_1.yaml') - - -def test_all_1_labels_5(): - return helper('test_5.csv', 'expected_output_test_5.csv', 'test_1.yaml') - - -def test_all_0_labels_6(): - return helper('test_6.csv', 'expected_output_test_6.csv', 'test_1.yaml') - - -def test_threshold_7(): - return helper('test_1.csv', 'expected_output_test_7.csv', 'test_3.yaml') - - -def test_threshold_8(): - return helper('test_1.csv', 'expected_output_test_8.csv', 'test_4.yaml') - diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..3ca5a175 --- /dev/null +++ b/tests/README.md @@ -0,0 +1,13 @@ +# Unit Testing in Aequitas + +To run unittests locally, you need to run the following commands from the base folder of the project. + +The below commans only needs to be run once to give the testing script necessary permission to run + +`chmod +x run_tests.sh` + +The testing script itself can be run using the following commans + +`./run_tests.sh` + +The script will run all defined tests in a virtual environment and output the coverage report in the terminal. It will also add an xml file of the coverage report `cov.xml` to the base folder. diff --git a/src/tests/expected_output_bias_test_1.csv b/tests/test_artifacts/test_bias_report/expected_output_bias_test_1.csv similarity index 100% rename from src/tests/expected_output_bias_test_1.csv rename to tests/test_artifacts/test_bias_report/expected_output_bias_test_1.csv diff --git a/src/tests/expected_output_fairness_test_1.csv b/tests/test_artifacts/test_bias_report/expected_output_fairness_test_1.csv similarity index 100% rename from src/tests/expected_output_fairness_test_1.csv rename to tests/test_artifacts/test_bias_report/expected_output_fairness_test_1.csv diff --git a/src/tests/expected_output_group_test_1.csv b/tests/test_artifacts/test_bias_report/expected_output_group_test_1.csv similarity index 100% rename from src/tests/expected_output_group_test_1.csv rename to tests/test_artifacts/test_bias_report/expected_output_group_test_1.csv diff --git a/src/tests/expected_output_test_1.csv b/tests/test_artifacts/test_bias_report/expected_output_test_1.csv similarity index 100% rename from src/tests/expected_output_test_1.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_1.csv diff --git a/src/tests/expected_output_test_10.csv b/tests/test_artifacts/test_bias_report/expected_output_test_10.csv similarity index 100% rename from src/tests/expected_output_test_10.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_10.csv diff --git a/src/tests/expected_output_test_2.csv b/tests/test_artifacts/test_bias_report/expected_output_test_2.csv similarity index 100% rename from src/tests/expected_output_test_2.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_2.csv diff --git a/src/tests/expected_output_test_3.csv b/tests/test_artifacts/test_bias_report/expected_output_test_3.csv similarity index 100% rename from src/tests/expected_output_test_3.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_3.csv diff --git a/src/tests/expected_output_test_4.csv b/tests/test_artifacts/test_bias_report/expected_output_test_4.csv similarity index 100% rename from src/tests/expected_output_test_4.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_4.csv diff --git a/src/tests/expected_output_test_5.csv b/tests/test_artifacts/test_bias_report/expected_output_test_5.csv similarity index 100% rename from src/tests/expected_output_test_5.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_5.csv diff --git a/src/tests/expected_output_test_6.csv b/tests/test_artifacts/test_bias_report/expected_output_test_6.csv similarity index 100% rename from src/tests/expected_output_test_6.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_6.csv diff --git a/src/tests/expected_output_test_7.csv b/tests/test_artifacts/test_bias_report/expected_output_test_7.csv similarity index 100% rename from src/tests/expected_output_test_7.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_7.csv diff --git a/src/tests/expected_output_test_8.csv b/tests/test_artifacts/test_bias_report/expected_output_test_8.csv similarity index 100% rename from src/tests/expected_output_test_8.csv rename to tests/test_artifacts/test_bias_report/expected_output_test_8.csv diff --git a/src/tests/test_1.csv b/tests/test_artifacts/test_bias_report/test_1.csv similarity index 100% rename from src/tests/test_1.csv rename to tests/test_artifacts/test_bias_report/test_1.csv diff --git a/src/tests/test_1.yaml b/tests/test_artifacts/test_bias_report/test_1.yaml similarity index 100% rename from src/tests/test_1.yaml rename to tests/test_artifacts/test_bias_report/test_1.yaml diff --git a/src/tests/test_10.csv b/tests/test_artifacts/test_bias_report/test_10.csv similarity index 100% rename from src/tests/test_10.csv rename to tests/test_artifacts/test_bias_report/test_10.csv diff --git a/src/tests/test_1_aequitas_20180501-114554.csv b/tests/test_artifacts/test_bias_report/test_1_aequitas_20180501-114554.csv similarity index 100% rename from src/tests/test_1_aequitas_20180501-114554.csv rename to tests/test_artifacts/test_bias_report/test_1_aequitas_20180501-114554.csv diff --git a/src/tests/test_2.csv b/tests/test_artifacts/test_bias_report/test_2.csv similarity index 100% rename from src/tests/test_2.csv rename to tests/test_artifacts/test_bias_report/test_2.csv diff --git a/src/tests/test_2.yaml b/tests/test_artifacts/test_bias_report/test_2.yaml similarity index 100% rename from src/tests/test_2.yaml rename to tests/test_artifacts/test_bias_report/test_2.yaml diff --git a/src/tests/test_3.csv b/tests/test_artifacts/test_bias_report/test_3.csv similarity index 100% rename from src/tests/test_3.csv rename to tests/test_artifacts/test_bias_report/test_3.csv diff --git a/src/tests/test_3.yaml b/tests/test_artifacts/test_bias_report/test_3.yaml similarity index 100% rename from src/tests/test_3.yaml rename to tests/test_artifacts/test_bias_report/test_3.yaml diff --git a/src/tests/test_4.csv b/tests/test_artifacts/test_bias_report/test_4.csv similarity index 100% rename from src/tests/test_4.csv rename to tests/test_artifacts/test_bias_report/test_4.csv diff --git a/src/tests/test_4.yaml b/tests/test_artifacts/test_bias_report/test_4.yaml similarity index 100% rename from src/tests/test_4.yaml rename to tests/test_artifacts/test_bias_report/test_4.yaml diff --git a/src/tests/test_5.csv b/tests/test_artifacts/test_bias_report/test_5.csv similarity index 100% rename from src/tests/test_5.csv rename to tests/test_artifacts/test_bias_report/test_5.csv diff --git a/src/tests/test_6.csv b/tests/test_artifacts/test_bias_report/test_6.csv similarity index 100% rename from src/tests/test_6.csv rename to tests/test_artifacts/test_bias_report/test_6.csv diff --git a/src/tests/test_7.csv b/tests/test_artifacts/test_bias_report/test_7.csv similarity index 100% rename from src/tests/test_7.csv rename to tests/test_artifacts/test_bias_report/test_7.csv diff --git a/src/tests/test_8.csv b/tests/test_artifacts/test_bias_report/test_8.csv similarity index 100% rename from src/tests/test_8.csv rename to tests/test_artifacts/test_bias_report/test_8.csv diff --git a/src/tests/test_9.csv b/tests/test_artifacts/test_bias_report/test_9.csv similarity index 100% rename from src/tests/test_9.csv rename to tests/test_artifacts/test_bias_report/test_9.csv diff --git a/tests/test_artifacts/test_generic/data.csv b/tests/test_artifacts/test_generic/data.csv new file mode 100644 index 00000000..547dc41c --- /dev/null +++ b/tests/test_artifacts/test_generic/data.csv @@ -0,0 +1,11 @@ +label,sensitive,feature1,feature2 +1,A,0.5,0.1 +0,B,0.2,0.4 +1,C,0.7,0.6 +0,B,0.3,0.8 +1,A,0.9,0.3 +1,A,0.5,0.5 +0,A,0.2,0.2 +1,C,0.7,0.7 +0,B,0.3,0.3 +1,A,0.9,0.9 diff --git a/tests/test_artifacts/test_generic/data.parquet b/tests/test_artifacts/test_generic/data.parquet new file mode 100644 index 00000000..ee60c7d2 Binary files /dev/null and b/tests/test_artifacts/test_generic/data.parquet differ diff --git a/tests/test_artifacts/test_generic/data_test.csv b/tests/test_artifacts/test_generic/data_test.csv new file mode 100644 index 00000000..a3e53a36 --- /dev/null +++ b/tests/test_artifacts/test_generic/data_test.csv @@ -0,0 +1,2 @@ +label,sensitive,feature1,feature2 +1,A,0.9,0.9 diff --git a/tests/test_artifacts/test_generic/data_train.csv b/tests/test_artifacts/test_generic/data_train.csv new file mode 100644 index 00000000..a34ba5e5 --- /dev/null +++ b/tests/test_artifacts/test_generic/data_train.csv @@ -0,0 +1,8 @@ +label,sensitive,feature1,feature2 +1,A,0.5,0.1 +0,B,0.2,0.4 +1,C,0.7,0.6 +0,B,0.3,0.8 +1,A,0.9,0.3 +1,A,0.5,0.5 +0,A,0.2,0.2 diff --git a/tests/test_artifacts/test_generic/data_validation.csv b/tests/test_artifacts/test_generic/data_validation.csv new file mode 100644 index 00000000..e07a1b29 --- /dev/null +++ b/tests/test_artifacts/test_generic/data_validation.csv @@ -0,0 +1,3 @@ +label,sensitive,feature1,feature2 +1,C,0.7,0.7 +0,B,0.3,0.3 diff --git a/tests/test_bias_report.py b/tests/test_bias_report.py new file mode 100644 index 00000000..8d54d61e --- /dev/null +++ b/tests/test_bias_report.py @@ -0,0 +1,170 @@ +# Aequitas +# +# Test code for aequitas_audit + + +import os +import sys + +timeout = 60 + +sys.path.append(os.getcwd()) + +# Get the test files from the same directory as +# this file. +BASE_DIR = os.path.dirname(__file__) + +import pytest +import pandas as pd +import numpy as np +from aequitas_cli.aequitas_audit import audit +from aequitas_cli.utils.configs_loader import Configs + + +def helper(input_filename, expected_filename, config_file): + """ """ + + input_filename = os.path.join(BASE_DIR, input_filename) + expected_df = pd.read_csv(os.path.join(BASE_DIR, expected_filename)) + + if config_file: + config_file = os.path.join(BASE_DIR, config_file) + + config = Configs.load_configs(config_file) + + test_df, _ = audit(pd.read_csv(os.path.join(BASE_DIR, input_filename)), config) + + # match expected_df columns + shared_columns = [c for c in expected_df.columns if c in test_df.columns] + + try: + expected_df = expected_df[shared_columns] + test_df = test_df[shared_columns] + combined_data = pd.merge( + expected_df, test_df, on=["attribute_name", "attribute_value"] + ) + # subtract expected_df from test_df + except: + # collect output for + print("could not merge") + return (test_df, expected_df) + # see if close enough to 0 + + s = "" + EPS = 1e-6 + for col in shared_columns: + if col not in {"attribute_value", "attribute_name"}: + print("testing {} ...".format(col)) + + try: + # TypeError: numpy boolean subtract, the `-` operator, is + # deprecated, use the bitwise_xor, the `^` operator, or the + # logical_xor function instead. + # found online that casting as float64 will go around, but + # would like to get Jesse's take on best way to avoid issue. + if ( + np.mean( + combined_data[col + "_x"].astype("float64") + - combined_data[col + "_y"].astype("float64") + ) + > EPS + ): + exp_mean = np.mean(combined_data[col + "_x"]) + aeq_mean = np.mean(combined_data[col + "_y"]) + s += "{} fails: Expected {} on average, but aequitas returned {}\n".format( + col, exp_mean, aeq_mean + ) + + pytest.fail(s) + + except: + if not all(combined_data[col + "_x"] == combined_data[col + "_y"]): + s += "{} fails: at least one entry was not the same between data sets\n".format( + col + ) + pytest.fail(s) + + +# simplest tests +def test_group_class_1(): + # test that the results from group are as expected + return helper( + "test_artifacts/test_bias_report/test_1.csv", + "test_artifacts/test_bias_report/expected_output_group_test_1.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_bias_class_1(): + # test that the results from bias are as expected (note it also tests group) + return helper( + "test_artifacts/test_bias_report/test_1.csv", + "test_artifacts/test_bias_report/expected_output_bias_test_1.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_fairness_class_1(): + # test that the results from fairness are as expected (note it also tests bias and group) + return helper( + "test_artifacts/test_bias_report/test_1.csv", + "test_artifacts/test_bias_report/expected_output_fairness_test_1.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_common_attributes_2(): + # test that aequitas deals with shared group attribute labels + return helper( + "test_artifacts/test_bias_report/test_2.csv", + "test_artifacts/test_bias_report/expected_output_test_2.csv", + "test_artifacts/test_bias_report/test_2.yaml", + ) + + +def test_all_1_scores_3(): + return helper( + "test_artifacts/test_bias_report/test_3.csv", + "test_artifacts/test_bias_report/expected_output_test_3.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_all_0_scores_4(): + return helper( + "test_artifacts/test_bias_report/test_4.csv", + "test_artifacts/test_bias_report/expected_output_test_4.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_all_1_labels_5(): + return helper( + "test_artifacts/test_bias_report/test_5.csv", + "test_artifacts/test_bias_report/expected_output_test_5.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_all_0_labels_6(): + return helper( + "test_artifacts/test_bias_report/test_6.csv", + "test_artifacts/test_bias_report/expected_output_test_6.csv", + "test_artifacts/test_bias_report/test_1.yaml", + ) + + +def test_threshold_7(): + return helper( + "test_artifacts/test_bias_report/test_1.csv", + "test_artifacts/test_bias_report/expected_output_test_7.csv", + "test_artifacts/test_bias_report/test_3.yaml", + ) + + +def test_threshold_8(): + return helper( + "test_artifacts/test_bias_report/test_1.csv", + "test_artifacts/test_bias_report/expected_output_test_8.csv", + "test_artifacts/test_bias_report/test_4.yaml", + ) diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 00000000..8653ba2f --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,249 @@ +import unittest +import os +import pandas as pd +from aequitas.flow.datasets.generic import GenericDataset + +BASE_DIR = os.path.dirname(__file__) + + +class TestGenericDataset(unittest.TestCase): + def setUp(self): + # Create a sample dataset for testing + self.df = pd.DataFrame( + { + "label": [1, 0, 1, 0, 1, 1, 0, 1, 0, 1], + "sensitive": ["A", "B", "C", "B", "A", "A", "A", "C", "B", "A"], + "feature1": [0.5, 0.2, 0.7, 0.3, 0.9, 0.5, 0.2, 0.7, 0.3, 0.9], + "feature2": [0.1, 0.4, 0.6, 0.8, 0.3, 0.5, 0.2, 0.7, 0.3, 0.9], + } + ) + + def test_load_data_from_dataframe(self): + dataset = GenericDataset( + label_column="label", sensitive_column="sensitive", df=self.df + ) + dataset.load_data() + self.assertEqual(len(dataset.data), len(self.df)) + self.assertTrue("label" in dataset.data.columns) + self.assertTrue("sensitive" in dataset.data.columns) + + def test_load_data_from_path_parquet(self): + dataset = GenericDataset( + label_column="label", + sensitive_column="sensitive", + extension="parquet", + dataset_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data.parquet" + ), + ) + dataset.load_data() + self.assertEqual(len(dataset.data), 10) + self.assertTrue("label" in dataset.data.columns) + self.assertTrue("sensitive" in dataset.data.columns) + + def test_load_data_from_path_csv(self): + dataset = GenericDataset( + label_column="label", + sensitive_column="sensitive", + extension="csv", + dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + ) + dataset.load_data() + self.assertEqual(len(dataset.data), 10) + self.assertTrue("label" in dataset.data.columns) + self.assertTrue("sensitive" in dataset.data.columns) + + def test_load_data_from_multiple_paths(self): + dataset = GenericDataset( + label_column="label", + sensitive_column="sensitive", + extension="csv", + train_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_train.csv" + ), + validation_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + ), + test_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_test.csv" + ), + ) + dataset.load_data() + dataset.create_splits() + self.assertEqual(len(dataset.train), 7) + self.assertEqual(len(dataset.validation), 2) + self.assertEqual(len(dataset.test), 1) + self.assertTrue("label" in dataset.data.columns) + self.assertTrue("sensitive" in dataset.data.columns) + + def test_create_splits_random(self): + dataset = GenericDataset( + label_column="label", sensitive_column="sensitive", df=self.df + ) + dataset.load_data() + dataset.create_splits() + self.assertEqual(len(dataset.train), 7) + self.assertEqual(len(dataset.validation), 2) + self.assertEqual(len(dataset.test), 1) + + def test_create_splits_column(self): + dataset = GenericDataset( + label_column="label", + sensitive_column="sensitive", + df=self.df, + split_type="column", + split_column="sensitive", + split_values={"train": ["A"], "validation": ["B"], "test": ["C"]}, + ) + dataset.create_splits() + self.assertEqual(len(dataset.train), 5) + self.assertEqual(len(dataset.validation), 3) + self.assertEqual(len(dataset.test), 2) + + def test_create_splits_column_from_path(self): + dataset = GenericDataset( + label_column="label", + sensitive_column="sensitive", + dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + split_type="column", + extension="csv", + split_column="sensitive", + split_values={"train": ["A"], "validation": ["B"], "test": ["C"]}, + ) + dataset.load_data() + dataset.create_splits() + self.assertEqual(len(dataset.train), 5) + self.assertEqual(len(dataset.validation), 3) + self.assertEqual(len(dataset.test), 2) + + def test_all_paths_provided(self): + self.assertRaisesRegex( + ValueError, + "If single dataset path is passed, the other paths must be None.", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + train_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_train.csv" + ), + validation_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + ), + test_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_test.csv" + ), + dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + ) + + def test_missing_paths(self): + self.assertRaisesRegex( + ValueError, + "If multiple dataset paths are passed, the single path must be" "`None`.", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + train_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_train.csv" + ), + validation_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + ), + ) + + def test_invalid_path(self): + self.assertRaisesRegex( + ValueError, + "Invalid path:*", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + train_path=os.path.join(BASE_DIR, "test_artifacts/data_train.csv"), + validation_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + ), + test_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data_test.csv" + ), + ) + + def test_missing_split_key(self): + self.assertRaisesRegex( + ValueError, + "Missing key in passed splits: test", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + split_values={"train": 0.63, "validation": 0.37}, + ) + + def test_invalid_splits(self): + self.assertRaisesRegex( + ValueError, + "Invalid split sizes. Make sure the sum of proportions for all the" + " datasets is equal to or lower than 1.", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + split_values={"train": 0.63, "validation": 0.37, "test": 0.2}, + ) + + def test_invalid_splits_warn(self): + with self.assertLogs("datasets.GenericDataset", level="WARN") as cm: + dataset = GenericDataset( + label_column="label", + sensitive_column="sensitive", + dataset_path=os.path.join( + BASE_DIR, "test_artifacts/test_generic/data.csv" + ), + split_values={"train": 0.3, "validation": 0.1, "test": 0.2}, + ) + self.assertEqual( + cm.output, + [ + "WARNING:datasets.GenericDataset:Using only 0.6000000000000001 of the dataset." + ], + ) + + def test_missing_splits_column(self): + self.assertRaisesRegex( + ValueError, + "Split column must be specified when using column split.", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + df=self.df, + split_type="column", + split_values={"train": ["A"], "validation": ["B"], "test": ["C"]}, + ) + + def test_wrong_splits_column(self): + self.assertRaisesRegex( + ValueError, + "Split column must be a column in the dataset.", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + df=self.df, + split_type="column", + split_column="test", + split_values={"train": ["A"], "validation": ["B"], "test": ["C"]}, + ) + + def test_wrong_splits_value(self): + self.assertRaisesRegex( + ValueError, + "Split values must be present in the split column.", + GenericDataset, + label_column="label", + sensitive_column="sensitive", + split_type="column", + df=self.df, + split_column="sensitive", + split_values={"train": ["D"], "validation": ["B"], "test": ["C"]}, + ) + + +if __name__ == "__main__": + unittest.main()