diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 3bcc3e63..1575a250 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -99,6 +99,10 @@ Lastly, you should make sure that the existing tests all run successfully and th
```bash
# run tests and make sure they all pass
pytest tests/
+
+# run doctest (run all examples in docstrings and match output)
+pytest --doctest-modules src/nemos/
+
# format the code base
black src/
isort src
@@ -184,38 +188,86 @@ properly documented as outlined below.
#### Adding documentation
-1) **Docstrings**
-
-All public-facing functions and classes should have complete docstrings, which start with a one-line short summary of the function,
-a medium-length description of the function / class and what it does, and a complete description of all arguments and return values.
-Math should be included in a `Notes` section when necessary to explain what the function is doing, and references to primary literature
-should be included in a `References` section when appropriate. Docstrings should be relatively short, providing the information necessary
-for a user to use the code.
-
-Private functions and classes should have sufficient explanation that other developers know what the function / class does and how to use it,
-but do not need to be as extensive.
-
-We follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/) conventions for docstring structure.
-
-2) **Examples/Tutorials**
-
-If your changes are significant (add a new functionality or drastically change the current codebase), then the current examples may need to be updated or
-a new example may need to be added.
-
-All examples live within the `docs/` subfolder of `nemos`. These are written as `.py` files but are converted to
-notebooks by [`mkdocs-gallery`](https://smarie.github.io/mkdocs-gallery/), and have a special syntax, as demonstrated in this [example
-gallery](https://smarie.github.io/mkdocs-gallery/generated/gallery/).
-
-We avoid using `.ipynb` notebooks directly because their JSON-based format makes them difficult to read, interpret, and resolve merge conflicts in version control.
-
-To see if changes you have made break the current documentation, you can build the documentation locally.
-
-```bash
-# Clear the cached documentation pages
-# This step is only necessary if your changes affected the src/ directory
-rm -r docs/generated
-# build the docs within the nemos repo
-mkdocs build
-```
-
-If the build fails, you will see line-specific errors that prompted the failure.
+1. **Docstrings**
+
+ All public-facing functions and classes should have complete docstrings, which start with a one-line short summary of the function, a medium-length description of the function/class and what it does, a complete description of all arguments and return values, and an example to illustrate usage. Math should be included in a `Notes` section when necessary to explain what the function is doing, and references to primary literature should be included in a `References` section when appropriate. Docstrings should be relatively short, providing the information necessary for a user to use the code.
+
+ Private functions and classes should have sufficient explanation that other developers know what the function/class does and how to use it, but do not need to be as extensive.
+
+ We follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/) conventions for docstring structure.
+
+2. **Examples/Tutorials**
+
+ If your changes are significant (add a new functionality or drastically change the current codebase), then the current examples may need to be updated or a new example may need to be added.
+
+ All examples live within the `docs/` subfolder of `nemos`. These are written as `.py` files but are converted to notebooks by [`mkdocs-gallery`](https://smarie.github.io/mkdocs-gallery/), and have a special syntax, as demonstrated in this [example gallery](https://smarie.github.io/mkdocs-gallery/generated/gallery/).
+
+ We avoid using `.ipynb` notebooks directly because their JSON-based format makes them difficult to read, interpret, and resolve merge conflicts in version control.
+
+ To see if changes you have made break the current documentation, you can build the documentation locally.
+
+ ```
+ # Clear the cached documentation pages
+ # This step is only necessary if your changes affected the src/ directory
+ rm -r docs/generated
+ # build the docs within the nemos repo
+ mkdocs build
+ ```
+
+ If the build fails, you will see line-specific errors that prompted the failure.
+
+3. **Doctest: Test the example code in your docs**
+
+ Doctests are a great way to ensure that code examples in your documentation remain accurate as the codebase evolves. With doctests, we will test any docstrings, Markdown files, or any other text-based documentation that contains code formatted as interactive Python sessions.
+
+ - **Docstrings:**
+ To include doctests in your function and class docstrings you must add an `Examples` section. The examples should be formatted as if you were typing them into a Python interactive session, with `>>>` used to indicate commands and expected outputs listed immediately below.
+
+ ```python
+ def add(a, b):
+ """
+ The sum of two numbers.
+
+ ...Other docstrings sections (Parameters, Returns...)
+
+ Examples
+ --------
+ An expected output is required.
+ >>> add(1, 2)
+ 3
+
+ Unless the output is captured.
+ >>> out = add(1, 2)
+
+ """
+ return a + b
+ ```
+
+ To validate all your docstrings examples, run pytest `--doctest-module` flag,
+
+ ```
+ pytest --doctest-modules src/nemos/
+ ```
+
+ This test is part of the Continuous Integration, every example must pass before we can merge a PR.
+
+ - **Documentation Pages:**
+ Doctests can also be included in Markdown files by using code blocks with the `python` language identifier and interactive Python examples. To enable this functionality, ensure that code blocks follow the standard Python doctest format:
+
+ ```markdown
+ ```python
+ >>> # Add any code
+ >>> x = 3 ** 2
+ >>> x + 1
+ 10
+
+ ```
+ ```
+
+ To run doctests on a text file, use the following command:
+
+ ```
+ python -m doctest -v path-to-your-text-file/file_name.md
+ ```
+
+ All MarkDown files will be tested as part of the Continuous Integration.
diff --git a/README.md b/README.md
index 40470e6e..e9abf895 100644
--- a/README.md
+++ b/README.md
@@ -177,3 +177,8 @@ We communicate via several channels on Github:
In all cases, we request that you respect our [code of
conduct](CODE_OF_CONDUCT.md).
+## Support
+
+This package is supported by the Center for Computational Neuroscience, in the Flatiron Institute of the Simons Foundation.
+
+
diff --git a/docs/assets/CCN-logo-wText.png b/docs/assets/CCN-logo-wText.png
new file mode 100644
index 00000000..206f7693
Binary files /dev/null and b/docs/assets/CCN-logo-wText.png differ
diff --git a/docs/how_to_guide/plot_04_population_glm.py b/docs/how_to_guide/plot_04_population_glm.py
index 84282477..70dac9cd 100644
--- a/docs/how_to_guide/plot_04_population_glm.py
+++ b/docs/how_to_guide/plot_04_population_glm.py
@@ -23,9 +23,10 @@
"""
import jax.numpy as jnp
-import nemos as nmo
-import numpy as np
import matplotlib.pyplot as plt
+import numpy as np
+
+import nemos as nmo
np.random.seed(123)
diff --git a/docs/how_to_guide/plot_05_batch_glm.py b/docs/how_to_guide/plot_05_batch_glm.py
index f9e758fc..84f64d98 100644
--- a/docs/how_to_guide/plot_05_batch_glm.py
+++ b/docs/how_to_guide/plot_05_batch_glm.py
@@ -6,10 +6,11 @@
"""
+import matplotlib.pyplot as plt
+import numpy as np
import pynapple as nap
+
import nemos as nmo
-import numpy as np
-import matplotlib.pyplot as plt
nap.nap_config.suppress_conversion_warnings = True
diff --git a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
index b7168e33..ca9b167a 100644
--- a/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
+++ b/docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.py
@@ -71,20 +71,19 @@
# ## Combining basis transformations and GLM in a pipeline
# Let's start by creating some toy data.
-import nemos as nmo
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
-import matplotlib.pyplot as plt
import seaborn as sns
-
-from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
+from sklearn.pipeline import Pipeline
+
+import nemos as nmo
# some helper plotting functions
from nemos import _documentation_utils as doc_plots
-
# predictors, shape (n_samples, n_features)
X = np.random.uniform(low=0, high=1, size=(1000, 1))
# observed counts, shape (n_samples,)
diff --git a/docs/index.md b/docs/index.md
index 7c840e1b..562491b5 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -103,3 +103,10 @@ We provide a **Poisson GLM** for analyzing spike counts, and a **Gamma GLM** for
## :material-scale-balance:{ .lg } License
Open source, [licensed under MIT](https://github.com/flatironinstitute/nemos/blob/main/LICENSE).
+
+
+## Support
+
+This package is supported by the Center for Computational Neuroscience, in the Flatiron Institute of the Simons Foundation.
+
+
diff --git a/pyproject.toml b/pyproject.toml
index a34bb780..5ee288b2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,6 +50,10 @@ dev = [
"pytest-cov", # Test coverage plugin for pytest
"statsmodels", # Used to compare model pseudo-r2 in testing
"scikit-learn", # Testing compatibility with CV & pipelines
+ "matplotlib>=3.7", # Needed by doctest to run docstrings examples
+ "pooch", # Required by doctest for fetch module
+ "dandi", # Required by doctest for fetch module
+ "seaborn", # Required by doctest for _documentation_utils module
]
docs = [
"mkdocs", # Documentation generator
@@ -112,7 +116,7 @@ testpaths = ["tests"] # Specify the directory where test files are l
[tool.coverage.run]
omit = [
"src/nemos/fetch/*",
- "src/nemos/_documentation_utils/*"
+ "src/nemos/_documentation_utils/*",
]
[tool.coverage.report]
diff --git a/src/nemos/basis.py b/src/nemos/basis.py
index 2cc48f95..f5907480 100644
--- a/src/nemos/basis.py
+++ b/src/nemos/basis.py
@@ -152,15 +152,15 @@ class TransformerBasis:
>>> # transformer can be used in pipelines
>>> transformer = TransformerBasis(basis)
>>> pipeline = Pipeline([ ("compute_features", transformer), ("glm", GLM()),])
- >>> pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API
- >>> print(pipeline.predict(np.random.normal(size=(10, 1)))) # predict rate from new data
-
+ >>> pipeline = pipeline.fit(x[:, None], y) # x need to be 2D for sklearn transformer API
+ >>> out = pipeline.predict(np.arange(10)[:, None]) # predict rate from new datas
>>> # TransformerBasis parameter can be cross-validated.
>>> # 5-fold cross-validate the number of basis
>>> param_grid = dict(compute_features__n_basis_funcs=[4, 10])
>>> grid_cv = GridSearchCV(pipeline, param_grid, cv=5)
- >>> grid_cv.fit(x[:, None], y)
+ >>> grid_cv = grid_cv.fit(x[:, None], y)
>>> print("Cross-validated number of basis:", grid_cv.best_params_)
+ Cross-validated number of basis: {'compute_features__n_basis_funcs': 10}
"""
def __init__(self, basis: Basis):
@@ -289,7 +289,7 @@ def __getattr__(self, name: str):
return getattr(self._basis, name)
def __setattr__(self, name: str, value) -> None:
- """
+ r"""
Allow setting _basis or the attributes of _basis with a convenient dot assignment syntax.
Setting any other attribute is not allowed.
@@ -312,10 +312,11 @@ def __setattr__(self, name: str, value) -> None:
>>> # allowed
>>> trans_bas.n_basis_funcs = 20
>>> # not allowed
- >>> tran_bas.random_attribute_name = "some value"
- Traceback (most recent call last):
- ...
- ValueError: Only setting _basis or existing attributes of _basis is allowed.
+ >>> try:
+ ... trans_bas.random_attribute_name = "some value"
+ ... except ValueError as e:
+ ... print(repr(e))
+ ValueError('Only setting _basis or existing attributes of _basis is allowed.')
"""
# allow self._basis = basis
if name == "_basis":
@@ -343,7 +344,7 @@ def __sklearn_clone__(self) -> TransformerBasis:
return cloned_obj
def set_params(self, **parameters) -> TransformerBasis:
- """
+ r"""
Set TransformerBasis parameters.
When used with `sklearn.model_selection`, users can set either the `_basis` attribute directly
@@ -357,12 +358,16 @@ def set_params(self, **parameters) -> TransformerBasis:
>>> # setting parameters of _basis is allowed
>>> print(transformer_basis.set_params(n_basis_funcs=8).n_basis_funcs)
-
+ 8
>>> # setting _basis directly is allowed
- >>> print(transformer_basis.set_params(_basis=BSplineBasis(10))._basis)
-
+ >>> print(type(transformer_basis.set_params(_basis=BSplineBasis(10))._basis))
+
>>> # mixing is not allowed, this will raise an exception
- >>> transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2)
+ >>> try:
+ ... transformer_basis.set_params(_basis=BSplineBasis(10), n_basis_funcs=2)
+ ... except ValueError as e:
+ ... print(repr(e))
+ ValueError('Set either new _basis object or parameters for existing _basis, not both.')
"""
new_basis = parameters.pop("_basis", None)
if new_basis is not None:
@@ -996,8 +1001,8 @@ def to_transformer(self) -> TransformerBasis:
>>> from sklearn.pipeline import Pipeline
>>> from sklearn.model_selection import GridSearchCV
>>> # load some data
- >>> X, y = ... # X: features, y: neural activity
- >>> basis = nmo.basis.RaisedCosineBasisLinear(10)
+ >>> X, y = np.random.normal(size=(30, 1)), np.random.poisson(size=30)
+ >>> basis = nmo.basis.RaisedCosineBasisLinear(10).to_transformer()
>>> glm = nmo.glm.GLM(regularizer="Ridge")
>>> pipeline = Pipeline([("basis", basis), ("glm", glm)])
>>> param_grid = dict(
@@ -1009,7 +1014,7 @@ def to_transformer(self) -> TransformerBasis:
... param_grid=param_grid,
... cv=5,
... )
- >>> gridsearch.fit(X, y)
+ >>> gridsearch = gridsearch.fit(X, y)
"""
return TransformerBasis(copy.deepcopy(self))
@@ -1346,7 +1351,7 @@ def _check_n_basis_min(self) -> None:
class MSplineBasis(SplineBasis):
- r"""
+ """
M-spline[$^{[1]}$](#references) basis functions for modeling and data transformation.
M-splines are a type of spline basis function used for smooth curve fitting
@@ -1502,12 +1507,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
>>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100)
>>> for i in range(4):
- ... plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}')
+ ... p = plt.plot(sample_points, basis_values[:, i], label=f'Function {i+1}')
>>> plt.title('M-Spline Basis Functions')
+ Text(0.5, 1.0, 'M-Spline Basis Functions')
>>> plt.xlabel('Domain')
+ Text(0.5, 0, 'Domain')
>>> plt.ylabel('Basis Function Value')
- >>> plt.legend()
- >>> plt.show()
+ Text(0, 0.5, 'Basis Function Value')
+ >>> l = plt.legend()
"""
return super().evaluate_on_grid(n_samples)
diff --git a/src/nemos/exceptions.py b/src/nemos/exceptions.py
index 4537aafb..8e3caa29 100644
--- a/src/nemos/exceptions.py
+++ b/src/nemos/exceptions.py
@@ -15,6 +15,5 @@ class NotFittedError(ValueError, AttributeError):
... GLM().predict([[[1, 2], [2, 3], [3, 4]]])
... except NotFittedError as e:
... print(repr(e))
- ... # NotFittedError("This GLM instance is not fitted yet. Call 'fit' with
- ... # appropriate arguments.")
+ NotFittedError("This GLM instance is not fitted yet. Call 'fit' with appropriate arguments.")
"""
diff --git a/src/nemos/fetch/fetch_data.py b/src/nemos/fetch/fetch_data.py
index a5b76b5c..4ee9bbb9 100644
--- a/src/nemos/fetch/fetch_data.py
+++ b/src/nemos/fetch/fetch_data.py
@@ -143,10 +143,21 @@ def download_dandi_data(dandiset_id: str, filepath: str) -> NWBHDF5IO:
Examples
--------
>>> import nemos as nmo
+ >>> import pynapple as nap
>>> io = nmo.fetch.download_dandi_data("000582",
- "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb")
+ ... "sub-11265/sub-11265_ses-07020602_behavior+ecephys.nwb")
>>> nwb = nap.NWBFile(io.read(), lazy_loading=False)
>>> print(nwb)
+ 07020602
+ ┍━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━┑
+ │ Keys │ Type │
+ ┝━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━┥
+ │ units │ TsGroup │
+ │ ElectricalSeriesLFP │ Tsd │
+ │ SpatialSeriesLED2 │ TsdFrame │
+ │ SpatialSeriesLED1 │ TsdFrame │
+ │ ElectricalSeries │ Tsd │
+ ┕━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━┙
"""
if dandi is None:
diff --git a/src/nemos/glm.py b/src/nemos/glm.py
index bec4304f..03684859 100644
--- a/src/nemos/glm.py
+++ b/src/nemos/glm.py
@@ -523,11 +523,11 @@ def _initialize_parameters(
>>> import numpy as np
>>> X = np.zeros((100, 5)) # Example input
>>> y = np.exp(np.random.normal(size=(100, ))) # Simulated firing rates
- >>> coeff, intercept = nmo.glm.GLM._initialize_parameters(X, y)
+ >>> coeff, intercept = nmo.glm.GLM()._initialize_parameters(X, y)
>>> coeff.shape
- (5, )
+ (5,)
>>> intercept.shape
- (1, )
+ (1,)
"""
if isinstance(X, FeaturePytree):
data = X.data
@@ -823,9 +823,12 @@ def initialize_params(
Examples
--------
- >>> X, y = load_data() # Hypothetical function to load data
+ >>> import numpy as np
+ >>> import nemos as nmo
+ >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
+ >>> model = nmo.glm.GLM()
>>> params = model.initialize_params(X, y)
- >>> opt_state = model.initialize_state(X, y)
+ >>> opt_state = model.initialize_state(X, y, params)
>>> # Now ready to run optimization or update steps
"""
if init_params is None:
@@ -950,7 +953,10 @@ def update(
Examples
--------
- >>> # Assume glm_instance is an instance of GLM that has been previously fitted.
+ >>> import nemos as nmo
+ >>> import numpy as np
+ >>> X, y = np.random.normal(size=(10, 2)), np.random.uniform(size=10)
+ >>> glm_instance = nmo.glm.GLM().fit(X, y)
>>> params = glm_instance.coef_, glm_instance.intercept_
>>> opt_state = glm_instance.solver_state_
>>> new_params, new_opt_state = glm_instance.update(params, opt_state, X, y)
@@ -1057,15 +1063,15 @@ class PopulationGLM(GLM):
>>> y = np.random.poisson(np.exp(X.dot(weights)))
>>> # Define a feature mask, shape (num_features, num_neurons)
>>> feature_mask = jnp.array([[1, 0], [1, 1], [0, 1]])
- >>> print("Feature mask:")
- >>> print(feature_mask)
+ >>> feature_mask
+ Array([[1, 0],
+ [1, 1],
+ [0, 1]], dtype=int32)
>>> # Create and fit the model
- >>> model = PopulationGLM(feature_mask=feature_mask)
- >>> model.fit(X, y)
- >>> # Check the fitted coefficients and intercepts
- >>> print("Model coefficients:")
- >>> print(model.coef_)
-
+ >>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
+ >>> # Check the fitted coefficients
+ >>> print(model.coef_.shape)
+ (3, 2)
>>> # Example with a FeaturePytree mask
>>> from nemos.pytrees import FeaturePytree
>>> # Define two features
@@ -1078,14 +1084,17 @@ class PopulationGLM(GLM):
>>> rate = np.exp(X["feature_1"].dot(weights["feature_1"]) + X["feature_2"].dot(weights["feature_2"]))
>>> y = np.random.poisson(rate)
>>> # Define a feature mask with arrays of shape (num_neurons, )
+
>>> feature_mask = FeaturePytree(feature_1=jnp.array([0, 1]), feature_2=jnp.array([1, 0]))
- >>> print("Feature mask:")
>>> print(feature_mask)
+ feature_1: shape (2,), dtype int32
+ feature_2: shape (2,), dtype int32
+
>>> # Fit a PopulationGLM
- >>> model = PopulationGLM(feature_mask=feature_mask)
- >>> model.fit(X, y)
- >>> print("Model coefficients:")
- >>> print(model.coef_)
+ >>> model = PopulationGLM(feature_mask=feature_mask).fit(X, y)
+ >>> # Coefficients are stored in a dictionary with keys the feature labels
+ >>> print(model.coef_.keys())
+ dict_keys(['feature_1', 'feature_2'])
"""
def __init__(
diff --git a/src/nemos/observation_models.py b/src/nemos/observation_models.py
index ecb7e76b..9d683ae1 100644
--- a/src/nemos/observation_models.py
+++ b/src/nemos/observation_models.py
@@ -846,7 +846,7 @@ def estimate_scale(
def check_observation_model(observation_model):
- """
+ r"""
Check the attributes of an observation model for compliance.
This function ensures that the observation model has the required attributes and that each
@@ -877,10 +877,10 @@ def check_observation_model(observation_model):
... def _negative_log_likelihood(self, params, y_true, aggregate_sample_scores=jnp.mean):
... return -aggregate_sample_scores(y_true * jax.scipy.special.logit(params) + \
... (1 - y_true) * jax.scipy.special.logit(1 - params))
- ... def pseudo_r2(self, params, y_true, aggregate_sample_scores):
+ ... def pseudo_r2(self, params, y_true, aggregate_sample_scores=jnp.mean):
... return 1 - (self._negative_log_likelihood(y_true, params, aggregate_sample_scores) /
... jnp.sum((y_true - y_true.mean()) ** 2))
- ... def sample_generator(self, key, params):
+ ... def sample_generator(self, key, params, scale=1.):
... return jax.random.bernoulli(key, params)
>>> model = MyObservationModel()
>>> check_observation_model(model) # Should pass without error if the model is correctly implemented.
diff --git a/src/nemos/regularizer.py b/src/nemos/regularizer.py
index 91e59f51..bfe72a32 100644
--- a/src/nemos/regularizer.py
+++ b/src/nemos/regularizer.py
@@ -352,10 +352,11 @@ class GroupLasso(Regularizer):
>>> mask[2] = [0, 0, 1, 0, 1] # Group 2 includes features 2 and 4
>>> # Create the GroupLasso regularizer instance
- >>> group_lasso = GroupLasso(regularizer_strength=0.1, mask=mask)
+ >>> group_lasso = GroupLasso(mask=mask)
>>> # fit a group-lasso glm
>>> model = GLM(regularizer=group_lasso).fit(X, y)
- >>> print(f"coeff: {model.coef_}")
+ >>> print(f"coeff shape: {model.coef_.shape}")
+ coeff shape: (5,)
"""
_allowed_solvers = (
diff --git a/src/nemos/solvers.py b/src/nemos/solvers.py
index d1b2deeb..4c060609 100644
--- a/src/nemos/solvers.py
+++ b/src/nemos/solvers.py
@@ -80,11 +80,13 @@ class ProxSVRG:
Examples
--------
- >>> def loss_fn(params, X, y):
- >>> ...
- >>>
- >>> svrg = ProxSVRG(loss_fn, prox_fun)
- >>> params, state = svrg.run(init_params, hyperparams_prox, X, y)
+ >>> import numpy as np
+ >>> from jaxopt.prox import prox_lasso
+ >>> loss_fn = lambda params, X, y: ((X.dot(params) - y)**2).sum()
+ >>> svrg = ProxSVRG(loss_fn, prox_lasso)
+ >>> hyperparams_prox = 0.1
+ >>> params, state = svrg.run(np.zeros(2), hyperparams_prox, np.ones((10, 2)), np.zeros(10))
+
References
----------
@@ -615,11 +617,10 @@ class SVRG(ProxSVRG):
Examples
--------
- >>> def loss_fn(params, X, y):
- >>> ...
- >>>
+ >>> import numpy as np
+ >>> loss_fn = lambda params, X, y: ((X.dot(params) - y)**2).sum()
>>> svrg = SVRG(loss_fn)
- >>> params, state = svrg.run(init_params, X, y)
+ >>> params, state = svrg.run(np.zeros(2), np.ones((10, 2)), np.zeros(10))
References
----------
diff --git a/tox.ini b/tox.ini
index d7761318..36430d2f 100644
--- a/tox.ini
+++ b/tox.ini
@@ -19,6 +19,7 @@ commands =
isort docs/background --profile=black
isort docs/tutorials --profile=black
flake8 --config={toxinidir}/tox.ini src
+ pytest --doctest-modules src/nemos/
pytest --cov=nemos --cov-config=pyproject.toml --cov-report=xml
[gh-actions]