Skip to content

Commit

Permalink
ENH simplification of hyperparam translator method as suggested by PR…
Browse files Browse the repository at this point in the history
… review
  • Loading branch information
dantegd committed Nov 18, 2024
1 parent 881ada1 commit b32495d
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 43 deletions.
10 changes: 5 additions & 5 deletions python/cuml/cuml/cluster/dbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,15 @@ class DBSCAN(UniversalBase,

_hyperparam_interop_translator = {
"metric": {
"manhattan": "dispatch",
"chebyshev": "dispatch",
"minkowski": "dispatch",
"manhattan": "NotImplemented",
"chebyshev": "NotImplemented",
"minkowski": "NotImplemented",
},

"algorithm": {
"auto": "brute",
"ball_tree": "dispatch",
"kd_tree": "dispatch",
"ball_tree": "NotImplemented",
"kd_tree": "NotImplemented",
},
}

Expand Down
4 changes: 0 additions & 4 deletions python/cuml/cuml/experimental/accel/estimator_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
from typing import Optional, Tuple, Dict


# currently we just use this dictionary for debugging purposes
patched_classes = {}


def intercept(
original_module: str,
accelerated_module: str,
Expand Down
42 changes: 12 additions & 30 deletions python/cuml/cuml/internals/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,6 @@ class Base(TagsMixin,
del base # optional!
"""

_base_hyperparam_interop_translator = {
"n_jobs": "accept"
}

_hyperparam_interop_translator = {}

def __init__(self, *,
Expand Down Expand Up @@ -484,37 +480,23 @@ class Base(TagsMixin,
This method is meant to do checks and translations of hyperparameters
at estimator creating time.
Each children estimator can override the method, returning either
modifier **kwargs with equivalent options, or
modifier **kwargs with equivalent options, or setting gpuaccel to False
for hyperaparameters not supported by cuML yet.
"""
gpu_hyperparams = cls._get_param_names()
kwargs.pop("self", None)
gpuaccel = True
for arg, value in kwargs.items():

if arg in cls._base_hyperparam_interop_translator:
if cls._base_hyperparam_interop_translator[arg] == "accept":
gpuaccel = gpuaccel and True

elif arg in cls._hyperparam_interop_translator:
if value in cls._hyperparam_interop_translator[arg]:
if cls._hyperparam_interop_translator[arg][value] == "accept":
gpuaccel = gpuaccel and True
elif cls._hyperparam_interop_translator[arg][value] == "dispatch":
# Copy it so we can modify it
translations = dict(cls.__bases__[0]._hyperparam_interop_translator)
# Allow the derived class to overwrite the base class
translations.update(cls._hyperparam_interop_translator)
for parameter_name, value in kwargs.items():
# maybe clean up using: translations.get(parameter_name, {}).get(value, None)?
if parameter_name in translations:
if value in translations[parameter_name]:
if translations[parameter_name][value] == "NotImplemented":
gpuaccel = False
else:
kwargs[arg] = cls._hyperparam_interop_translator[arg][value]
gpuaccel = gpuaccel and True
# todo (dgd): improve message
logger.warn("Value changed")

else:
gpuaccel = gpuaccel and True

# else:
# gpuaccel = False
kwargs[parameter_name] = translations[parameter_name][value]

# we need to enable this if we enable translation for regular cuML
# kwargs["_gpuaccel"] = gpuaccel
return kwargs, gpuaccel


Expand Down
2 changes: 0 additions & 2 deletions python/cuml/cuml/linear_model/elastic_net.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,9 @@ class ElasticNet(UniversalBase,
_hyperparam_interop_translator = {
"positive": {
True: "dispatch",
False: "accept",
},
"warm_start": {
True: "dispatch",
False: "accept",
},
}

Expand Down
3 changes: 1 addition & 2 deletions python/cuml/cuml/linear_model/linear_regression.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,7 @@ class LinearRegression(LinearPredictMixin,

_hyperparam_interop_translator = {
"positive": {
True: "dispatch",
False: "accept",
True: "NotImplemented",
},
}

Expand Down

0 comments on commit b32495d

Please sign in to comment.