Skip to content

Commit

Permalink
Add future warnings to all public functions of SingleTablePreset
Browse files Browse the repository at this point in the history
  • Loading branch information
lajohn4747 committed Apr 4, 2024
1 parent 6750191 commit 3433e5e
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions sdv/lite/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import logging
import sys
import warnings

import cloudpickle

Expand All @@ -14,6 +15,10 @@
PRESETS = {
FAST_ML_PRESET: 'Use this preset to minimize the time needed to create a synthetic data model.'
}
DEPRECATION_MSG = (
"The 'SingleTablePreset' is deprecated. For equivalent Fast ML "
"functionality, please use the 'GaussianCopulaSynthesizer'."
)


class SingleTablePreset:
Expand Down Expand Up @@ -41,6 +46,7 @@ def _setup_fast_preset(self, metadata, locales):
)

def __init__(self, metadata, name, locales=['en_US']):
warnings.warn(DEPRECATION_MSG, FutureWarning)
self.locales = locales
if name not in PRESETS:
raise ValueError(f"'name' must be one of {PRESETS}.")
Expand All @@ -58,14 +64,17 @@ def add_constraints(self, constraints):
* ``constraint_class``: Name of the constraint to apply.
* ``constraint_parameters``: A dictionary with the constraint parameters.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
self._synthesizer.add_constraints(constraints)

def get_metadata(self):
"""Return the ``SingleTableMetadata`` for this synthesizer."""
warnings.warn(DEPRECATION_MSG, FutureWarning)
return self._synthesizer.get_metadata()

def get_parameters(self):
"""Return the parameters used to instantiate the synthesizer."""
warnings.warn(DEPRECATION_MSG, FutureWarning)
parameters = inspect.signature(self.__init__).parameters
instantiated_parameters = {}
for parameter_name in parameters:
Expand All @@ -81,6 +90,7 @@ def fit(self, data):
data (pandas.DataFrame):
Data to fit the model to.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
self._synthesizer.fit(data)

def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file_path=None):
Expand All @@ -101,6 +111,7 @@ def sample(self, num_rows, max_tries_per_batch=100, batch_size=None, output_file
pandas.DataFrame:
Sampled data.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
sampled = self._synthesizer.sample(
num_rows,
max_tries_per_batch,
Expand Down Expand Up @@ -132,6 +143,7 @@ def sample_from_conditions(self, conditions, max_tries_per_batch=100,
pandas.DataFrame:
Sampled data.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
sampled = self._synthesizer.sample_from_conditions(
conditions,
max_tries_per_batch,
Expand Down Expand Up @@ -163,6 +175,7 @@ def sample_remaining_columns(self, known_columns, max_tries_per_batch=100,
pandas.DataFrame:
Sampled data.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
sampled = self._synthesizer.sample_remaining_columns(
known_columns,
max_tries_per_batch,
Expand All @@ -179,6 +192,7 @@ def save(self, filepath):
filepath (str):
Path where the SDV instance will be serialized.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
with open(filepath, 'wb') as output:
cloudpickle.dump(self, output)

Expand All @@ -194,13 +208,15 @@ def load(cls, filepath):
SingleTableSynthesizer:
The loaded synthesizer.
"""
warnings.warn(DEPRECATION_MSG, FutureWarning)
with open(filepath, 'rb') as f:
model = cloudpickle.load(f)
return model

@classmethod
def list_available_presets(cls, out=sys.stdout):
"""List the available presets and their descriptions."""
warnings.warn(DEPRECATION_MSG, FutureWarning)
out.write(f'Available presets:\n{PRESETS}\n\n'
'Supply the desired preset using the `name` parameter.\n\n'
'Have any requests for custom presets? Contact the SDV team to learn '
Expand Down

0 comments on commit 3433e5e

Please sign in to comment.