From 4a63b54fb6ec450ffcae010327e5261bd4558792 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 2 Aug 2024 14:52:45 +0200 Subject: [PATCH] add integration test --- .../integration/single_table/test_copulas.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index a2553c90e..c94f1c7fd 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -469,3 +469,37 @@ def test_datetime_values_inside_real_data_range(): assert check_in_synthetic.max() <= check_in_real.max() assert check_out_synthetic.min() >= check_out_real.min() assert check_out_synthetic.max() <= check_out_real.max() + + +def test_support_new_pandas_dtypes(): + """Test that the synthesizer supports the new pandas dtypes.""" + # Setup + data = pd.DataFrame({ + 'Int8': pd.Series([1, 2, -3, pd.NA], dtype='Int8'), + 'Int16': pd.Series([1, 2, -3, pd.NA], dtype='Int16'), + 'Int32': pd.Series([1, 2, -3, pd.NA], dtype='Int32'), + 'Int64': pd.Series([1, 2, pd.NA, -3], dtype='Int64'), + 'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'), + 'Float64': pd.Series([1.113, 2.22, 3.3, pd.NA], dtype='Float64'), + }) + metadata = SingleTableMetadata().load_from_dict({ + 'columns': { + 'Int8': {'sdtype': 'numerical', 'computer_representation': 'Int8'}, + 'Int16': {'sdtype': 'numerical', 'computer_representation': 'Int16'}, + 'Int32': {'sdtype': 'numerical', 'computer_representation': 'Int32'}, + 'Int64': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'Float32': {'sdtype': 'numerical', 'computer_representation': 'Float32'}, + 'Float64': {'sdtype': 'numerical', 'computer_representation': 'Float64'}, + } + }) + + synthesizer = GaussianCopulaSynthesizer(metadata) + + # Run + synthesizer.fit(data) + synthetic_data = synthesizer.sample(10) + + # Assert + assert (synthetic_data.dtypes == data.dtypes).all() + assert (synthetic_data['Float32'] == synthetic_data['Float32'].round(1)).all(skipna=True) + assert (synthetic_data['Float64'] == synthetic_data['Float64'].round(3)).all(skipna=True)