Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error if continuous training data contains null values #428

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ctgan/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Custom errors for CTGAN."""


class InvalidDataError(Exception):
"""Error to raise when data is not valid."""
27 changes: 27 additions & 0 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from ctgan.data_sampler import DataSampler
from ctgan.data_transformer import DataTransformer
from ctgan.errors import InvalidDataError
from ctgan.synthesizers.base import BaseSynthesizer, random_state


Expand Down Expand Up @@ -289,6 +290,31 @@ def _validate_discrete_columns(self, train_data, discrete_columns):
if invalid_columns:
raise ValueError(f'Invalid columns found: {invalid_columns}')

def _validate_null_data(self, train_data, discrete_columns):
"""Check whether null values exist in continuous ``train_data``.

Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
if isinstance(train_data, pd.DataFrame):
continuous_cols = list(set(train_data.columns) - set(discrete_columns))
any_nulls = train_data[continuous_cols].isna().any().any()
else:
continuous_cols = [i for i in range(train_data.shape[1]) if i not in discrete_columns]
any_nulls = pd.DataFrame(train_data)[continuous_cols].isna().any().any()

if any_nulls:
raise InvalidDataError(
'CTGAN does not support null values in the continuous training data. '
'Please remove all null values from your continuous training data.'
)

@random_state
def fit(self, train_data, discrete_columns=(), epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Expand All @@ -303,6 +329,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
a ``pandas.DataFrame``, this list should contain the column names.
"""
self._validate_discrete_columns(train_data, discrete_columns)
self._validate_null_data(train_data, discrete_columns)

if epochs is None:
epochs = self._epochs
Expand Down
20 changes: 20 additions & 0 deletions tests/integration/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pandas as pd
import pytest

from ctgan.errors import InvalidDataError
from ctgan.synthesizers.ctgan import CTGAN


Expand Down Expand Up @@ -132,6 +133,25 @@ def test_categorical_nan():
assert {'b', 'c'}.issubset(values)


def test_continuous_nan():
"""Test the CTGAN with missing numerical values."""
# Setup
data = pd.DataFrame({
rwedge marked this conversation as resolved.
Show resolved Hide resolved
'continuous': [np.nan, 1.0, 2.0] * 10,
'discrete': ['a', 'b', 'c'] * 10,
})
discrete_columns = ['discrete']
error_message = (
'CTGAN does not support null values in the continuous training data. '
'Please remove all null values from your continuous training data.'
)

# Run and Assert
ctgan = CTGAN(epochs=1)
with pytest.raises(InvalidDataError, match=error_message):
ctgan.fit(data, discrete_columns)


def test_synthesizer_sample():
"""Test the CTGAN samples the correct datatype."""
data = pd.DataFrame({'discrete': np.random.choice(['a', 'b', 'c'], 100)})
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from unittest import TestCase
from unittest.mock import Mock

import numpy as np
import pandas as pd
import pytest
import torch

from ctgan.data_transformer import SpanInfo
from ctgan.errors import InvalidDataError
from ctgan.synthesizers.ctgan import CTGAN, Discriminator, Generator, Residual


Expand Down Expand Up @@ -289,3 +291,42 @@ def test__validate_discrete_columns(self):
ctgan = CTGAN(epochs=1)
with pytest.raises(ValueError, match=r'Invalid columns found: {\'doesnt exist\'}'):
ctgan.fit(data, discrete_columns)

def test__validate_null_data(self):
"""Test `_validate_null_data` with pandas and numpy data.
Check the appropriate error is raised if null values are present in
continuous columns, both for numpy arrays and dataframes.
"""
# Setup
discrete_df = pd.DataFrame({'discrete': ['a', 'b']})
discrete_array = np.array([['a'], ['b']])
continuous_no_nulls_df = pd.DataFrame({'continuous': [0, 1]})
continuous_no_nulls_array = np.array([[0], [1]])
continuous_with_null_df = pd.DataFrame({'continuous': [1, np.nan]})
continuous_with_null_array = np.array([[1], [np.nan]])
ctgan = CTGAN(epochs=1)
error_message = (
'CTGAN does not support null values in the continuous training data. '
'Please remove all null values from your continuous training data.'
)

# Test discrete DataFrame fits without error
ctgan.fit(discrete_df, ['discrete'])

# Test discrete array fits without error
ctgan.fit(discrete_array, [0])

# Test continuous DataFrame without nulls fits without error
ctgan.fit(continuous_no_nulls_df)

# Test continuous array without nulls fits without error
ctgan.fit(continuous_no_nulls_array)

# Test nulls in continuous columns DataFrame errors on fit
with pytest.raises(InvalidDataError, match=error_message):
ctgan.fit(continuous_with_null_df)

# Test nulls in continuous columns array errors on fit
with pytest.raises(InvalidDataError, match=error_message):
ctgan.fit(continuous_with_null_array)
Loading