From c78f70fda4a8b54483b5af55c12a27ac19a19f63 Mon Sep 17 00:00:00 2001 From: Richard C Gerkin Date: Fri, 14 Jun 2024 16:26:27 -0700 Subject: [PATCH] Added constant for float relative tolerance --- sdv/single_table/base.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index f0d669d2a..1bb661f32 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -33,6 +33,7 @@ COND_IDX = str(uuid.uuid4()) FIXED_RNG_SEED = 73251 +FLOAT_RTOL = 0.01 class BaseSynthesizer: @@ -576,7 +577,7 @@ def _filter_conditions(sampled, conditions, float_rtol): return sampled def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, - float_rtol=0.1, previous_rows=None, keep_extra_columns=False): + float_rtol=FLOAT_RTOL, previous_rows=None, keep_extra_columns=False): """Sample rows with the given conditions. Input conditions is taken both in the raw input format, which will be used @@ -654,7 +655,7 @@ def _sample_rows(self, num_rows, conditions=None, transformed_conditions=None, return sampled, num_rows def _sample_batch(self, batch_size, max_tries=100, - conditions=None, transformed_conditions=None, float_rtol=0.01, + conditions=None, transformed_conditions=None, float_rtol=FLOAT_RTOL, progress_bar=None, output_file_path=None, keep_extra_columns=False): """Sample a batch of rows with the given conditions. @@ -774,7 +775,7 @@ def _make_condition_dfs(conditions): ] def _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditions=None, - transformed_conditions=None, float_rtol=0.01, progress_bar=None, + transformed_conditions=None, float_rtol=FLOAT_RTOL, progress_bar=None, output_file_path=None): sampled = [] batch_size = batch_size if num_rows > batch_size else num_rows @@ -794,7 +795,7 @@ def _sample_in_batches(self, num_rows, batch_size, max_tries_per_batch, conditio return sampled.head(num_rows) def _conditionally_sample_rows(self, dataframe, condition, transformed_condition, - max_tries_per_batch=None, batch_size=None, float_rtol=0.01, + max_tries_per_batch=None, batch_size=None, float_rtol=FLOAT_RTOL, graceful_reject_sampling=True, progress_bar=None, output_file_path=None): batch_size = batch_size or len(dataframe)