Skip to content

Commit

Permalink
Merge pull request #6208 from janezd/datasample-zero-size
Browse files Browse the repository at this point in the history
[FIX] Data Sampler: Fix crash when requesting an empty sample
  • Loading branch information
VesnaT authored Dec 1, 2022
2 parents f01cc80 + b2b6b77 commit 3423a8e
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
14 changes: 9 additions & 5 deletions Orange/widgets/data/owdatasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def set_sampling_type_i():
ibox = gui.indentedBox(sampling)
self.sampleSizeSpin = gui.spin(
ibox, self, "sampleSizeNumber", label="Instances: ",
minv=1, maxv=self._MAX_SAMPLE_SIZE,
minv=0, maxv=self._MAX_SAMPLE_SIZE,
callback=set_sampling_type(self.FixedSize),
controlWidth=90)
gui.checkBox(
Expand Down Expand Up @@ -395,11 +395,15 @@ def __call__(self, table):
o[sample] = 0
others = np.nonzero(o)[0]
return others, sample
if self.n == len(table):
if self.n in (0, len(table)):
rgen = np.random.RandomState(self.random_state)
sample = np.arange(self.n)
rgen.shuffle(sample)
return np.array([], dtype=int), sample
shuffled = np.arange(len(table))
rgen.shuffle(shuffled)
empty = np.array([], dtype=int)
if self.n == 0:
return shuffled, empty
else:
return empty, shuffled
elif self.stratified and table.domain.has_discrete_class:
test_size = max(len(table.domain.class_var.values), self.n)
splitter = skl.StratifiedShuffleSplit(
Expand Down
27 changes: 27 additions & 0 deletions Orange/widgets/data/tests/test_owdatasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def set_fixed_sample_size(self, sample_size, with_replacement=False):
self.widget.commit()
return self.widget.sampleSizeSpin.value()

def set_fixed_proportion(self, proportion):
"""Set fixed sample proportion.
"""
self.select_sampling_type(self.widget.FixedProportion)
self.widget.sampleSizePercentageSlider.setValue(proportion)
self.widget.commit()

def assertNoIntersection(self, sample, other):
self.assertFalse(bool(set(sample.ids) & set(other.ids)))

Expand Down Expand Up @@ -170,6 +177,26 @@ def test_cv_output_migration(self):
self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 15)
self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 135)

def test_empty_sample(self):
w = self.widget
self.send_signal(w.Inputs.data, self.iris)

self.set_fixed_sample_size(150)
self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 150)
self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 0)

self.set_fixed_sample_size(0)
self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 0)
self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 150)

self.set_fixed_proportion(100)
self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 150)
self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 0)

self.set_fixed_proportion(0)
self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 0)
self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 150)

def test_send_report(self):
w = self.widget
self.send_signal(w.Inputs.data, self.iris)
Expand Down

0 comments on commit 3423a8e

Please sign in to comment.