diff --git a/.github/workflows/dtypes_benchmark.yml b/.github/workflows/dtypes_benchmark.yml new file mode 100644 index 000000000..18355116c --- /dev/null +++ b/.github/workflows/dtypes_benchmark.yml @@ -0,0 +1,80 @@ +name: Data Types Benchmark + +on: + push: + branches: + - main + +jobs: + run_dtypes_benchmark: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install invoke .[test] + + - name: Create folder and JSON file + run: | + mkdir -p results + touch results/${{ matrix.python-version }}.json + + # Run the benchmarking + - name: Benchmark Data Types + env: + PYDRIVE_CREDENTIALS: ${{ secrets.PYDRIVE_CREDENTIALS }} + run: | + invoke benchmark-dtypes + + # Upload the json files as artifacts + - name: Upload artifacts + uses: actions/upload-artifact@v3 + with: + name: results-${{ matrix.python-version }} + path: results/*.json + + generate_dtypes_report: + runs-on: ubuntu-latest + needs: run_dtypes_benchmark + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + # Set up Python 3.10 + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: '3.10' + + - name: Install dependencies for report + run: | + python -m pip install --upgrade pip + python -m pip install .[test] + + # Download the artifacts + - name: Download artifacts + uses: actions/download-artifact@v3 + with: + path: results/ + + # Generate the report + - name: Generate the report + env: + PYDRIVE_CREDENTIALS: ${{ secrets.PYDRIVE_CREDENTIALS }} + SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} + + run: python -m tests.benchmark.utils diff --git a/.github/workflows/install.yaml b/.github/workflows/install.yaml index 3970f1b73..106d19272 100644 --- a/.github/workflows/install.yaml +++ b/.github/workflows/install.yaml @@ -5,6 +5,11 @@ on: push: branches: - main + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: install: name: ${{ matrix.python_version }} install diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 717b7fa3a..8ab8aecb0 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -5,6 +5,10 @@ on: pull_request: types: [opened, reopened] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: integration: runs-on: ${{ matrix.os }} @@ -29,4 +33,9 @@ jobs: python -m pip install --upgrade pip python -m pip install invoke .[test] - name: Run integration tests - run: invoke integration + env: + PYDRIVE_CREDENTIALS: ${{ secrets.PYDRIVE_CREDENTIALS }} + + run: | + invoke integration + invoke benchmark-dtypes diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml index 10ac72666..de032a13b 100644 --- a/.github/workflows/minimum.yml +++ b/.github/workflows/minimum.yml @@ -5,6 +5,10 @@ on: pull_request: types: [opened, reopened] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: minimum: runs-on: ${{ matrix.os }} diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml index 55883f20a..3b4217b1b 100644 --- a/.github/workflows/unit.yml +++ b/.github/workflows/unit.yml @@ -5,6 +5,10 @@ on: pull_request: types: [opened, reopened] +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + jobs: unit: runs-on: ${{ matrix.os }} diff --git a/HISTORY.md b/HISTORY.md index 482b05c84..43978481a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,6 +1,16 @@ # Release Notes -### v1.16.1 - 2024-08-27 +### v1.16.2 - 2024-09-25 + +### New Features + +* Supported data types benchmark - Issue [#2200](https://github.com/sdv-dev/SDV/issues/2200) by @pvk-developer + +### Bugs Fixed + +* The `_validate_circular_relationships` method may fail to detect circular relationships - Issue [#2205](https://github.com/sdv-dev/SDV/issues/2205) by @fealho + +## v1.16.1 - 2024-08-27 ### Internal diff --git a/latest_requirements.txt b/latest_requirements.txt index 15f5b8299..837722ec3 100644 --- a/latest_requirements.txt +++ b/latest_requirements.txt @@ -4,8 +4,8 @@ ctgan==0.10.1 deepecho==0.6.0 graphviz==0.20.3 numpy==1.26.4 -pandas==2.2.2 -platformdirs==4.2.2 -rdt==1.12.3 +pandas==2.2.3 +platformdirs==4.3.6 +rdt==1.12.4 sdmetrics==0.15.1 tqdm==4.66.5 diff --git a/pyproject.toml b/pyproject.toml index a5133e024..b68fa75b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ 'ctgan>=0.10.0', 'deepecho>=0.6.0', 'rdt>=1.12.3', - 'sdmetrics>=0.14.0', + 'sdmetrics>=0.16.0', 'platformdirs>=4.0', 'pyyaml>=6.0.1', ] @@ -62,6 +62,10 @@ test = [ 'rundoc>=0.4.3,<0.5', 'pytest-runner >= 2.11.1', 'tomli>=2.0.0,<3', + 'pydrive', + 'pyarrow', + 'gitpython', + 'slack-sdk>=3.23,<4.0', ] pomegranate = ['pomegranate>=0.14.3,<0.15'] dev = [ @@ -132,7 +136,7 @@ namespaces = false version = {attr = 'sdv.__version__'} [tool.bumpversion] -current_version = "1.16.1" +current_version = "1.16.2.dev1" parse = '(?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?' serialize = [ '{major}.{minor}.{patch}.{release}{candidate}', @@ -181,6 +185,7 @@ exclude = [ ".tox", ".git", "__pycache__", + "*.ipynb", ".ipynb_checkpoints", "tasks.py", ] diff --git a/sdv/__init__.py b/sdv/__init__.py index b265a2855..931dd4b69 100644 --- a/sdv/__init__.py +++ b/sdv/__init__.py @@ -6,7 +6,7 @@ __author__ = 'DataCebo, Inc.' __email__ = 'info@sdv.dev' -__version__ = '1.16.1' +__version__ = '1.16.2.dev1' import sys diff --git a/sdv/_utils.py b/sdv/_utils.py index 3ec466537..40464b76d 100644 --- a/sdv/_utils.py +++ b/sdv/_utils.py @@ -9,6 +9,7 @@ from pathlib import Path import pandas as pd +from pandas.api.types import is_float, is_integer from pandas.core.tools.datetimes import _guess_datetime_format_for_array from rdt.transformers.utils import _GENERATORS @@ -81,6 +82,7 @@ def _is_datetime_type(value): bool(_get_datetime_format([value])) or isinstance(value, pd.Timestamp) or isinstance(value, datetime) + or (isinstance(value, str) and pd.notna(pd.to_datetime(value, errors='coerce'))) ): return False @@ -439,3 +441,11 @@ def get_possible_chars(regex, num_subpatterns=None): possible_chars += _get_chars_for_option(option, params) return possible_chars + + +def _is_numerical(value): + """Determine if the input is a numerical type or not.""" + try: + return is_integer(value) or is_float(value) + except Exception: + return False diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 04ede421b..a2c89b00b 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -36,7 +36,7 @@ import numpy as np import pandas as pd -from sdv._utils import _convert_to_timedelta, _create_unique_name, _is_datetime_type +from sdv._utils import _convert_to_timedelta, _create_unique_name, _is_datetime_type, _is_numerical from sdv.constraints.base import Constraint from sdv.constraints.errors import ( AggregateConstraintsError, @@ -604,7 +604,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs): sdtype = metadata.columns.get(column_name, {}).get('sdtype') value = kwargs.get('value') if sdtype == 'numerical': - if not isinstance(value, (int, float)): + if not _is_numerical(value): raise ConstraintMetadataError("'value' must be an int or float.") elif sdtype == 'datetime': @@ -632,7 +632,7 @@ def _validate_init_inputs(column_name, value, relation): if relation not in ['>', '>=', '<', '<=']: raise ValueError('`relation` must be one of the following: `>`, `>=`, `<`, `<=`') - if not (isinstance(value, (int, float)) or value_is_datetime): + if not (_is_numerical(value) or value_is_datetime): raise ValueError('`value` must be a number or a string that represents a datetime.') if value_is_datetime and not isinstance(value, str): @@ -1071,9 +1071,7 @@ def _validate_init_inputs(low_value, high_value): if values_are_datetimes and not values_are_strings: raise ValueError('Datetime must be represented as a string.') - values_are_numerical = bool( - isinstance(low_value, (int, float)) and isinstance(high_value, (int, float)) - ) + values_are_numerical = bool(_is_numerical(low_value) and _is_numerical(high_value)) if not (values_are_numerical or values_are_datetimes): raise ValueError( '``low_value`` and ``high_value`` must be a number or a string that ' @@ -1092,7 +1090,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs): high_value = kwargs.get('high_value') low_value = kwargs.get('low_value') if sdtype == 'numerical': - if not isinstance(high_value, (int, float)) or not isinstance(low_value, (int, float)): + if not _is_numerical(high_value) or not _is_numerical(low_value): raise ConstraintMetadataError( "Both 'high_value' and 'low_value' must be ints or floats" ) @@ -1187,11 +1185,7 @@ def is_valid(self, table_data): self._operator(data, self._high_value), pd.isna(self._high_value), ) - - return np.logical_or( - np.logical_and(satisfy_low_bound, satisfy_high_bound), - pd.isna(data), - ) + return (satisfy_low_bound & satisfy_high_bound) | pd.isna(data) def _transform(self, table_data): """Transform the table data. @@ -1250,7 +1244,7 @@ def _reverse_transform(self, table_data): table_data[self._column_name] = data.round().astype(self._dtype) else: - table_data[self._column_name] = data.astype(self._dtype) + table_data[self._column_name] = data.astype(self._dtype, errors='ignore') table_data = table_data.drop(self._transformed_column, axis=1) return table_data diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 7c6e2a17e..61be08db8 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -126,10 +126,10 @@ def _validate_relationship_sdtypes( ) def _validate_circular_relationships( - self, parent, children=None, parents=None, child_map=None, errors=None + self, parent, children=None, visited=None, child_map=None, errors=None ): """Validate that there is no circular relationship in the metadata.""" - parents = set() if parents is None else parents + visited = set() if visited is None else visited if children is None: children = child_map[parent] @@ -137,15 +137,15 @@ def _validate_circular_relationships( errors.append(parent) for child in children: - if child in parents: - break + if child in visited: + continue - parents.add(child) + visited.add(child) self._validate_circular_relationships( parent, children=child_map.get(child, set()), child_map=child_map, - parents=parents, + visited=visited, errors=errors, ) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 5565428f9..83c80ae25 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -70,9 +70,10 @@ def _set_temp_numpy_seed(self): def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) + synthesizer_parameters = {'locales': self.locales} + synthesizer_parameters.update(self._table_parameters.get(table_name, {})) self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, locales=self.locales, **synthesizer_parameters + metadata=table_metadata, **synthesizer_parameters ) self._table_synthesizers[table_name]._data_processor.table_name = table_name @@ -340,6 +341,10 @@ def _store_and_convert_original_cols(self, data): data[table] = dataframe return list_of_changed_tables + def _transform_helper(self, data): + """Stub method for transforming data patterns.""" + return data + def preprocess(self, data): """Transform the raw data to numerical space. @@ -353,6 +358,7 @@ def preprocess(self, data): """ list_of_changed_tables = self._store_and_convert_original_cols(data) + data = self._transform_helper(data) self.validate(data) if self._fitted: warnings.warn( @@ -471,6 +477,10 @@ def reset_sampling(self): def _sample(self, scale): raise NotImplementedError() + def _reverse_transform_helper(self, sampled_data): + """Stub method for reverse transforming data patterns.""" + return sampled_data + def sample(self, scale=1.0): """Generate synthetic data for the entire dataset. @@ -495,6 +505,7 @@ def sample(self, scale=1.0): with self._set_temp_numpy_seed(), disable_single_table_logger(): sampled_data = self._sample(scale=scale) + sampled_data = self._reverse_transform_helper(sampled_data) total_rows = 0 total_columns = 0 diff --git a/static_code_analysis.txt b/static_code_analysis.txt index 39180cd7d..6841439dc 100644 --- a/static_code_analysis.txt +++ b/static_code_analysis.txt @@ -1,4 +1,4 @@ -Run started:2024-08-23 00:37:57.536879 +Run started:2024-08-27 21:12:42.890265 Test results: >> Issue: [B105:hardcoded_password_string] Possible hardcoded password: '# Release Notes @@ -69,7 +69,7 @@ Test results: -------------------------------------------------- Code scanned: - Total lines of code: 12451 + Total lines of code: 12453 Total lines skipped (#nosec): 0 Total potential issues skipped due to specifically being disabled (e.g., #nosec BXXX): 0 diff --git a/tasks.py b/tasks.py index f7a12f619..a98d438a1 100644 --- a/tasks.py +++ b/tasks.py @@ -34,6 +34,11 @@ def integration(c): c.run('python -m pytest ./tests/integration --reruns 3') +@task +def benchmark_dtypes(c): + c.run('python -m pytest ./tests/benchmark/supported_dtypes_benchmark.py') + + def _get_minimum_versions(dependencies, python_version): min_versions = {} for dependency in dependencies: diff --git a/tests/_external/__init__.py b/tests/_external/__init__.py new file mode 100644 index 000000000..dee965eb8 --- /dev/null +++ b/tests/_external/__init__.py @@ -0,0 +1 @@ +"""External utility functions.""" diff --git a/tests/_external/gdrive_utils.py b/tests/_external/gdrive_utils.py new file mode 100644 index 000000000..332787b77 --- /dev/null +++ b/tests/_external/gdrive_utils.py @@ -0,0 +1,140 @@ +"""Google Drive utils.""" + +import io +import json +import os +import pathlib +import tempfile +from datetime import date + +import git +import pandas as pd +import yaml +from pydrive.auth import GoogleAuth +from pydrive.drive import GoogleDrive + +PYDRIVE_CREDENTIALS = 'PYDRIVE_CREDENTIALS' + + +def _generate_filename(): + """Generate a filename with today's date and the commit id.""" + repo = git.Repo(search_parent_directories=True) + commit_id = repo.head.object.hexsha + today = str(date.today()) + return f'{today}-{commit_id}.xlsx' + + +def _get_drive_client(): + tmp_credentials = os.getenv(PYDRIVE_CREDENTIALS) + if not tmp_credentials: + gauth = GoogleAuth() + gauth.LocalWebserverAuth() + else: + with tempfile.TemporaryDirectory() as tempdir: + credentials_file_path = pathlib.Path(tempdir) / 'credentials.json' + credentials_file_path.write_text(tmp_credentials) + + credentials = json.loads(tmp_credentials) + + settings = { + 'client_config_backend': 'settings', + 'client_config': { + 'client_id': credentials['client_id'], + 'client_secret': credentials['client_secret'], + }, + 'save_credentials': True, + 'save_credentials_backend': 'file', + 'save_credentials_file': str(credentials_file_path), + 'get_refresh_token': True, + } + settings_file = pathlib.Path(tempdir) / 'settings.yaml' + settings_file.write_text(yaml.safe_dump(settings)) + + gauth = GoogleAuth(str(settings_file)) + gauth.LocalWebserverAuth() + + return GoogleDrive(gauth) + + +def get_latest_file(folder_id): + """Get the latest file from the given Google Drive folder. + + Args: + folder (str): + The string Google Drive folder ID. + """ + drive = _get_drive_client() + drive_query = drive.ListFile({ + 'q': f"'{folder_id}' in parents and trashed=False", + 'orderBy': 'modifiedDate desc', + 'maxResults': 1, + }) + file_list = drive_query.GetList() + if len(file_list) > 0: + return file_list[0] + + +def read_excel(file_id): + """Read a file as an XLSX from Google Drive. + + Args: + file_id (str): + The ID of the file to load. + + Returns: + pd.DataFrame or dict[pd.DataFrame]: + A DataFrame containing the body of file if single sheet else dict of DataFrames one for + each sheet + + """ + client = _get_drive_client() + drive_file = client.CreateFile({'id': file_id}) + xlsx_mime = 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' + drive_file.FetchContent(mimetype=xlsx_mime) + return pd.read_excel(drive_file.content, sheet_name=None) + + +def _set_column_width(writer, results, sheet_name): + for column in results: + column_width = max(results[column].astype(str).map(len).max(), len(column)) + col_idx = results.columns.get_loc(column) + writer.sheets[sheet_name].set_column(col_idx, col_idx, column_width + 2) + + +def save_to_gdrive(output_folder, results, output_filename=None): + """Save a ``DataFrame`` to google drive folder as ``xlsx`` (spreadsheet). + + Given the output folder id (google drive folder id), store the given ``results`` as + ``spreadsheet``. If not ``output_filename`` is given, the spreadsheet is saved with the + current date and commit as name. + + Args: + output_folder (str): + String representing a google drive folder id. + results (pd.DataFrame or dict[pd.DataFrame]): + Dataframe to be stored as ``xlsx``, or dictionary mapping sheet names to dataframes for + storage in one ``xlsx`` file. + output_filename (str, optional): + String representing the filename to be used for the results spreadsheet. If None, + uses to the current date and commit as the name. Defaults to None. + + Returns: + str: + Google drive file id of uploaded file. + """ + if not output_filename: + output_filename = _generate_filename() + + output = io.BytesIO() + + with pd.ExcelWriter(output, engine='xlsxwriter') as writer: # pylint: disable=E0110 + for sheet_name, data in results.items(): + data.to_excel(writer, sheet_name=sheet_name, index=False) + _set_column_width(writer, data, sheet_name) + + file_config = {'title': output_filename, 'parents': [{'id': output_folder}]} + drive = _get_drive_client() + drive_file = drive.CreateFile(file_config) + drive_file.content = output + drive_file.Upload({'convert': True}) + return drive_file['id'] diff --git a/tests/_external/slack_utils.py b/tests/_external/slack_utils.py new file mode 100644 index 000000000..abc6c9648 --- /dev/null +++ b/tests/_external/slack_utils.py @@ -0,0 +1,65 @@ +"""Utility functions for Slack integration.""" + +import os + +from slack_sdk import WebClient + + +def _get_slack_client(): + """Create an authenticated Slack client. + + Returns: + WebClient: + An authenticated Slack WebClient instance. + """ + token = os.getenv('SLACK_TOKEN') + client = WebClient(token=token) + return client + + +def post_slack_message(channel, text): + """Post a message to a Slack channel. + + Args: + channel (str): + The name of the channel to post to. + text (str): + The message to send to the channel. + + Returns: + SlackResponse: + Response from Slack API call + """ + client = _get_slack_client() + response = client.chat_postMessage(channel=channel, text=text) + if not response['ok']: + error = response.get('error', 'unknown_error') + msg = f'{error} occured trying to post message to {channel}' + raise RuntimeError(msg) + + return response + + +def post_slack_message_in_thread(channel, text, thread_ts): + """Post a message as a threaded reply in a Slack channel. + + Args: + channel (str): + The name of the channel to post to. + text (str): + The message to send as a reply in the thread. + thread_ts (str): + The timestamp of the message that starts the thread. + + Returns: + SlackResponse: + Response from Slack API call. + """ + client = _get_slack_client() + response = client.chat_postMessage(channel=channel, text=text, thread_ts=thread_ts) + if not response['ok']: + error = response.get('error', 'unknown_error') + msg = f'{error} occurred trying to post threaded message to {channel}' + raise RuntimeError(msg) + + return response diff --git a/tests/benchmark/__init__.py b/tests/benchmark/__init__.py new file mode 100644 index 000000000..c93ceaf9f --- /dev/null +++ b/tests/benchmark/__init__.py @@ -0,0 +1 @@ +"""SDV benchmarking module.""" diff --git a/tests/benchmark/excluded_tests.py b/tests/benchmark/excluded_tests.py new file mode 100644 index 000000000..78e3e2458 --- /dev/null +++ b/tests/benchmark/excluded_tests.py @@ -0,0 +1,123 @@ +"""Excluded tests from constraints due to hard crashing from NumPy or Pandas.""" + +EXCLUDED_CONSTRAINT_TESTS = [ + ('numerical', 'pd.boolean', 'FixedIncrements'), + ('numerical', 'pd.object', 'Positive'), + ('numerical', 'pd.object', 'Negative'), + ('numerical', 'pd.object', 'ScalarInequality'), + ('numerical', 'pd.object', 'ScalarRange'), + ('numerical', 'pd.string', 'Positive'), + ('numerical', 'pd.string', 'Negative'), + ('numerical', 'pd.string', 'ScalarInequality'), + ('numerical', 'pd.category', 'Positive'), + ('numerical', 'pd.category', 'Negative'), + ('numerical', 'pd.category', 'ScalarInequality'), + ('numerical', 'pd.category', 'ScalarRange'), + ('numerical', 'pd.datetime64', 'Positive'), + ('numerical', 'pd.datetime64', 'Negative'), + ('numerical', 'pd.datetime64', 'ScalarInequality'), + ('numerical', 'pd.timedelta64', 'Positive'), + ('numerical', 'pd.timedelta64', 'Negative'), + ('numerical', 'pd.timedelta64', 'ScalarInequality'), + ('numerical', 'pd.Period', 'Positive'), + ('numerical', 'pd.Period', 'Negative'), + ('numerical', 'pd.Period', 'ScalarInequality'), + ('numerical', 'pd.Period', 'FixedIncrements'), + ('numerical', 'np.object', 'Positive'), + ('numerical', 'np.object', 'Negative'), + ('numerical', 'np.object', 'ScalarInequality'), + ('numerical', 'np.string', 'Positive'), + ('numerical', 'np.string', 'Negative'), + ('numerical', 'np.string', 'ScalarInequality'), + ('numerical', 'np.unicode', 'Positive'), + ('numerical', 'np.unicode', 'Negative'), + ('numerical', 'np.unicode', 'ScalarInequality'), + ('numerical', 'np.datetime64', 'Positive'), + ('numerical', 'np.datetime64', 'Negative'), + ('numerical', 'np.datetime64', 'ScalarInequality'), + ('numerical', 'np.timedelta64', 'Positive'), + ('numerical', 'np.timedelta64', 'Negative'), + ('numerical', 'np.timedelta64', 'ScalarInequality'), + ('numerical', 'pa.string', 'Positive'), + ('numerical', 'pa.string', 'Negative'), + ('numerical', 'pa.string', 'ScalarInequality'), + ('numerical', 'pa.utf8', 'Positive'), + ('numerical', 'pa.utf8', 'Negative'), + ('numerical', 'pa.utf8', 'ScalarInequality'), + ('numerical', 'pa.binary', 'Positive'), + ('numerical', 'pa.binary', 'Negative'), + ('numerical', 'pa.binary', 'ScalarInequality'), + ('numerical', 'pa.binary', 'FixedIncrements'), + ('numerical', 'pa.large_binary', 'Positive'), + ('numerical', 'pa.large_binary', 'Negative'), + ('numerical', 'pa.large_binary', 'ScalarInequality'), + ('numerical', 'pa.large_binary', 'FixedIncrements'), + ('numerical', 'pa.large_string', 'Positive'), + ('numerical', 'pa.large_string', 'Negative'), + ('numerical', 'pa.large_string', 'ScalarInequality'), + ('numerical', 'pa.date32', 'Positive'), + ('numerical', 'pa.date32', 'Negative'), + ('numerical', 'pa.date32', 'ScalarInequality'), + ('numerical', 'pa.date64', 'Positive'), + ('numerical', 'pa.date64', 'Negative'), + ('numerical', 'pa.date64', 'ScalarInequality'), + ('numerical', 'pa.timestamp', 'Positive'), + ('numerical', 'pa.timestamp', 'Negative'), + ('numerical', 'pa.timestamp', 'ScalarInequality'), + ('numerical', 'pa.duration', 'Positive'), + ('numerical', 'pa.duration', 'Negative'), + ('numerical', 'pa.duration', 'ScalarInequality'), + ('numerical', 'pa.time32', 'Positive'), + ('numerical', 'pa.time32', 'Negative'), + ('numerical', 'pa.time32', 'ScalarInequality'), + ('numerical', 'pa.time64', 'Positive'), + ('numerical', 'pa.time64', 'Negative'), + ('numerical', 'pa.time64', 'ScalarInequality'), + ('numerical', 'pa.binary_view', 'Positive'), + ('numerical', 'pa.binary_view', 'Negative'), + ('numerical', 'pa.binary_view', 'ScalarInequality'), + ('numerical', 'pa.binary_view', 'FixedIncrements'), + ('numerical', 'pa.string_view', 'Positive'), + ('numerical', 'pa.string_view', 'Negative'), + ('numerical', 'pa.string_view', 'ScalarInequality'), + ('datetime', 'pd.object', 'ScalarRange'), + ('datetime', 'pd.category', 'ScalarRange'), + ('numerical', 'pd.category', 'Inequality'), + ('numerical', 'pd.category', 'Range'), + ('numerical', 'pd.datetime64', 'Inequality'), + ('numerical', 'pd.datetime64', 'Range'), + ('numerical', 'pd.Period', 'Inequality'), + ('numerical', 'pd.Period', 'Range'), + ('numerical', 'np.datetime64', 'Inequality'), + ('numerical', 'np.datetime64', 'Range'), + ('numerical', 'pa.bool', 'Inequality'), + ('numerical', 'pa.bool', 'Range'), + ('numerical', 'pa.large_binary', 'Inequality'), + ('numerical', 'pa.large_binary', 'Range'), + ('numerical', 'pa.date32', 'Inequality'), + ('numerical', 'pa.date32', 'Range'), + ('numerical', 'pa.date64', 'Inequality'), + ('numerical', 'pa.date64', 'Range'), + ('numerical', 'pa.timestamp', 'Inequality'), + ('numerical', 'pa.timestamp', 'Range'), + ('numerical', 'pa.time32', 'Inequality'), + ('numerical', 'pa.time32', 'Range'), + ('numerical', 'pa.time64', 'Inequality'), + ('numerical', 'pa.time64', 'Range'), + ('numerical', 'pa.string', 'FixedIncrements'), + ('numerical', 'pa.utf8', 'FixedIncrements'), + ('numerical', 'pa.large_string', 'FixedIncrements'), + ('numerical', 'pa.string_view', 'FixedIncrements'), + ('numerical', 'pa.string', 'Inequality'), + ('numerical', 'pa.string', 'Range'), + ('numerical', 'pa.utf8', 'Inequality'), + ('numerical', 'pa.utf8', 'Range'), + ('numerical', 'pa.binary', 'Inequality'), + ('numerical', 'pa.binary', 'Range'), + ('numerical', 'pa.large_string', 'Inequality'), + ('numerical', 'pa.large_string', 'Range'), + ('numerical', 'pa.binary_view', 'Inequality'), + ('numerical', 'pa.binary_view', 'Range'), + ('numerical', 'pa.string_view', 'Inequality'), + ('numerical', 'pa.string_view', 'Range'), +] diff --git a/tests/benchmark/numpy_dtypes.py b/tests/benchmark/numpy_dtypes.py new file mode 100644 index 000000000..fb46008d8 --- /dev/null +++ b/tests/benchmark/numpy_dtypes.py @@ -0,0 +1,87 @@ +import numpy as np +import pandas as pd + +NUMPY_DTYPES = { + 'np.int8': pd.DataFrame({ + 'np.int8': pd.Series([np.int8(1), np.int8(-1), np.int8(127)], dtype='int8') + }), + 'np.int16': pd.DataFrame({ + 'np.int16': pd.Series([np.int16(2), np.int16(-2), np.int16(32767)], dtype='int16') + }), + 'np.int32': pd.DataFrame({ + 'np.int32': pd.Series([np.int32(3), np.int32(-3), np.int32(2147483647)], dtype='int32') + }), + 'np.int64': pd.DataFrame({ + 'np.int64': pd.Series([np.int64(4), np.int64(-4), np.int64(922)], dtype='int64') + }), + 'np.uint8': pd.DataFrame({ + 'np.uint8': pd.Series([np.uint8(5), np.uint8(10), np.uint8(255)], dtype='uint8') + }), + 'np.uint16': pd.DataFrame({ + 'np.uint16': pd.Series([np.uint16(6), np.uint16(20), np.uint16(65535)], dtype='uint16') + }), + 'np.uint32': pd.DataFrame({ + 'np.uint32': pd.Series([np.uint32(7), np.uint32(30), np.uint32(42)], dtype='uint32') + }), + 'np.uint64': pd.DataFrame({ + 'np.uint64': pd.Series([np.uint64(8), np.uint64(40), np.uint64(184467)], dtype='uint64') + }), + 'np.float16': pd.DataFrame({ + 'np.float16': pd.Series( + [np.float16(9.1), np.float16(-9.1), np.float16(65.0)], dtype='float16' + ) + }), + 'np.float32': pd.DataFrame({ + 'np.float32': pd.Series( + [np.float32(1.2), np.float32(-1.2), np.float32(3.40)], dtype='float32' + ) + }), + 'np.float64': pd.DataFrame({ + 'np.float64': pd.Series( + [np.float64(1.3), np.float64(-11.3), np.float64(1.7)], dtype='float64' + ) + }), + 'np.complex64': pd.DataFrame({ + 'np.complex64': pd.Series( + [np.complex64(12 + 1j), np.complex64(-12 - 1j), np.complex64(3.4e38 + 1j)], + dtype='complex64', + ) + }), + 'np.complex128': pd.DataFrame({ + 'np.complex128': pd.Series( + [np.complex128(13 + 2j), np.complex128(-13 - 2j), np.complex128(1.7e308 + 2j)], + dtype='complex128', + ) + }), + 'np.bool': pd.DataFrame({ + 'np.bool': pd.Series([np.bool_(True), np.bool_(False), np.bool_(True)], dtype='bool') + }), + 'np.object': pd.DataFrame({ + 'np.object': pd.Series(['object1', 'object2', 'object3'], dtype='object') + }), + 'np.string': pd.DataFrame({ + 'np.string': pd.Series([ + np.string_('string1'), + np.string_('string2'), + np.string_('string3'), + ]) + }), + 'np.unicode': pd.DataFrame({ + 'np.unicode': pd.Series( + [np.unicode_('unicode1'), np.unicode_('unicode2'), np.unicode_('unicode3')], + dtype='string', + ) + }), + 'np.datetime64': pd.DataFrame({ + 'np.datetime64': pd.Series([ + np.datetime64('2023-01-01T00:00:00'), + np.datetime64('2024-01-01T00:00:00'), + np.datetime64('2025-01-01T00:00:00'), + ]) + }), + 'np.timedelta64': pd.DataFrame({ + 'np.timedelta64': pd.Series( + [np.timedelta64(1, 'D'), np.timedelta64(2, 'h'), np.timedelta64(3, 'm')], + ) + }), +} diff --git a/tests/benchmark/pandas_dtypes.py b/tests/benchmark/pandas_dtypes.py new file mode 100644 index 000000000..3b7b62809 --- /dev/null +++ b/tests/benchmark/pandas_dtypes.py @@ -0,0 +1,41 @@ +import pandas as pd + +PANDAS_DTYPES = { + 'pd.Int8': pd.DataFrame({'pd.Int8': pd.Series([1, 2, -3, None, 4, 5], dtype='Int8')}), + 'pd.Int16': pd.DataFrame({'pd.Int16': pd.Series([1, 2, -3, None, 4, 5], dtype='Int16')}), + 'pd.Int32': pd.DataFrame({'pd.Int32': pd.Series([1, 2, -3, None, 4, 5], dtype='Int32')}), + 'pd.Int64': pd.DataFrame({'pd.Int64': pd.Series([1, 2, -3, None, 4, 5], dtype='Int64')}), + 'pd.UInt8': pd.DataFrame({'pd.UInt8': pd.Series([1, 2, 3, None, 4, 5], dtype='UInt8')}), + 'pd.UInt16': pd.DataFrame({'pd.UInt16': pd.Series([1, 2, 3, None, 4, 5], dtype='UInt16')}), + 'pd.UInt32': pd.DataFrame({'pd.UInt32': pd.Series([1, 2, 3, None, 4, 5], dtype='UInt32')}), + 'pd.UInt64': pd.DataFrame({'pd.UInt64': pd.Series([1, 2, 3, None, 4, 5], dtype='UInt64')}), + 'pd.Float32': pd.DataFrame({ + 'pd.Float32': pd.Series([1.1, 1.2, 1.3, 1.4, None], dtype='Float32') + }), + 'pd.Float64': pd.DataFrame({ + 'pd.Float64': pd.Series([1.1, 1.2, 1.3, 1.4, None], dtype='Float64') + }), + 'pd.boolean': pd.DataFrame({ + 'pd.boolean': pd.Series([True, False, None, True, False], dtype='boolean') + }), + 'pd.object': pd.DataFrame({'pd.object': pd.Series(['A', 'B', None, 'C'], dtype='object')}), + 'pd.string': pd.DataFrame({'pd.string': pd.Series(['A', 'B', None, 'C'], dtype='string')}), + 'pd.category': pd.DataFrame({ + 'pd.category': pd.Series(['A', 'B', None, 'D'], dtype='category') + }), + 'pd.datetime64': pd.DataFrame({ + 'pd.datetime64': pd.Series(pd.date_range('2023-01-01', periods=3), dtype='datetime64[ns]') + }), + 'pd.timedelta64': pd.DataFrame({ + 'pd.timedelta64': pd.Series( + [pd.Timedelta(days=1), pd.Timedelta(days=2), pd.Timedelta(days=3)], + dtype='timedelta64[ns]', + ) + }), + 'pd.Period': pd.DataFrame({ + 'pd.Period': pd.Series(pd.period_range('2023-01', periods=3, freq='M')), + }), + 'pd.Complex': pd.DataFrame({ + 'pd.Complex': pd.Series([1 + 1j, 2 + 2j, 3 + 3j], dtype='complex128'), + }), +} diff --git a/tests/benchmark/pyarrow_dtypes.py b/tests/benchmark/pyarrow_dtypes.py new file mode 100644 index 000000000..476be7732 --- /dev/null +++ b/tests/benchmark/pyarrow_dtypes.py @@ -0,0 +1,126 @@ +import decimal + +import pandas as pd +import pyarrow as pa + +PYARROW_DTYPES = { + 'pa.int8': pd.DataFrame({ + 'pa.int8': pd.Series([1, 2, -3, None, 4, 5], dtype=pd.ArrowDtype(pa.int8())) + }), + 'pa.int16': pd.DataFrame({ + 'pa.int16': pd.Series([1, 2, -3, None, 4, 5], dtype=pd.ArrowDtype(pa.int16())) + }), + 'pa.int32': pd.DataFrame({ + 'pa.int32': pd.Series([1, 2, -3, None, 4, 5], dtype=pd.ArrowDtype(pa.int32())) + }), + 'pa.int64': pd.DataFrame({ + 'pa.int64': pd.Series([1, 2, -3, None, 4, 5], dtype=pd.ArrowDtype(pa.int64())) + }), + 'pa.uint8': pd.DataFrame({ + 'pa.uint8': pd.Series([1, 2, 3, None, 4, 5], dtype=pd.ArrowDtype(pa.uint8())) + }), + 'pa.uint16': pd.DataFrame({ + 'pa.uint16': pd.Series([1, 2, 3, None, 4, 5], dtype=pd.ArrowDtype(pa.uint16())) + }), + 'pa.uint32': pd.DataFrame({ + 'pa.uint32': pd.Series([1, 2, 3, None, 4, 5], dtype=pd.ArrowDtype(pa.uint32())) + }), + 'pa.uint64': pd.DataFrame({ + 'pa.uint64': pd.Series([1, 2, 3, None, 4, 5], dtype=pd.ArrowDtype(pa.uint64())) + }), + 'pa.float32': pd.DataFrame({ + 'pa.float32': pd.Series([1.1, 1.2, 1.3, None, 1.4], dtype=pd.ArrowDtype(pa.float32())) + }), + 'pa.float64': pd.DataFrame({ + 'pa.float64': pd.Series([1.1, 1.2, 1.3, None, 1.4], dtype=pd.ArrowDtype(pa.float64())) + }), + 'pa.bool': pd.DataFrame({ + 'pa.bool': pd.Series([True, False, None, True, False], dtype=pd.ArrowDtype(pa.bool_())) + }), + 'pa.string': pd.DataFrame({ + 'pa.string': pd.Series(['A', 'B', None, 'C'], dtype=pd.ArrowDtype(pa.string())) + }), + 'pa.utf8': pd.DataFrame({ + 'pa.utf8': pd.Series(['A', 'B', None, 'C'], dtype=pd.ArrowDtype(pa.utf8())) + }), + 'pa.binary': pd.DataFrame({ + 'pa.binary': pd.Series( + [b'binary1', b'binary2', None, b'binary3'], dtype=pd.ArrowDtype(pa.binary()) + ) + }), + 'pa.large_binary': pd.DataFrame({ + 'pa.large_binary': pd.Series( + [b'large_binary1', b'large_binary2', None, b'large_binary3'], + dtype=pd.ArrowDtype(pa.large_binary()), + ) + }), + 'pa.large_string': pd.DataFrame({ + 'pa.large_string': pd.Series(['A', 'B', None, 'C'], dtype=pd.ArrowDtype(pa.large_string())) + }), + 'pa.date32': pd.DataFrame({ + 'pa.date32': pd.Series( + [pd.Timestamp('2023-01-01'), pd.Timestamp('2024-01-01'), None], + dtype=pd.ArrowDtype(pa.date32()), + ) + }), + 'pa.date64': pd.DataFrame({ + 'pa.date64': pd.Series( + [pd.Timestamp('2023-01-01'), pd.Timestamp('2024-01-01'), None], + dtype=pd.ArrowDtype(pa.date64()), + ) + }), + 'pa.timestamp': pd.DataFrame({ + 'pa.timestamp': pd.Series( + [pd.Timestamp('2023-01-01T00:00:00'), pd.Timestamp('2024-01-01T00:00:00'), None], + dtype=pd.ArrowDtype(pa.timestamp('ms')), + ) + }), + 'pa.duration': pd.DataFrame({ + 'pa.duration': pd.Series( + [pd.Timedelta(days=1), pd.Timedelta(hours=2), None], + dtype=pd.ArrowDtype(pa.duration('s')), + ) + }), + 'pa.time32': pd.DataFrame({ + 'pa.time32': pd.Series( + [ + pd.Timestamp('2023-01-01T01:00:00').time(), + pd.Timestamp('2023-01-01T02:00:00').time(), + None, + ], + dtype=pd.ArrowDtype(pa.time32('s')), + ) + }), + 'pa.time64': pd.DataFrame({ + 'pa.time64': pd.Series( + [ + pd.Timestamp('2023-01-01T01:00:00').time(), + pd.Timestamp('2023-01-01T02:00:00').time(), + None, + ], + dtype=pd.ArrowDtype(pa.time64('ns')), + ) + }), + 'pa.binary_view': pd.DataFrame({ + 'pa.binary_view': pd.Series( + [b'view1', b'view2', None, b'view3'], dtype=pd.ArrowDtype(pa.binary()) + ) + }), + 'pa.string_view': pd.DataFrame({ + 'pa.string_view': pd.Series(['A', 'B', None, 'C'], dtype=pd.ArrowDtype(pa.string())) + }), + 'pa.decimal128': pd.DataFrame({ + 'pa.decimal128': pd.Series( + [ + decimal.Decimal('123.45'), + decimal.Decimal('88.90'), + decimal.Decimal('78.90'), + decimal.Decimal('98.90'), + decimal.Decimal('678.90'), + decimal.Decimal('6.90'), + None, + ], + dtype=pd.ArrowDtype(pa.decimal128(precision=10, scale=2)), + ) + }), +} diff --git a/tests/benchmark/supported_dtypes_benchmark.py b/tests/benchmark/supported_dtypes_benchmark.py new file mode 100644 index 000000000..adce06b6f --- /dev/null +++ b/tests/benchmark/supported_dtypes_benchmark.py @@ -0,0 +1,579 @@ +"""Benchmark for supported data types.""" + +import contextlib +import logging +from copy import deepcopy +from functools import partialmethod + +import numpy as np +import pandas as pd +import pytest +from tqdm import tqdm + +from sdv.metadata import SingleTableMetadata +from sdv.single_table import GaussianCopulaSynthesizer +from tests.benchmark.excluded_tests import EXCLUDED_CONSTRAINT_TESTS +from tests.benchmark.numpy_dtypes import NUMPY_DTYPES +from tests.benchmark.pandas_dtypes import PANDAS_DTYPES +from tests.benchmark.pyarrow_dtypes import PYARROW_DTYPES +from tests.benchmark.utils import get_previous_dtype_result, save_results_to_json + +LOGGER = logging.getLogger(__name__) + +SINGLE_COLUMN_PREDEFINED_CONSTRAINTS = { + 'Positive': { + 'constraint_class': 'Positive', + 'constraint_parameters': {'column_name': '', 'strict_boundaries': False}, + }, + 'Negative': { + 'constraint_class': 'Negative', + 'constraint_parameters': {'column_name': '', 'strict_boundaries': False}, + }, + 'ScalarInequality': { + 'constraint_class': 'ScalarInequality', + 'constraint_parameters': {'column_name': '', 'relation': '>=', 'value': 0}, + }, + 'ScalarRange': { + 'constraint_class': 'ScalarRange', + 'constraint_parameters': { + 'column_name': '', + 'low_value': 0, + 'high_value': 1, + 'strict_boundaries': False, + }, + }, + 'FixedIncrements': { + 'constraint_class': 'FixedIncrements', + 'constraint_parameters': { + 'column_name': '', + 'increment_value': 1, + }, + }, +} + +MULTI_COLUMN_PREDEFINED_CONSTRAINTS = { + 'FixedCombinations': { + 'constraint_class': 'FixedCombinations', + 'constraint_parameters': { + 'column_names': [], + }, + }, + 'Inequality': { + 'constraint_class': 'Inequality', + 'constraint_parameters': { + 'low_column_name': '', + 'high_column_name': '', + 'strict_boundaries': False, + }, + }, + 'Range': { + 'constraint_class': 'Range', + 'constraint_parameters': { + 'low_column_name': '', + 'middle_column_name': '', + 'high_column_name': '', + 'strict_boundaries': False, + }, + }, +} + + +METADATA_SDTYPES = ('numerical', 'id', 'datetime', 'categorical') + + +EXPECTED_METADATA_SDTYPES = { + # Pandas + 'pd.Int8': 'numerical', + 'pd.Int16': 'numerical', + 'pd.Int32': 'numerical', + 'pd.Int64': 'numerical', + 'pd.UInt8': 'numerical', + 'pd.UInt16': 'numerical', + 'pd.UInt32': 'numerical', + 'pd.UInt64': 'numerical', + 'pd.Float32': 'numerical', + 'pd.Float64': 'numerical', + 'pd.datetime64': 'datetime', + 'pd.boolean': 'categorical', + 'pd.object': 'categorical', + 'pd.category': 'categorical', + 'pd.string': 'categorical', + 'pd.timedelta64': 'datetime', + 'pd.Period': 'datetime', + 'pd.Complex': 'numerical', + # NumPy + 'np.int8': 'numerical', + 'np.int16': 'numerical', + 'np.int32': 'numerical', + 'np.int64': 'numerical', + 'np.uint8': 'numerical', + 'np.uint16': 'numerical', + 'np.uint32': 'numerical', + 'np.uint64': 'numerical', + 'np.float16': 'numerical', + 'np.float32': 'numerical', + 'np.float64': 'numerical', + 'np.complex64': 'numerical', + 'np.complex128': 'numerical', + 'np.datetime64': 'datetime', + 'np.timedelta64': 'datetime', + 'np.object': 'categorical', + 'np.bool': 'categorical', + 'np.string': 'categorical', + 'np.unicode': 'categorical', + # PyArrow + 'pa.int8': 'numerical', + 'pa.int16': 'numerical', + 'pa.int32': 'numerical', + 'pa.int64': 'numerical', + 'pa.uint8': 'numerical', + 'pa.uint16': 'numerical', + 'pa.uint32': 'numerical', + 'pa.uint64': 'numerical', + 'pa.float32': 'numerical', + 'pa.float64': 'numerical', + 'pa.bool': 'categorical', + 'pa.string': 'categorical', + 'pa.utf8': 'categorical', + 'pa.binary': 'categorical', + 'pa.large_binary': 'categorical', + 'pa.large_string': 'categorical', + 'pa.binary_view': 'categorical', + 'pa.string_view': 'categorical', + 'pa.date32': 'datetime', + 'pa.date64': 'datetime', + 'pa.timestamp': 'datetime', + 'pa.duration': 'datetime', + 'pa.time32': 'datetime', + 'pa.time64': 'datetime', + 'pa.decimal128': 'numerical', +} + + +@contextlib.contextmanager +def prevent_tqdm_output(): + """Temporarily disables tqdm for the conditional sampling.""" + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) + try: + yield + finally: + tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) + + +def _get_metadata_for_dtype_and_sdtype(dtype, sdtype): + """Return the expected metadata.""" + metadata = SingleTableMetadata.load_from_dict({'columns': {dtype: {'sdtype': sdtype}}}) + return metadata + + +@pytest.mark.parametrize('dtype, data', {**PANDAS_DTYPES, **NUMPY_DTYPES, **PYARROW_DTYPES}.items()) +@pytest.mark.parametrize('sdtype', METADATA_SDTYPES) +def test_metadata_detection(dtype, data, sdtype): + """Test metadata detection for data types using `SingleTableMetadata`. + + This test checks the ability of the `SingleTableMetadata` class to detect + metadata from data types coming from `Pandas` and `NumPy`. It compares the + detected metadata against expected results. + + Args: + dtype (str): + The data type to test. + data (pd.DataFrame): + The data for which metadata detection is performed. + + Raises: + AssertionError: + If the detected metadata is incorrect or the dtype is no longer supported. + + Test flow: + 1. Initialize `SingleTableMetadata`. + 2. Attempt to detect metadata from the provided data. + 3. Assert if the sdtype matches the expected one. + """ + metadata = SingleTableMetadata() + previous_result, _ = get_previous_dtype_result(dtype, sdtype, 'METADATA_DETECTION') + result = False + try: + metadata.detect_from_dataframe(data) + column = metadata.columns.get(dtype) + detected_sdtype = column.get('sdtype') + result = detected_sdtype == EXPECTED_METADATA_SDTYPES.get(dtype) + except BaseException as e: + LOGGER.debug(f"Error during 'metadata.validate_data' with dtype '{dtype}': {e}") + + assertion_message = f"{dtype} is no longer supported in 'METADATA_DETECTION'." + save_results_to_json({'dtype': dtype, 'sdtype': sdtype, 'METADATA_DETECTION': result}) + if result is False: + assert result == previous_result, assertion_message + + +@pytest.mark.parametrize('dtype, data', {**PANDAS_DTYPES, **NUMPY_DTYPES, **PYARROW_DTYPES}.items()) +@pytest.mark.parametrize('sdtype', METADATA_SDTYPES) +def test_metadata_validate_data(dtype, data, sdtype): + """Test the validation of data using `SingleTableMetadata`. + + This test checks whether the `validate_data` method of the metadata object + properly validates the given data for different data types coming from + `Pandas` and `NumPy`. + + Args: + dtype (str): + The data type to test. + data (pd.DataFrame): + The data for which metadata validation is performed. + + Raises: + AssertionError: + If the validation result does not match the previously recorded result + or if the dtype is no longer supported. + + Test flow: + 1. Create a predefined `SingleTableMetadata` for the given dtype. + 2. Attempt to validate the data using `metadata.validate_data` for the provided data. + 3. Assert if the result is as expected. + """ + metadata = _get_metadata_for_dtype_and_sdtype(dtype, sdtype) + previous_result, _ = get_previous_dtype_result(dtype, sdtype, 'METADATA_VALIDATE_DATA') + result = False + try: + metadata.validate_data(data) + result = True + except BaseException as e: + LOGGER.debug(f"Error during 'metadata.validate_data' with dtype '{dtype}': {e}") + + save_results_to_json({'dtype': dtype, 'sdtype': sdtype, 'METADATA_VALIDATE_DATA': result}) + if result is False: + assertion_message = f"{dtype} is no longer supported by 'METADATA_VALIDATE_DATA'." + assert result == previous_result, assertion_message + + +@pytest.mark.parametrize('dtype, data', {**PANDAS_DTYPES, **NUMPY_DTYPES, **PYARROW_DTYPES}.items()) +@pytest.mark.parametrize('sdtype', METADATA_SDTYPES) +def test_fit_and_sample_synthesizer(dtype, data, sdtype): + """Test fitting and sampling a synthesizer for different data types. + + This test evaluates the `GaussianCopulaSynthesizer` to fit and + sample data for various data types from `Pandas` and `NumPy`. + It verifies that the synthesizer can successfully be fitted to the + data and generate synthetic data with matching data types. + The results are compared with previously recorded outcomes for both + fitting and sampling. + + Args: + dtype (str): + The data type to test. + data (pd.DataFrame): + The data for which the fitting and sampling is performed. + + Raises: + AssertionError: + If the fit or sample results do not match previously recorded results + or if the dtype is no longer supported. + + The test flow includes: + 1. Initializing the `GaussianCopulaSynthesizer` with the appropriate metadata. + 2. Compare the current fit result against previously recorded results. + 3. Using the synthesizer to sample data, compare the synthetic data types if they match the + input. + """ + metadata = _get_metadata_for_dtype_and_sdtype(dtype, sdtype) + synthesizer = GaussianCopulaSynthesizer(metadata) + previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, 'SYNTHESIZER_FIT') + previous_sample_result, _ = get_previous_dtype_result(dtype, sdtype, 'SYNTHESIZER_SAMPLE') + fit_result = False + sample_result = False + + try: + synthesizer.fit(data) + fit_result = True + with prevent_tqdm_output(): + synthetic_data = synthesizer.sample(10) + + sample_result = synthetic_data.dtypes[dtype] == data.dtypes[dtype] + + except BaseException as e: + LOGGER.debug(f"Error during fitting/sampling with dtype '{dtype}': {e}") + + save_results_to_json({ + 'dtype': dtype, + 'sdtype': sdtype, + 'SYNTHESIZER_FIT': fit_result, + 'SYNTHESIZER_SAMPLE': sample_result, + }) + fit_assertion_message = f"{dtype} is no longer supported by 'SYNTHESIZER_FIT'." + if fit_result is False: + assert fit_result == previous_fit_result, fit_assertion_message + + sample_assertion_message = f"{dtype} is no longer supported by 'SYNTHESIZER_SAMPLE'." + if sample_result is False: + assert sample_result == previous_sample_result, sample_assertion_message + + +def convert_values(value, inequality): + """Convert the given value based on the specified inequality. + + This function checks the provided value and applies a conversion based on + the inequality function. If the value satisfies the inequality condition + when compared to 0, it multiplies the value by -1. If the value is `None`, + it returns `None`. + + Args: + value (numeric): + The value to be checked and potentially converted. It can be any numeric type or `None`. + inequality (function): + A comparison function (e.g., `operator.gt` or `operator.lt`) used to compare the + value with 0. + + Returns: + numeric or None: + The converted value if the inequality holds, or the original value otherwise. + Returns `None` if the value is `None`. + """ + if pd.isna(value): + return None + + if inequality(value, 0): + return value * -1 + + return value + + +def _create_single_column_constraint_and_data(constraint, data, dtype, sdtype): + constraint_class = constraint.get('constraint_class') + _dtype = data.dtypes[dtype] + constraint['constraint_parameters']['column_name'] = dtype + + if constraint_class == 'Positive' and sdtype == 'numerical': + data[dtype] = data[dtype].apply(convert_values, inequality=np.less) + elif constraint_class == 'Negative' and sdtype == 'numerical': + data[dtype] = data[dtype].apply(convert_values, inequality=np.greater) + elif constraint_class == 'ScalarInequality': + lower = 0 + if sdtype == 'numerical': + data[dtype] = data[dtype].apply(convert_values, inequality=np.less) + + elif sdtype == 'datetime': + # Make the lowest date to be 1971-01-01 + lower = '1971-01-01' + + constraint['constraint_parameters']['value'] = lower + + elif constraint_class == 'ScalarRange': + if sdtype in ('numerical', 'datetime'): + low_value = data[dtype].min() + high_value = data[dtype].max() + constraint['constraint_parameters']['low_value'] = low_value + constraint['constraint_parameters']['high_value'] = high_value + + elif constraint_class == 'FixedIncrements': + if sdtype == 'numerical': + values = [10, 20, 30, 40] + if dtype.startswith('pd'): + values.append(None) + + data[dtype] = pd.Series(values, dtype=_dtype) + constraint['constraint_parameters']['increment_value'] = 10 + + return constraint, data + + +def _create_multi_column_constraint_data_and_metadata(constraint, data, dtype, sdtype, metadata): + _dtype = data.dtypes[dtype] + constraint_class = constraint.get('constraint_class') + constraints = [] + if constraint_class == 'FixedCombinations': + for dtype_name, dtype_data in {**PANDAS_DTYPES, **NUMPY_DTYPES}.items(): + dtype_sdtype = EXPECTED_METADATA_SDTYPES.get(dtype_name, 'unknown') + if dtype_sdtype in ('categorical', 'boolean'): + data[f'{dtype}_{dtype_name}'] = data[dtype] + metadata.columns[f'{dtype}_{dtype_name}'] = {'sdtype': sdtype} + new_constraint = deepcopy(constraint) + data[dtype_name] = dtype_data[dtype_name] + dtype_sdtype = EXPECTED_METADATA_SDTYPES.get(dtype_name, 'unknown') + metadata.columns[dtype_name] = {'sdtype': dtype_sdtype} + new_constraint['constraint_parameters']['column_names'].append(dtype_name) + new_constraint['constraint_parameters']['column_names'].append( + f'{dtype}_{dtype_name}' + ) + constraints.append(new_constraint) + + elif constraint_class == 'Inequality': + if sdtype == 'numerical': + data['high'] = data[dtype] * 10 + metadata.columns['high'] = {'sdtype': 'numerical'} + + elif constraint_class == 'Range': + if sdtype == 'numerical': + data['mid'] = data[dtype] * 5 + data['high'] = data[dtype] * 10 + metadata.columns['mid'] = {'sdtype': 'numerical'} + metadata.columns['high'] = {'sdtype': 'numerical'} + + return constraints, data, metadata + + +@pytest.mark.parametrize( + 'constraint_name, constraint', SINGLE_COLUMN_PREDEFINED_CONSTRAINTS.items() +) +@pytest.mark.parametrize('dtype, data', {**PANDAS_DTYPES, **NUMPY_DTYPES, **PYARROW_DTYPES}.items()) +@pytest.mark.parametrize('sdtype', METADATA_SDTYPES) +def test_fit_and_sample_single_column_constraints(constraint_name, constraint, dtype, data, sdtype): + """Test fitting and sampling with single-column constraints for various data types. + + This test evaluates the `GaussianCopulaSynthesizer` to fit data and + generate synthetic data while applying single-column constraints to different + data types. It verifies that the synthesizer can respect the constraint and + successfully produce synthetic data with the same data types as the original. + + Args: + constraint_name (str): + The name of the constraint being tested. + constraint (dict): + The predefined constraint to apply to the data. + dtype (str): + The data type being tested. + data (pd.DataFrame): + The input data to fit and generate synthetic samples from. + + Raises: + AssertionError: + If the fit or sample results do not match previously recorded results or if the dtype + is no longer supported. + + The test flow includes: + 1. Initializing the `GaussianCopulaSynthesizer` with the metadata. + 2. Preparing the constraint and data for the test. + 3. Adding the constraint to the synthesizer, fitting the data, and verifying the fit result. + 4. Sampling synthetic data and checking that the dtype matches the original. + """ + if (sdtype, dtype, constraint_name) not in EXCLUDED_CONSTRAINT_TESTS: + metadata = _get_metadata_for_dtype_and_sdtype(dtype, sdtype) + synthesizer = GaussianCopulaSynthesizer(metadata) + sdtype = metadata.columns[dtype].get('sdtype') + previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'{constraint_name}_FIT') + previous_sample_result, _ = get_previous_dtype_result( + dtype, sdtype, f'{constraint_name}_SAMPLE' + ) + + # Prepare the constraint and data + constraint, data = _create_single_column_constraint_and_data( + deepcopy(constraint), data.copy(), dtype, sdtype + ) + + # Initialize results + sample_result = False + fit_result = False + try: + synthesizer.add_constraints([constraint]) + synthesizer.fit(data) + fit_result = True + + # Sample Synthetic Data + with prevent_tqdm_output(): + synthetic_data = synthesizer.sample(10) + + sample_result = synthetic_data.dtypes[dtype] == data.dtypes[dtype] + + except BaseException as e: + LOGGER.debug( + f"Error during fitting/sampling with dtype '{dtype}' and constraint " + f"'{constraint_name}': {e}" + ) + + save_results_to_json({ + 'dtype': dtype, + 'sdtype': sdtype, + f'{constraint_name}_FIT': fit_result, + f'{constraint_name}_SAMPLE': sample_result, + }) + if fit_result is False: + fit_assertion_message = f"{dtype} is no longer supported by '{constraint_name}_FIT''." + assert fit_result == previous_fit_result, fit_assertion_message + + if sample_result is False: + sample_assertion_message = ( + f"{dtype} is no longer supported by '{constraint_name}_FIT''." + ) + assert sample_result == previous_sample_result, sample_assertion_message + + +@pytest.mark.parametrize('constraint_name, constraint', MULTI_COLUMN_PREDEFINED_CONSTRAINTS.items()) +@pytest.mark.parametrize('dtype, data', {**PANDAS_DTYPES, **NUMPY_DTYPES, **PYARROW_DTYPES}.items()) +@pytest.mark.parametrize('sdtype', METADATA_SDTYPES) +def test_fit_and_sample_multi_column_constraints(constraint_name, constraint, dtype, data, sdtype): + """Test fitting and sampling with multi-column constraints for various data types. + + This test evaluates the `GaussianCopulaSynthesizer` to fit data and + generate synthetic data while applying multi-column constraints. It ensures + that the synthesizer can handle constraints across multiple columns and produce + synthetic data with the expected data types. + + Args: + constraint_name (str): + The name of the multi-column constraint being tested. + constraint (dict): + The predefined multi-column constraint to apply to the data. + dtype (str): + The data type being tested. + data (pd.DataFrame): + The input data to fit and generate synthetic samples from. + + Raises: + AssertionError: + If the fit or sample results do not match previously recorded results or if + the dtype is no longer supported. + + The test flow includes: + 1. Preparing the constraints, data, and metadata for the test. + 2. Initializing the `GaussianCopulaSynthesizer` with the metadata. + 3. Adding the multi-column constraints to the synthesizer and fitting the data. + 4. Sampling synthetic data and ensuring the synthetic data types match the original. + """ + if (sdtype, dtype, constraint_name) not in EXCLUDED_CONSTRAINT_TESTS: + metadata = _get_metadata_for_dtype_and_sdtype(dtype, sdtype) + sdtype = metadata.columns[dtype].get('sdtype') + previous_fit_result, _ = get_previous_dtype_result(dtype, sdtype, f'{constraint_name}_FIT') + previous_sample_result, _ = get_previous_dtype_result( + dtype, sdtype, f'{constraint_name}_SAMPLE' + ) + + # Prepare constraints, data required and metadata + constraints, data, metadata = _create_multi_column_constraint_data_and_metadata( + deepcopy(constraint), data.copy(), dtype, sdtype, metadata + ) + + # Initialize results + sample_result = False + fit_result = False + + try: + synthesizer = GaussianCopulaSynthesizer(metadata) + synthesizer.add_constraints(constraints) + synthesizer.fit(data) + fit_result = True + + # Generate Synthetic Data + with prevent_tqdm_output(): + synthetic_data = synthesizer.sample(10) + + sample_result = synthetic_data.dtypes[dtype] == data.dtypes[dtype] + + except BaseException as e: + LOGGER.debug( + f"Error during fitting/sampling with dtype '{dtype}' and constraint " + f"'{constraint_name}': {e}" + ) + + save_results_to_json({ + 'dtype': dtype, + 'sdtype': sdtype, + f'{constraint_name}_FIT': fit_result, + f'{constraint_name}_SAMPLE': sample_result, + }) + if fit_result is False: + fit_message = f"{dtype} failed during '{constraint_name}_FIT'." + assert fit_result == previous_fit_result, fit_message + + if sample_result is False: + sample_msg = f"{dtype} failed during '{constraint_name}_SAMPLE'." + assert sample_result == previous_sample_result, sample_msg diff --git a/tests/benchmark/utils.py b/tests/benchmark/utils.py new file mode 100644 index 000000000..082ff58d1 --- /dev/null +++ b/tests/benchmark/utils.py @@ -0,0 +1,191 @@ +"""Utility functions for the benchmarking.""" + +import json +import os +import sys +from datetime import date +from functools import lru_cache +from pathlib import Path + +import git +import pandas as pd + +from sdv.io.local import CSVHandler +from tests._external.gdrive_utils import get_latest_file, read_excel, save_to_gdrive +from tests._external.slack_utils import post_slack_message + +GDRIVE_OUTPUT_FOLDER = '16SkTOyQ3xkJDPJbyZCusb168JwreW5bm' +PYTHON_VERSION = f'{sys.version_info.major}.{sys.version_info.minor}' +TEMPRESULTS = Path(f'results/{sys.version_info.major}.{sys.version_info.minor}.json') + + +def get_previous_dtype_result(dtype, sdtype, method, python_version=PYTHON_VERSION): + """Return previous result for a given ``dtype`` and method.""" + data = get_previous_results() + df = data[python_version] + try: + filtered_row = df[(df['dtype'] == dtype) & (df['sdtype'] == sdtype)] + value = filtered_row[method].to_numpy()[0] + previously_seen = True + except (IndexError, KeyError): + value = False + previously_seen = False + + return value, previously_seen + + +@lru_cache() +def get_previous_results(): + """Get the last run for the dtype benchmarking.""" + latest_file = get_latest_file(GDRIVE_OUTPUT_FOLDER) + df = read_excel(latest_file['id']) + return df + + +def _load_temp_results(filename): + df = pd.read_json(filename) + df.iloc[:, 2:] = df.groupby(['dtype', 'sdtype']).transform(lambda x: x.ffill().bfill()) + for column in df.columns: + if column not in ('sdtype', 'dtype'): + df[column] = df[column].astype(bool) + + return df.drop_duplicates().reset_index(drop=True) + + +def _get_output_filename(): + repo = git.Repo(search_parent_directories=True) + commit_id = repo.head.object.hexsha + today = str(date.today()) + output_filename = f'{today}-{commit_id}' + return output_filename + + +def compare_previous_result_with_current(): + """Compare the previous result with the current and post a message on slack.""" + for result in Path('results/').rglob('*.json'): + python_version = result.stem + current_results = _load_temp_results(result) + csv_output = Path(f'results/{python_version}.csv') + current_results.to_csv(csv_output, index=False) + + new_supported_dtypes = [] + unsupported_dtypes = [] + previously_unseen_dtypes = [] + + for index, row in current_results.iterrows(): + dtype = row['dtype'] + sdtype = row['sdtype'] + for col in current_results.columns[1:]: + current_value = row[col] + stored_value, previously_seen = get_previous_dtype_result( + dtype, + sdtype, + col, + python_version, + ) + + if current_value and not stored_value: + new_supported_dtypes.append({ + 'dtype': dtype, + 'sdtype': sdtype, + 'method': col, + 'python_version': python_version, + }) + + elif not current_value and stored_value: + unsupported_dtypes.append({ + 'dtype': dtype, + 'sdtype': sdtype, + 'method': col, + 'python_version': python_version, + }) + + if not previously_seen: + previously_unseen_dtypes.append({ + 'dtype': dtype, + 'sdtype': sdtype, + 'method': col, + 'python_version': python_version, + }) + + return { + 'unsupported_dtypes': pd.DataFrame(unsupported_dtypes), + 'new_supported_dtypes': pd.DataFrame(new_supported_dtypes), + 'previously_unseen_dtypes': pd.DataFrame(previously_unseen_dtypes), + } + + +def save_results_to_json(results, filename=None): + """Save results to a JSON file, categorizing by `dtype`. + + This function saves the `results` dictionary to a specified JSON file. + The dictionary must contain a `dtype` key, which is used as a category + to group the results in the file. If the file already exists, it loads + the existing data, updates the `dtype` category with new values from + `results`, and saves the updated content back to the file. If the file + does not exist it doesn't write. + + Args: + results (dict): + A dictionary containing the data to save. Must include the + key `dtype` that specifies the category under which the data + will be stored in the JSON file. + filename (str, optional): + The name of the JSON file where the results will be saved. + Defaults to `None`. + """ + filename = filename or TEMPRESULTS + + if os.path.exists(filename): + with open(filename, 'r') as file: + try: + json_data = json.load(file) + except json.JSONDecodeError: + json_data = [] + + json_data.append(results) + with open(filename, 'w') as file: + json.dump(json_data, file, indent=4) + + +def calculate_support_percentage(df): + """Calculate the percentage of supported features (True) for each dtype in a DataFrame.""" + feature_columns = df.drop(columns=['dtype']) + # Calculate percentage of TRUE values for each row (dtype) + percentage_support = feature_columns.mean(axis=1) * 100 + return pd.DataFrame({'dtype': df['dtype'], 'percentage_supported': percentage_support}) + + +def compare_and_store_results_in_gdrive(): + csv_handler = CSVHandler() + comparison_results = compare_previous_result_with_current() + + results = csv_handler.read('results/') + sorted_results = {} + + slack_messages = [] + for key, value in comparison_results.items(): + if not value.empty: + sorted_results[key] = value + if key == 'unsupported_dtypes': + slack_messages.append(':fire: New unsupported DTypes!') + elif key == 'new_supported_dtypes': + slack_messages.append(':party_blob: New DTypes supported!') + + if len(slack_messages) == 0: + slack_messages.append(':dealwithit: No new changes to the DTypes in SDV.') + + for key, value in results.items(): + sorted_results[key] = value + + file_id = save_to_gdrive(GDRIVE_OUTPUT_FOLDER, sorted_results) + + slack_messages.append( + f'See ' + ) + slack_message = '\n'.join(slack_messages) + post_slack_message('sdv-alerts', slack_message) + + +if __name__ == '__main__': + compare_and_store_results_in_gdrive() diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index 8d3771164..3367bd0fa 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -7,7 +7,7 @@ def test_evaluation(): """Test ``evaluate_quality`` and ``run_diagnostic``.""" # Setup - table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4]}) + table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4.0]}) slightly_different_table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 3.5]}) data = { 'table1': table, diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index a43816100..43c9941a6 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1418,6 +1418,102 @@ def test_validate_data_datetime_warning(self): with pytest.warns(UserWarning, match=warning_msg): metadata.validate_data(data) + def test_add_relationship_circular_graph(self): + """Test that an error is raised when a circular relationship is detected. + + The graph has the cycle B->C->D->B. + Besides the cycle, the other relationships are: B->A, C->A, D->A. + """ + # Setup + metadata = MultiTableMetadata() + metadata.add_table('A') + metadata.add_column('A', 'id', sdtype='id') + metadata.add_column('A', 'fk', sdtype='id') + metadata.set_primary_key('A', 'id') + + metadata.add_table('B') + metadata.add_column('B', 'id', sdtype='id') + metadata.add_column('B', 'fk', sdtype='id') + metadata.set_primary_key('B', 'id') + + metadata.add_table('C') + metadata.add_column('C', 'id', sdtype='id') + metadata.add_column('C', 'fk', sdtype='id') + metadata.set_primary_key('C', 'id') + + metadata.add_table('D') + metadata.add_column('D', 'id', sdtype='id') + metadata.add_column('D', 'fk', sdtype='id') + metadata.set_primary_key('D', 'id') + + metadata.add_relationship('B', 'C', 'id', 'fk') + metadata.add_relationship('B', 'A', 'id', 'fk') + + metadata.add_relationship('C', 'D', 'id', 'fk') + metadata.add_relationship('C', 'A', 'id', 'fk') + + metadata.add_relationship('D', 'A', 'id', 'fk') + + # Run and Assert + error_msg = re.escape( + 'The relationships in the dataset describe a ' + "circular dependency between tables ['B', 'C', 'D']." + ) + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.add_relationship('D', 'B', 'id', 'fk') + + def test_add_relationship_circular_graph_complex(self): + """Test that an error is raised when a circular relationship is detected. + + The graph has the cycle C->E->D->C. + Besides the cycle, the other relationships are: C->B, D->B, E->B, E->A, A->B. + """ + # Setup + metadata = MultiTableMetadata() + metadata.add_table('A') + metadata.add_column('A', 'id', sdtype='id') + metadata.add_column('A', 'fk', sdtype='id') + metadata.set_primary_key('A', 'id') + + metadata.add_table('B') + metadata.add_column('B', 'id', sdtype='id') + metadata.add_column('B', 'fk', sdtype='id') + metadata.set_primary_key('B', 'id') + + metadata.add_table('C') + metadata.add_column('C', 'id', sdtype='id') + metadata.add_column('C', 'fk', sdtype='id') + metadata.set_primary_key('C', 'id') + + metadata.add_table('D') + metadata.add_column('D', 'id', sdtype='id') + metadata.add_column('D', 'fk', sdtype='id') + metadata.set_primary_key('D', 'id') + + metadata.add_table('E') + metadata.add_column('E', 'id', sdtype='id') + metadata.add_column('E', 'fk', sdtype='id') + metadata.set_primary_key('E', 'id') + + metadata.add_relationship('C', 'B', 'id', 'fk') + metadata.add_relationship('C', 'E', 'id', 'fk') + + metadata.add_relationship('D', 'B', 'id', 'fk') + metadata.add_relationship('D', 'C', 'id', 'fk') + + metadata.add_relationship('A', 'B', 'id', 'fk') + + metadata.add_relationship('E', 'A', 'id', 'fk') + metadata.add_relationship('E', 'B', 'id', 'fk') + + # Run and Assert + error_msg = re.escape( + 'The relationships in the dataset describe a ' + "circular dependency between tables ['C', 'D', 'E']." + ) + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.add_relationship('E', 'D', 'id', 'fk') + @patch('sdv.metadata.multi_table.SingleTableMetadata') def test_add_table(self, table_metadata_mock): """Test that the method adds the table name to ``instance.tables``.""" diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index f732fe827..48a32cd30 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -40,7 +40,9 @@ def test__initialize_models(self): locales = ['en_CA', 'fr_CA'] instance = Mock() instance._table_synthesizers = {} - instance._table_parameters = {'nesreca': {'default_distribution': 'gamma'}} + instance._table_parameters = { + 'nesreca': {'default_distribution': 'gamma', 'locales': ['en_US']}, + } instance.locales = locales instance.metadata = get_multi_table_metadata() @@ -57,7 +59,7 @@ def test__initialize_models(self): call( metadata=instance.metadata.tables['nesreca'], default_distribution='gamma', - locales=locales, + locales=['en_US'], ), call(metadata=instance.metadata.tables['oseba'], locales=locales), call(metadata=instance.metadata.tables['upravna_enota'], locales=locales), @@ -763,6 +765,7 @@ def test_preprocess(self): 'id_upravna_enota': np.arange(10), }), } + instance._transform_helper = Mock(return_value=data) synth_nesreca = Mock() synth_oseba = Mock() @@ -782,6 +785,7 @@ def test_preprocess(self): 'oseba': synth_oseba._preprocess.return_value, 'upravna_enota': synth_upravna_enota._preprocess.return_value, } + instance._transform_helper.assert_called_once_with(data) instance.validate.assert_called_once_with(data) assert instance.metadata._get_all_foreign_keys.call_args_list == [ call('nesreca'), @@ -1212,6 +1216,7 @@ def test_sample(self, mock_datetime, caplog): 'table2': pd.DataFrame({'id': [1, 2, 3], 'name': ['John', 'Johanna', 'Doe']}), } instance._sample = Mock(return_value=data) + instance._reverse_transform_helper = Mock(return_value=data) synth_id = 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5' instance._synthesizer_id = synth_id @@ -1222,6 +1227,7 @@ def test_sample(self, mock_datetime, caplog): # Assert instance._sample.assert_called_once_with(scale=1.5) + instance._reverse_transform_helper.assert_called_once_with(data) assert caplog.messages[0] == str({ 'EVENT': 'Sample', 'TIMESTAMP': '2024-04-19 16:20:10.037183', diff --git a/tests/unit/test__utils.py b/tests/unit/test__utils.py index 1cbcf3416..2f9b280d0 100644 --- a/tests/unit/test__utils.py +++ b/tests/unit/test__utils.py @@ -17,6 +17,7 @@ _get_datetime_format, _get_root_tables, _is_datetime_type, + _is_numerical, _validate_foreign_keys_not_null, check_sdv_versions_and_warn, check_synthesizer_version, @@ -240,6 +241,30 @@ def test__is_datetime_type_with_datetime_str(): assert is_datetime +def test__is_datetime_type_with_datetime_str_nanoseconds(): + """Test it for a datetime string with nanoseconds.""" + # Setup + value = '2011-10-15 20:11:03.498707' + + # Run + is_datetime = _is_datetime_type(value) + + # Assert + assert is_datetime + + +def test__is_datetime_type_with_str_int(): + """Test it for a string with an integer.""" + # Setup + value = '123' + + # Run + is_datetime = _is_datetime_type(value) + + # Assert + assert is_datetime is False + + def test__is_datetime_type_with_invalid_str(): """Test the ``_is_datetime_type`` function when an invalid string is passed. @@ -713,3 +738,33 @@ def test_get_possible_chars(): nums = [str(i) for i in range(10)] lowercase_letters = list(string.ascii_lowercase) assert possible_chars == prefix + nums + ['_'] + lowercase_letters + + +def test__is_numerical(): + """Test that ensures that if passed any numerical data type we will get a ``True``.""" + # Setup + np_int = np.int16(10) + np_nan = np.nan + + # Run + np_int_result = _is_numerical(np_int) + np_nan_result = _is_numerical(np_nan) + + # Assert + assert np_int_result + assert np_nan_result + + +def test__is_numerical_string(): + """Test that ensures that if passed any other value but numerical it will return `False`.""" + # Setup + str_value = 'None' + datetime_value = pd.to_datetime('2012-01-01') + + # Run + str_result = _is_numerical(str_value) + datetime_result = _is_numerical(datetime_value) + + # Assert + assert str_result is False + assert datetime_result is False