Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Apr 8, 2024
1 parent a4e0d9a commit bc431b9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 42 deletions.
37 changes: 22 additions & 15 deletions sdv/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Miscellaneous utility functions."""
import operator
import warnings
from collections import defaultdict
from collections.abc import Iterable
Expand Down Expand Up @@ -280,9 +281,11 @@ def check_sdv_versions_and_warn(synthesizer):
warnings.warn(message, SDVVersionWarning)


def _check_is_lower_version(current_version, synthesizer_version,
check_synthesizer_is_greater=False):
"""Check if the current version is lower than the synthesizer version.
def _compare_versions(current_version, synthesizer_version, compare_operator=operator.gt):
"""Compare two versions.
Given a ``compare_operator`` compare two versions using that operator to determine if one is
greater than the other or vice-versa.
Args:
current_version (str):
Expand All @@ -291,8 +294,8 @@ def _check_is_lower_version(current_version, synthesizer_version,
synthesizer_version (str):
The synthesizer version to compare, formatted as a string with major, minor, and
revision parts separated by periods (e.g., "1.0.0")
check_synthesizer_is_greater (bool):
If ``True`` invert the check.
compare_operator (operator):
Operator function to evaluate with. Defaults to ``operator.gt``.
Returns:
bool:
Expand All @@ -304,24 +307,28 @@ def _check_is_lower_version(current_version, synthesizer_version,
try:
current_v = int(current_v)
synth_v = int(synth_v)
if current_v > synth_v:
return False if not check_synthesizer_is_greater else True
if compare_operator(current_v, synth_v):
return False

if compare_operator(synth_v, current_v):
return True

if synth_v > current_v:
return True if not check_synthesizer_is_greater else False
except Exception:
pass

return False


def check_synthesizer_version(synthesizer, is_fit_method=False,
check_synthesizer_is_greater=False):
def check_synthesizer_version(synthesizer, is_fit_method=False, compare_operator=operator.gt):
"""Check if the current synthesizer version is greater than the package version.
Args:
synthesizer (BaseSynthesizer or BaseMultiTableSynthesizer):
An SDV model instance to check versions against.
is_fit_method (bool):
Whether or not this function is being called by a ``fit`` function.
compare_operator (operator):
Operator function to evaluate with. Defaults to ``operator.gt``.
Raises:
VersionError:
Expand All @@ -341,18 +348,18 @@ def check_synthesizer_version(synthesizer, is_fit_method=False,

is_public_lower = False
if None not in (current_public_version, fit_public_version):
is_public_lower = _check_is_lower_version(
is_public_lower = _compare_versions(
current_public_version,
fit_public_version,
check_synthesizer_is_greater
compare_operator
)

is_enterprise_lower = False
if None not in (current_enterprise_version, fit_enterprise_version):
is_enterprise_lower = _check_is_lower_version(
is_enterprise_lower = _compare_versions(
current_enterprise_version,
fit_enterprise_version,
check_synthesizer_is_greater
compare_operator
)

if is_public_lower and is_enterprise_lower:
Expand Down
5 changes: 3 additions & 2 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import datetime
import inspect
import operator
import warnings
from collections import defaultdict
from copy import deepcopy
Expand Down Expand Up @@ -368,7 +369,7 @@ def fit_processed_data(self, processed_data):
processed_data (dict):
Dictionary mapping each table name to a preprocessed ``pandas.DataFrame``.
"""
check_synthesizer_version(self, is_fit_method=True, check_synthesizer_is_greater=True)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
augmented_data = self._augment_tables(processed_data)
self._model_tables(augmented_data)
self._fitted = True
Expand All @@ -384,7 +385,7 @@ def fit(self, data):
Dictionary mapping each table name to a ``pandas.DataFrame`` in the raw format
(before any transformations).
"""
check_synthesizer_version(self, is_fit_method=True, check_synthesizer_is_greater=True)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
_validate_foreign_keys_not_null(self.metadata, data)
self._check_metadata_updated()
self._fitted = False
Expand Down
5 changes: 3 additions & 2 deletions sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import logging
import math
import operator
import os
import uuid
import warnings
Expand Down Expand Up @@ -386,7 +387,7 @@ def fit_processed_data(self, processed_data):
processed_data (pandas.DataFrame):
The transformed data used to fit the model to.
"""
check_synthesizer_version(self, is_fit_method=True, check_synthesizer_is_greater=True)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
if not processed_data.empty:
self._fit(processed_data)

Expand All @@ -402,7 +403,7 @@ def fit(self, data):
data (pandas.DataFrame):
The raw data (before any transformations) to fit the model to.
"""
check_synthesizer_version(self, is_fit_method=True, check_synthesizer_is_greater=True)
check_synthesizer_version(self, is_fit_method=True, compare_operator=operator.lt)
self._check_metadata_updated()
self._fitted = False
self._data_processor.reset_sampling()
Expand Down
39 changes: 16 additions & 23 deletions tests/unit/test__utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
import re
from datetime import datetime
from unittest.mock import Mock, patch
Expand All @@ -8,7 +9,7 @@

from sdv import version
from sdv._utils import (
_check_is_lower_version, _convert_to_timedelta, _create_unique_name, _get_datetime_format,
_compare_versions, _convert_to_timedelta, _create_unique_name, _get_datetime_format,
_get_root_tables, _is_datetime_type, _validate_foreign_keys_not_null,
check_sdv_versions_and_warn, check_synthesizer_version)
from sdv.errors import SDVVersionWarning, SynthesizerInputError, VersionError
Expand Down Expand Up @@ -409,40 +410,40 @@ def test_check_sdv_versions_and_warn_public_and_enterprise_missmatch(mock_versio
check_sdv_versions_and_warn(synthesizer)


def test__check_is_lower_version():
"""Test that _check_is_lower_version returns True if synthesizer version is greater."""
def test__compare_versions():
"""Test that _compare_versions returns True if synthesizer version is greater."""
# Setup
synthesizer_version = '1.2.3'
current_version = '1.2.1'
synthesizer_version = '1.2.3'

# Run
result = _check_is_lower_version(current_version, synthesizer_version)
result = _compare_versions(current_version, synthesizer_version)

# Assert
assert result is True


def test__check_is_lower_version_equal():
"""Test that _check_is_lower_version returns False if synthesizer version is equal."""
def test__compare_versions_equal():
"""Test that _compare_versions returns False if synthesizer version is equal."""
# Setup
synthesizer_version = '1.2.3'
current_version = '1.2.3'

# Run
result = _check_is_lower_version(current_version, synthesizer_version)
result = _compare_versions(current_version, synthesizer_version)

# Assert
assert result is False


def test__check_is_lower_version_lower():
"""Test that _check_is_lower_version returns False if synthesizer version is lower."""
def test__compare_versions_lower():
"""Test that _compare_versions returns False if synthesizer version is lower."""
# Setup
synthesizer_version = '1.0.3'
current_version = '1.2.1'

# Run
result = _check_is_lower_version(current_version, synthesizer_version)
result = _compare_versions(current_version, synthesizer_version)

# Assert
assert result is False
Expand Down Expand Up @@ -570,7 +571,7 @@ def test_check_synthesizer_version_check_synthesizer_is_greater(mock_version):
check_synthesizer_version(
synthesizer,
is_fit_method=True,
check_synthesizer_is_greater=True
compare_operator=operator.lt
)


Expand All @@ -587,7 +588,7 @@ def test_check_synthesizer_version_check_synthesizer_is_greater_equal(mock_versi
mock_version.enterprise = '1.3.0'

# Run and Assert
check_synthesizer_version(synthesizer, is_fit_method=True, check_synthesizer_is_greater=True)
check_synthesizer_version(synthesizer, is_fit_method=True, compare_operator=operator.lt)


@patch('sdv._utils.version')
Expand All @@ -613,11 +614,7 @@ def test_check_synthesizer_version_check_synthesizer_is_greater_public_missmatch
'Please create a new synthesizer.'
)
with pytest.raises(VersionError, match=message):
check_synthesizer_version(
synthesizer,
is_fit_method=True,
check_synthesizer_is_greater=True
)
check_synthesizer_version(synthesizer, is_fit_method=True, compare_operator=operator.lt)


@patch('sdv._utils.version')
Expand All @@ -643,8 +640,4 @@ def test_check_synthesizer_version_check_synthesizer_is_greater_both_missmatch(m
'Fitting this synthesizer again is not supported. Please create a new synthesizer.'
)
with pytest.raises(VersionError, match=message):
check_synthesizer_version(
synthesizer,
is_fit_method=True,
check_synthesizer_is_greater=True
)
check_synthesizer_version(synthesizer, is_fit_method=True, compare_operator=operator.lt)

0 comments on commit bc431b9

Please sign in to comment.