Skip to content

Commit

Permalink
Merge pull request #296 from flatironinstitute/hotfix_repr
Browse files Browse the repository at this point in the history
Hotfix repr
  • Loading branch information
BalzaniEdoardo authored Jan 21, 2025
2 parents ec5b381 + d329376 commit c355d28
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "nemos"
version = "0.2.0"
version = "0.2.1"
authors = [{name = "nemos authors"}]
description = "NEural MOdelS, a statistical modeling framework for neuroscience."
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
__version__ = "0.2.0"
__version__ = "0.2.1"

from . import (
basis,
Expand Down
8 changes: 7 additions & 1 deletion src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,12 @@ def _get_optimal_solver_params_config(self):
def __repr__(self):
return format_repr(self, multiline=True)

def __sklearn_clone__(self) -> GLM:
"""Clone the PopulationGLM, dropping feature_mask"""
params = self.get_params(deep=False)
klass = self.__class__(**params)
return klass


class PopulationGLM(GLM):
"""
Expand Down Expand Up @@ -1633,7 +1639,7 @@ def _predict(
+ bs
)

def __sklearn_clone__(self) -> GLM:
def __sklearn_clone__(self) -> PopulationGLM:
"""Clone the PopulationGLM, dropping feature_mask"""
params = self.get_params(deep=False)
params.pop("feature_mask")
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def format_repr(
)
if repr_param:
if k in use_name_keys:
v = v.__name__
v = getattr(v, "__name__", repr(v))
elif isinstance(v, str):
v = repr(v)
disp_params.append(f"{k}={v}")
Expand Down
16 changes: 12 additions & 4 deletions tests/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,12 +1478,16 @@ def test_deviance_against_statsmodels(self, poissonGLM_model_instantiation):
def test_compatibility_with_sklearn_cv(self, poissonGLM_model_instantiation):
X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

def test_compatibility_with_sklearn_cv_gamma(self, gammaGLM_model_instantiation):
X, y, model, true_params, firing_rate = gammaGLM_model_instantiation
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

@pytest.mark.parametrize(
"regr_setup, glm_class",
Expand Down Expand Up @@ -3572,12 +3576,16 @@ def test_deviance_against_statsmodels(self, poisson_population_GLM_model):
def test_compatibility_with_sklearn_cv(self, poisson_population_GLM_model):
X, y, model, true_params, firing_rate = poisson_population_GLM_model
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

def test_compatibility_with_sklearn_cv_gamma(self, gamma_population_GLM_model):
X, y, model, true_params, firing_rate = gamma_population_GLM_model
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

def test_sklearn_clone(self, poisson_population_GLM_model):
X, y, model, true_params, firing_rate = poisson_population_GLM_model
Expand Down
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from contextlib import nullcontext as does_not_raise
from copy import deepcopy

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -595,6 +596,8 @@ def __repr__(self):
(Example(a=0, b=False, c=None), None, [], "Example(a=0, b=False, d=1)"),
# Falsey values excluded2
(Example(a=0, b=[], c={}), None, [], "Example(a=0, d=1)"),
# function without the __name__
(nmo.observation_models.PoissonObservations(deepcopy(jax.numpy.exp)),None, [], "PoissonObservations(inverse_link_function=<PjitFunction>)")
],
)
def test_format_repr(obj, exclude_keys, use_name_keys, expected):
Expand Down

0 comments on commit c355d28

Please sign in to comment.