Skip to content

Commit

Permalink
Merge pull request #2138 from pavlin-policar/fitter-sklearn-quick
Browse files Browse the repository at this point in the history
[FIX] Fitter: Fix used_vals and params not being set
  • Loading branch information
janezd authored Mar 31, 2017
2 parents 407db85 + efab929 commit 99ca084
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 51 deletions.
4 changes: 2 additions & 2 deletions Orange/modelling/ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from Orange.ensembles import (
SklAdaBoostClassificationLearner, SklAdaBoostRegressionLearner
)
from Orange.modelling import Fitter
from Orange.modelling import SklFitter

__all__ = ['SklAdaBoostLearner']


class SklAdaBoostLearner(Fitter):
class SklAdaBoostLearner(SklFitter):
__fits__ = {'classification': SklAdaBoostClassificationLearner,
'regression': SklAdaBoostRegressionLearner}

Expand Down
30 changes: 14 additions & 16 deletions Orange/modelling/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,9 @@
from Orange.base import Learner, Model
import numpy as np

from Orange.base import Learner, Model, SklLearner

class FitterMeta(type):
"""Ensure that each subclass of the `Fitter` class overrides the `__fits__`
attribute with a valid value."""
def __new__(mcs, name, bases, attrs):
# Check that a fitter implementation defines a valid `__fits__`
if any(cls.__name__ == 'Fitter' for cls in bases):
fits = attrs.get('__fits__')
assert isinstance(fits, dict), '__fits__ must be dict instance'
assert fits.get('classification') and fits.get('regression'), \
('`__fits__` property does not define classification '
'or regression learner. Use a simple learner if you don\'t '
'need the functionality provided by Fitter.')
return super().__new__(mcs, name, bases, attrs)


class Fitter(Learner, metaclass=FitterMeta):
class Fitter(Learner):
"""Handle multiple types of target variable with one learner.
Subclasses of this class serve as a sort of dispatcher. When subclassing,
Expand Down Expand Up @@ -119,3 +106,14 @@ def params(self):
def get_params(self, problem_type):
"""Access the specific learner params of a given learner."""
return self.get_learner(problem_type).params


class SklFitter(Fitter):
def _fit_model(self, data):
model = super()._fit_model(data)
model.used_vals = [np.unique(y) for y in data.Y[:, None].T]
if data.domain.has_discrete_class:
model.params = self.get_params(self.CLASSIFICATION)
else:
model.params = self.get_params(self.REGRESSION)
return model
4 changes: 2 additions & 2 deletions Orange/modelling/knn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from Orange.classification import KNNLearner as KNNClassification
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import KNNRegressionLearner

__all__ = ['KNNLearner']


class KNNLearner(Fitter):
class KNNLearner(SklFitter):
__fits__ = {'classification': KNNClassification,
'regression': KNNRegressionLearner}
4 changes: 2 additions & 2 deletions Orange/modelling/linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from Orange.classification.sgd import SGDClassificationLearner
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import SGDRegressionLearner

__all__ = ['SGDLearner']


class SGDLearner(Fitter):
class SGDLearner(SklFitter):
name = 'sgd'

__fits__ = {'classification': SGDClassificationLearner,
Expand Down
4 changes: 2 additions & 2 deletions Orange/modelling/neural_network.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from Orange.classification import NNClassificationLearner
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import NNRegressionLearner

__all__ = ['NNLearner']


class NNLearner(Fitter):
class NNLearner(SklFitter):
__fits__ = {'classification': NNClassificationLearner,
'regression': NNRegressionLearner}
4 changes: 2 additions & 2 deletions Orange/modelling/randomforest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from Orange.base import RandomForestModel
from Orange.classification import RandomForestLearner as RFClassification
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import RandomForestRegressionLearner as RFRegression

__all__ = ['RandomForestLearner']


class RandomForestLearner(Fitter):
class RandomForestLearner(SklFitter):
name = 'random forest'

__fits__ = {'classification': RFClassification,
Expand Down
8 changes: 4 additions & 4 deletions Orange/modelling/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
LinearSVMLearner as LinearSVCLearner,
NuSVMLearner as NuSVCLearner,
)
from Orange.modelling import Fitter
from Orange.modelling import SklFitter
from Orange.regression import SVRLearner, LinearSVRLearner, NuSVRLearner

__all__ = ['SVMLearner', 'LinearSVMLearner', 'NuSVMLearner']


class SVMLearner(Fitter):
class SVMLearner(SklFitter):
__fits__ = {'classification': SVCLearner, 'regression': SVRLearner}


class LinearSVMLearner(Fitter):
class LinearSVMLearner(SklFitter):
__fits__ = {'classification': LinearSVCLearner, 'regression': LinearSVRLearner}


class NuSVMLearner(Fitter):
class NuSVMLearner(SklFitter):
__fits__ = {'classification': NuSVCLearner, 'regression': NuSVRLearner}
4 changes: 2 additions & 2 deletions Orange/modelling/tree.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from Orange.classification import SklTreeLearner
from Orange.classification import TreeLearner as ClassificationTreeLearner
from Orange.modelling import Fitter
from Orange.modelling import Fitter, SklFitter
from Orange.regression import TreeLearner as RegressionTreeLearner
from Orange.regression.tree import SklTreeRegressionLearner
from Orange.tree import TreeModel

__all__ = ['SklTreeLearner', 'TreeLearner']


class SklTreeLearner(Fitter):
class SklTreeLearner(SklFitter):
name = 'tree'

__fits__ = {'classification': SklTreeLearner,
Expand Down
19 changes: 0 additions & 19 deletions Orange/tests/test_fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,6 @@ def setUpClass(cls):
cls.heart_disease = Table('heart_disease')
cls.housing = Table('housing')

def test_throws_if_fits_property_is_invalid(self):
"""The `__fits__` attribute must be an instance of `LearnerTypes`"""
with self.assertRaises(Exception):
class DummyFitter(Fitter):
name = 'dummy'
__fits__ = (DummyClassificationLearner, DummyRegressionLearner)

fitter = DummyFitter()
fitter(self.heart_disease)

def test_dispatches_to_correct_learner(self):
"""Based on the input data, it should dispatch the fitting process to
the appropriate learner"""
Expand Down Expand Up @@ -102,15 +92,6 @@ class DummyFitter(Fitter):
except TypeError:
self.fail('Fitter did not properly distribute params to learners')

def test_error_for_data_type_with_no_learner(self):
"""If we attempt to define a fitter which only handles one data type
it makes more sense to simply use a Learner."""
with self.assertRaises(AssertionError):
class DummyFitter(Fitter):
name = 'dummy'
__fits__ = {'classification': None,
'regression': DummyRegressionLearner}

def test_correctly_sets_preprocessors_on_learner(self):
"""Fitters have to be able to pass the `use_default_preprocessors` and
preprocessors down to individual learners"""
Expand Down

0 comments on commit 99ca084

Please sign in to comment.