Skip to content

Commit

Permalink
Merge pull request #253 from flatironinstitute/basis_examples_section
Browse files Browse the repository at this point in the history
Started basis examples
  • Loading branch information
BalzaniEdoardo authored Oct 31, 2024
2 parents 6b3a1e8 + e901856 commit d4f6524
Showing 1 changed file with 214 additions and 2 deletions.
216 changes: 214 additions & 2 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,19 @@ def fit(self, X: FeatureMatrix, y=None):
-------
self :
The transformer object.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X = np.random.normal(size=(100, 2))
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> transformer_fitted = transformer.fit(X)
"""
self._basis._set_kernel(*self._unpack_inputs(X))
return self
Expand All @@ -223,6 +236,28 @@ def transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
-------
:
The data transformed by the basis functions.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X = np.random.normal(size=(10000, 2))
>>> # Define and fit tranformation basis
>>> basis = MSplineBasis(10, mode="conv", window_size=200)
>>> transformer = TransformerBasis(basis)
>>> # Before calling `fit` the convolution kernel is not set
>>> transformer.kernel_
>>> transformer_fitted = transformer.fit(X)
>>> # Now the convolution kernel is initialized and has shape (window_size, n_basis_funcs)
>>> transformer_fitted.kernel_.shape
(200, 10)
>>> # Transform basis
>>> feature_transformed = transformer.transform(X[:, 0:1])
"""
# transpose does not work with pynapple
# can't use func(*X.T) to unwrap
Expand All @@ -248,6 +283,21 @@ def fit_transform(self, X: FeatureMatrix, y=None) -> FeatureMatrix:
array-like
The data transformed by the basis functions, after fitting the basis
functions to the data.
Examples
--------
>>> import numpy as np
>>> from nemos.basis import MSplineBasis, TransformerBasis
>>> # Example input
>>> X = np.random.normal(size=(100, 1))
>>> # Define tranformation basis
>>> basis = MSplineBasis(10)
>>> transformer = TransformerBasis(basis)
>>> # Fit and transform basis
>>> feature_transformed = transformer.fit_transform(X)
"""
return self._basis.compute_features(*self._unpack_inputs(X))

Expand Down Expand Up @@ -705,6 +755,19 @@ def compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
input samples with the basis functions. The output shape varies based on
the subclass and mode.
Examples
-------
>>> import numpy as np
>>> from nemos.basis import BSplineBasis
>>> # Generate data
>>> num_samples = 10000
>>> X = np.random.normal(size=(num_samples, )) # raw time series
>>> basis = BSplineBasis(10)
>>> features = basis.compute_features(X) # basis transformed time series
>>> features.shape
(10000, 10)
Notes
-----
Subclasses should implement how to handle the transformation specific to their
Expand Down Expand Up @@ -882,6 +945,19 @@ def evaluate_on_grid(self, *n_samples: int) -> Tuple[Tuple[NDArray], NDArray]:
This differs from the numpy.meshgrid default, which uses Cartesian indexing.
For the same input, Cartesian indexing would return an output of shape $(M_2, M_1, M_3, ....,M_N)$.
Examples
--------
>>> # Evaluate and visualize 4 M-spline basis functions of order 3:
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import MSplineBasis
>>> mspline_basis = MSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = mspline_basis.evaluate_on_grid(100)
>>> p = plt.plot(sample_points, basis_values)
>>> _ = plt.title('M-Spline Basis Functions')
>>> _ = plt.xlabel('Domain')
>>> _ = plt.ylabel('Basis Function Value')
>>> _ = plt.legend([f'Function {i+1}' for i in range(4)]);
"""
self._check_input_dimensionality(n_samples)

Expand Down Expand Up @@ -1071,7 +1147,22 @@ class AdditiveBasis(Basis):
n_basis_funcs : int
Number of basis functions.
Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 2))
>>> # define two basis objects and add them
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> additive_basis = basis_1 + basis_2
>>> # can add another basis to the AdditiveBasis object
>>> X = np.random.normal(size=(30, 3))
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> additive_basis_2 = additive_basis + basis_3
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1183,6 +1274,22 @@ class MultiplicativeBasis(Basis):
n_basis_funcs : int
Number of basis functions.
Examples
--------
>>> # Generate sample data
>>> import numpy as np
>>> import nemos as nmo
>>> X = np.random.normal(size=(30, 3))
>>> # define two basis and multiply
>>> basis_1 = nmo.basis.BSplineBasis(10)
>>> basis_2 = nmo.basis.RaisedCosineBasisLinear(15)
>>> multiplicative_basis = basis_1 * basis_2
>>> # Can multiply or add another basis to the AdditiveBasis object
>>> # This will cause the number of output features of the result basis to grow accordingly
>>> basis_3 = nmo.basis.RaisedCosineBasisLog(100)
>>> multiplicative_basis_2 = multiplicative_basis * basis_3
"""

def __init__(self, basis1: Basis, basis2: Basis) -> None:
Expand Down Expand Up @@ -1298,7 +1405,6 @@ class SplineBasis(Basis, abc.ABC):
----------
order : int
Spline order.
"""

def __init__(
Expand Down Expand Up @@ -1614,6 +1720,14 @@ class BSplineBasis(SplineBasis):
[1] Prautzsch, H., Boehm, W., Paluszny, M. (2002). B-spline representation. In: Bézier and B-Spline Techniques.
Mathematics and Visualization. Springer, Berlin, Heidelberg. https://doi.org/10.1007/978-3-662-04919-8_5
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import BSplineBasis
>>> bspline_basis = BSplineBasis(n_basis_funcs=5, order=3)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = bspline_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1693,6 +1807,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import BSplineBasis
>>> bspline_basis = BSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = bspline_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1728,6 +1850,16 @@ class CyclicBSplineBasis(SplineBasis):
Number of basis functions.
order : int
Order of the splines used in basis functions.
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import CyclicBSplineBasis
>>> X = np.random.normal(size=(1000, 1))
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=5, order=3, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cyclic_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -1835,6 +1967,14 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
-----
The evaluation is performed by looping over each element and using `splev` from
SciPy to compute the basis values.
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import CyclicBSplineBasis
>>> cyclic_basis = CyclicBSplineBasis(n_basis_funcs=4, order=3)
>>> sample_points, basis_values = cyclic_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -1864,6 +2004,16 @@ class RaisedCosineBasisLinear(Basis):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLinear
>>> X = np.random.normal(size=(1000, 1))
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2003,6 +2153,13 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
basis_funcs :
Raised cosine basis functions, shape (n_samples, n_basis_funcs)
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import RaisedCosineBasisLinear
>>> cosine_basis = RaisedCosineBasisLinear(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points, basis_values = cosine_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand Down Expand Up @@ -2057,6 +2214,16 @@ class RaisedCosineBasisLog(RaisedCosineBasisLinear):
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import RaisedCosineBasisLog
>>> X = np.random.normal(size=(1000, 1))
>>> cosine_basis = RaisedCosineBasisLog(n_basis_funcs=5, mode="conv", window_size=10)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = cosine_basis(sample_points)
# References
------------
[1] Pillow, J. W., Paninski, L., Uzzel, V. J., Simoncelli, E. P., & J.,
Expand Down Expand Up @@ -2210,6 +2377,18 @@ class OrthExponentialBasis(Basis):
**kwargs :
Only used in "conv" mode. Additional keyword arguments that are passed to
`nemos.convolve.create_convolutional_predictor`
Examples
--------
>>> from numpy import linspace
>>> from nemos.basis import OrthExponentialBasis
>>> X = np.random.normal(size=(1000, 1))
>>> n_basis_funcs = 5
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points = linspace(0, 1, 100)
>>> basis_functions = ortho_basis(sample_points)
"""

def __init__(
Expand Down Expand Up @@ -2365,6 +2544,16 @@ def evaluate_on_grid(self, n_samples: int) -> Tuple[NDArray, NDArray]:
Evaluated exponentially decaying basis functions, numerically
orthogonalized, shape (n_samples, n_basis_funcs)
Examples
--------
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from nemos.basis import OrthExponentialBasis
>>> n_basis_funcs = 5
>>> decay_rates = [0.01, 0.02, 0.03, 0.04, 0.05] # sample decay rates
>>> window_size=10
>>> ortho_basis = OrthExponentialBasis(n_basis_funcs, decay_rates, "conv", window_size)
>>> sample_points, basis_values = ortho_basis.evaluate_on_grid(100)
"""
return super().evaluate_on_grid(n_samples)

Expand All @@ -2387,6 +2576,17 @@ def mspline(x: NDArray, k: int, i: int, T: NDArray) -> NDArray:
-------
spline
M-spline basis function, shape (n_sample_points, ).
Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import mspline
>>> sample_points = linspace(0, 1, 100)
>>> mspline_eval = mspline(x=sample_points, k=3, i=2, T=np.random.rand(7)) # define a cubic M-spline
>>> mspline_eval.shape
(100,)
"""
# Boundary conditions.
if (T[i + k] - T[i]) < 1e-6:
Expand Down Expand Up @@ -2453,6 +2653,18 @@ def bspline(
Notes
-----
The function uses splev function from scipy.interpolate library for the basis evaluation.
Examples
--------
>>> import numpy as np
>>> from numpy import linspace
>>> from nemos.basis import bspline
>>> sample_points = linspace(0, 1, 100)
>>> knots = np.array([0, 0, 0, 0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 1, 1, 1, 1])
>>> bspline_eval = bspline(sample_points, knots) # define a cubic B-spline
>>> bspline_eval.shape
(100, 10)
"""
knots.sort()
nk = knots.shape[0]
Expand Down

0 comments on commit d4f6524

Please sign in to comment.