Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi everyone, this PR fixes issues #169 and #235 which report bugs concerning the sampling from the conditional generator after training, i.e., the
sample
method ofCTGAN
. The details of the proposed changes are described and discussed in the issues, but I give a summary here:_discrete_column_matrix_st
of theDataSampler
inCTGAN
. It affects thesample_original_condvec
andgenerate_cond_from_condition_column_info
methods. Addingself._discrete_column_matrix_st[current_id] = st
fixes the issue forsample_original_condvec
. To fix the issue forgenerate_cond_from_condition_column_info
, I have replaced_discrete_column_matrix_st
with_discrete_column_cond_st
. The difference between both fixes is due to creating a conditional vector vs. selecting a conditional vector from the data (which also contains continuous variables and thus requires other indices)._discrete_column_matrix_st
to_discrete_column_cond_st
. There were still some issues as the generator contains batchnorm layers, and the model was still intrain
mode. Settingself._generator.eval()
fixed the issue here. For performance, I also added thewith torch.no_grad()
.test_synthesizer_sampling
to test the sampling methods. I noticed thattest_log_frequency
was failing, but after looking into more detail, it seems this test is outdated Exposelog_frequency
parameter for conditional sampling #20. The generator's sampling during inference time is always set to the empirical frequency (not sure whether this is intentional, and maybe an issue to request the feature to sample with log frequency may be appropriate?). In training, the default option is the log frequency, but this is not what the test is assessing. Therefore, I have changed this test, but it can also be removed.