diff --git a/Orange/widgets/data/owdatasampler.py b/Orange/widgets/data/owdatasampler.py index 15cbf3f53a3..6f14556cebe 100644 --- a/Orange/widgets/data/owdatasampler.py +++ b/Orange/widgets/data/owdatasampler.py @@ -107,7 +107,7 @@ def set_sampling_type_i(): ibox = gui.indentedBox(sampling) self.sampleSizeSpin = gui.spin( ibox, self, "sampleSizeNumber", label="Instances: ", - minv=1, maxv=self._MAX_SAMPLE_SIZE, + minv=0, maxv=self._MAX_SAMPLE_SIZE, callback=set_sampling_type(self.FixedSize), controlWidth=90) gui.checkBox( @@ -395,11 +395,15 @@ def __call__(self, table): o[sample] = 0 others = np.nonzero(o)[0] return others, sample - if self.n == len(table): + if self.n in (0, len(table)): rgen = np.random.RandomState(self.random_state) - sample = np.arange(self.n) - rgen.shuffle(sample) - return np.array([], dtype=int), sample + shuffled = np.arange(len(table)) + rgen.shuffle(shuffled) + empty = np.array([], dtype=int) + if self.n == 0: + return shuffled, empty + else: + return empty, shuffled elif self.stratified and table.domain.has_discrete_class: test_size = max(len(table.domain.class_var.values), self.n) splitter = skl.StratifiedShuffleSplit( diff --git a/Orange/widgets/data/tests/test_owdatasampler.py b/Orange/widgets/data/tests/test_owdatasampler.py index c59ba560014..b0fa8bf6ae4 100644 --- a/Orange/widgets/data/tests/test_owdatasampler.py +++ b/Orange/widgets/data/tests/test_owdatasampler.py @@ -131,6 +131,13 @@ def set_fixed_sample_size(self, sample_size, with_replacement=False): self.widget.commit() return self.widget.sampleSizeSpin.value() + def set_fixed_proportion(self, proportion): + """Set fixed sample proportion. + """ + self.select_sampling_type(self.widget.FixedProportion) + self.widget.sampleSizePercentageSlider.setValue(proportion) + self.widget.commit() + def assertNoIntersection(self, sample, other): self.assertFalse(bool(set(sample.ids) & set(other.ids))) @@ -170,6 +177,26 @@ def test_cv_output_migration(self): self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 15) self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 135) + def test_empty_sample(self): + w = self.widget + self.send_signal(w.Inputs.data, self.iris) + + self.set_fixed_sample_size(150) + self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 150) + self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 0) + + self.set_fixed_sample_size(0) + self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 0) + self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 150) + + self.set_fixed_proportion(100) + self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 150) + self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 0) + + self.set_fixed_proportion(0) + self.assertEqual(len(self.get_output(w.Outputs.data_sample)), 0) + self.assertEqual(len(self.get_output(w.Outputs.remaining_data)), 150) + def test_send_report(self): w = self.widget self.send_signal(w.Inputs.data, self.iris)