diff --git a/HISTORY.md b/HISTORY.md index 5d56ad95..cf4f356c 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,16 @@ # History +## v0.4.1 - 2021-03-30 + +This release exposes all the hyperparameters which the user may find useful for both `CTGAN` +and `TVAE`. Also `TVAE` can now be fitted on datasets that are shorter than the batch +size and drops the last batch only if the data size is not divisible by the batch size. + +### Issues closed + +* `TVAE`: Adapt `batch_size` to data size - Issue [#135](https://github.com/sdv-dev/CTGAN/issues/135) by @fealho and @csala +* `ValueError` from `validate_discre_columns` with `uniqueCombinationConstraint` - Issue [133](https://github.com/sdv-dev/CTGAN/issues/133) by @fealho and @MLjungg + ## v0.4.0 - 2021-02-24 Maintenance relese to upgrade dependencies to ensure compatibility with the rest diff --git a/conda/meta.yaml b/conda/meta.yaml index 2f9768cd..6bb7643e 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -1,5 +1,5 @@ {% set name = 'ctgan' %} -{% set version = '0.4.0' %} +{% set version = '0.4.1.dev2' %} package: name: "{{ name|lower }}" diff --git a/ctgan/__init__.py b/ctgan/__init__.py index d9f0d19e..a7669f0a 100644 --- a/ctgan/__init__.py +++ b/ctgan/__init__.py @@ -4,7 +4,7 @@ __author__ = 'MIT Data To AI Lab' __email__ = 'dailabmit@gmail.com' -__version__ = '0.4.0' +__version__ = '0.4.1.dev2' from ctgan.demo import load_demo from ctgan.synthesizers.ctgan import CTGANSynthesizer diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index c34a67ab..e2b85612 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -14,11 +14,11 @@ class Discriminator(Module): - def __init__(self, input_dim, discriminator_dim, pack=10): + def __init__(self, input_dim, discriminator_dim, pac=10): super(Discriminator, self).__init__() - dim = input_dim * pack - self.pack = pack - self.packdim = dim + dim = input_dim * pac + self.pac = pac + self.pacdim = dim seq = [] for item in list(discriminator_dim): seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)] @@ -49,8 +49,8 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb return gradient_penalty def forward(self, input): - assert input.size()[0] % self.pack == 0 - return self.seq(input.view(-1, self.packdim)) + assert input.size()[0] % self.pac == 0 + return self.seq(input.view(-1, self.pacdim)) class Residual(Module): @@ -122,12 +122,19 @@ class CTGANSynthesizer(BaseSynthesizer): Whether to have print statements for progress results. Defaults to ``False``. epochs (int): Number of training epochs. Defaults to 300. + pac (int): + Number of samples to group together when applying the discriminator. + Defaults to 10. + cuda (bool): + Whether to attempt to use cuda for GPU computation. + If this is False or CUDA is not available, CPU will be used. + Defaults to ``True``. """ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, discriminator_decay=0, batch_size=500, discriminator_steps=1, log_frequency=True, - verbose=False, epochs=300): + verbose=False, epochs=300, pac=10, cuda=True): assert batch_size % 2 == 0 @@ -145,8 +152,20 @@ def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_di self._log_frequency = log_frequency self._verbose = verbose self._epochs = epochs - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.trained_epochs = 0 + self.pac = pac + + if not cuda or not torch.cuda.is_available(): + device = 'cpu' + elif isinstance(cuda, str): + device = cuda + else: + device = 'cuda' + + self._device = torch.device(device) + + self._transformer = None + self._data_sampler = None + self._generator = None @staticmethod def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1): @@ -289,18 +308,19 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): data_dim ).to(self._device) - self._discriminator = Discriminator( + discriminator = Discriminator( data_dim + self._data_sampler.dim_cond_vec(), - self._discriminator_dim + self._discriminator_dim, + pac=self.pac ).to(self._device) - self._optimizerG = optim.Adam( + optimizerG = optim.Adam( self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), weight_decay=self._generator_decay ) - self._optimizerD = optim.Adam( - self._discriminator.parameters(), lr=self._discriminator_lr, + optimizerD = optim.Adam( + discriminator.parameters(), lr=self._discriminator_lr, betas=(0.5, 0.9), weight_decay=self._discriminator_decay ) @@ -309,7 +329,6 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): steps_per_epoch = max(len(train_data) // self._batch_size, 1) for i in range(epochs): - self.trained_epochs += 1 for id_ in range(steps_per_epoch): for n in range(self._discriminator_steps): @@ -343,17 +362,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): real_cat = real fake_cat = fake - y_fake = self._discriminator(fake_cat) - y_real = self._discriminator(real_cat) + y_fake = discriminator(fake_cat) + y_real = discriminator(real_cat) - pen = self._discriminator.calc_gradient_penalty( - real_cat, fake_cat, self._device) + pen = discriminator.calc_gradient_penalty( + real_cat, fake_cat, self._device, self.pac) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) - self._optimizerD.zero_grad() + optimizerD.zero_grad() pen.backward(retain_graph=True) loss_d.backward() - self._optimizerD.step() + optimizerD.step() fakez = torch.normal(mean=mean, std=std) condvec = self._data_sampler.sample_condvec(self._batch_size) @@ -370,9 +389,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): fakeact = self._apply_activate(fake) if c1 is not None: - y_fake = self._discriminator(torch.cat([fakeact, c1], dim=1)) + y_fake = discriminator(torch.cat([fakeact, c1], dim=1)) else: - y_fake = self._discriminator(fakeact) + y_fake = discriminator(fakeact) if condvec is None: cross_entropy = 0 @@ -381,9 +400,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None): loss_g = -torch.mean(y_fake) + cross_entropy - self._optimizerG.zero_grad() + optimizerG.zero_grad() loss_g.backward() - self._optimizerG.step() + optimizerG.step() if self._verbose: print(f"Epoch {i+1}, Loss G: {loss_g.detach().cpu(): .4f}," @@ -444,7 +463,5 @@ def sample(self, n, condition_column=None, condition_value=None): def set_device(self, device): self._device = device - if hasattr(self, '_generator'): + if self._generator is not None: self._generator.to(self._device) - if hasattr(self, '_discriminator'): - self._discriminator.to(self._device) diff --git a/ctgan/synthesizers/tvae.py b/ctgan/synthesizers/tvae.py index fee2a47e..6b85f827 100644 --- a/ctgan/synthesizers/tvae.py +++ b/ctgan/synthesizers/tvae.py @@ -82,7 +82,9 @@ def __init__( decompress_dims=(128, 128), l2scale=1e-5, batch_size=500, - epochs=300 + epochs=300, + loss_factor=2, + cuda=True ): self.embedding_dim = embedding_dim @@ -91,17 +93,24 @@ def __init__( self.l2scale = l2scale self.batch_size = batch_size - self.loss_factor = 2 + self.loss_factor = loss_factor self.epochs = epochs - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if not cuda or not torch.cuda.is_available(): + device = 'cpu' + elif isinstance(cuda, str): + device = cuda + else: + device = 'cuda' + + self._device = torch.device(device) def fit(self, train_data, discrete_columns=tuple()): self.transformer = DataTransformer() self.transformer.fit(train_data, discrete_columns) train_data = self.transformer.transform(train_data) dataset = TensorDataset(torch.from_numpy(train_data.astype('float32')).to(self._device)) - loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True) + loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=False) data_dim = self.transformer.output_dimensions encoder = Encoder(data_dim, self.compress_dims, self.embedding_dim).to(self._device) diff --git a/setup.cfg b/setup.cfg index 1dcc7d6a..f33623f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.4.0 +current_version = 0.4.1.dev2 commit = True tag = True parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))? diff --git a/setup.py b/setup.py index ad86163c..8b74de5e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ 'scikit-learn>=0.23,<1', 'torch>=1.4,<2', 'torchvision>=0.5.0,<1', - 'rdt>=0.4.0,<0.5', + 'rdt>=0.4.1,<0.5', ] setup_requires = [ @@ -99,6 +99,6 @@ test_suite='tests', tests_require=tests_require, url='https://github.com/sdv-dev/CTGAN', - version='0.4.0', + version='0.4.1.dev2', zip_safe=False, ) diff --git a/tests/integration/test_tvae.py b/tests/integration/test_tvae.py index afeaa26d..e504fd70 100644 --- a/tests/integration/test_tvae.py +++ b/tests/integration/test_tvae.py @@ -33,3 +33,21 @@ def test_tvae(tmpdir): assert isinstance(sampled, pd.DataFrame) assert set(sampled.columns) == set(data.columns) assert set(sampled.dtypes) == set(data.dtypes) + + +def test_drop_last_false(): + data = pd.DataFrame({ + '1': ['a', 'b', 'c'] * 150, + '2': ['a', 'b', 'c'] * 150 + }) + + tvae = TVAESynthesizer(epochs=300) + tvae.fit(data, ['1', '2']) + + sampled = tvae.sample(100) + correct = 0 + for _, row in sampled.iterrows(): + if row['1'] == row['2']: + correct += 1 + + assert correct >= 95