Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs for conditional sampling #236

Closed

Conversation

AndresAlgaba
Copy link

@AndresAlgaba AndresAlgaba commented Jul 20, 2022

Hi everyone, this PR fixes issues #169 and #235 which report bugs concerning the sampling from the conditional generator after training, i.e., the sample method of CTGAN. The details of the proposed changes are described and discussed in the issues, but I give a summary here:

  • Issue discrete_column_matrix_st from data_sampler class is always 0 #169 concerns the _discrete_column_matrix_st of the DataSampler in CTGAN. It affects the sample_original_condvec and generate_cond_from_condition_column_info methods. Adding self._discrete_column_matrix_st[current_id] = st fixes the issue for sample_original_condvec. To fix the issue for generate_cond_from_condition_column_info, I have replaced _discrete_column_matrix_st with _discrete_column_cond_st. The difference between both fixes is due to creating a conditional vector vs. selecting a conditional vector from the data (which also contains continuous variables and thus requires other indices).
  • Issue Conditional sampling and cross-entropy loss #235 was only partially fixed by setting _discrete_column_matrix_st to _discrete_column_cond_st. There were still some issues as the generator contains batchnorm layers, and the model was still in train mode. Setting self._generator.eval() fixed the issue here. For performance, I also added the with torch.no_grad().
  • I have written test_synthesizer_sampling to test the sampling methods. I noticed that test_log_frequency was failing, but after looking into more detail, it seems this test is outdated Expose log_frequency parameter for conditional sampling #20. The generator's sampling during inference time is always set to the empirical frequency (not sure whether this is intentional, and maybe an issue to request the feature to sample with log frequency may be appropriate?). In training, the default option is the log frequency, but this is not what the test is assessing. Therefore, I have changed this test, but it can also be removed.

@AndresAlgaba AndresAlgaba requested a review from a team as a code owner July 20, 2022 12:27
@AndresAlgaba AndresAlgaba requested review from pvk-developer and removed request for a team July 20, 2022 12:27
@AndresAlgaba AndresAlgaba marked this pull request as draft July 20, 2022 12:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant