Skip to content

Commit

Permalink
added a quick repr
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Oct 11, 2024
1 parent a7e21fd commit 7f8b2b1
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 48 deletions.
3 changes: 1 addition & 2 deletions docs/how_to_guide/plot_05_batch_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@
# Here we instantiate the basis. `ws` is 40 time bins. It corresponds to a 200 ms windows
ws = 40

# set the n_basis_input to the number of neuron in the population
basis = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=ws, n_basis_input=n_neurons)
basis = nmo.basis.RaisedCosineBasisLog(5, mode="conv", window_size=ws)

# %%
# ## Batch definition
Expand Down
15 changes: 12 additions & 3 deletions docs/tutorials/plot_02_head_direction.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@
# define a basis function that expects an input of shape (num_samples, num_neurons).
num_neurons = count.shape[1]
basis = nmo.basis.RaisedCosineBasisLog(
n_basis_funcs=8, mode="conv", window_size=window_size, n_basis_input=num_neurons
n_basis_funcs=8, mode="conv", window_size=window_size, label="convolved counts"
)

# convolve all the neurons
Expand Down Expand Up @@ -574,10 +574,19 @@
minmax=(0, 2 * np.pi))

# %%
# Extract the weights and store it in a (n_neurons, n_neurons, n_basis_funcs) array.
# Extract the weights and store it in a (n_neurons, n_basis_funcs, n_neurons) array.
# We can use `basis.split_by_feature` for this. The method will return a dictionary with an array
# for each feature, and keys the label we provided to the basis.
# In this case, "convolved counts" is the only feature.

weights = model.coef_.reshape(count.shape[1], basis.n_basis_funcs, count.shape[1])
# split the coefficients by feature
weights = basis.split_by_feature(model.coef_)

# the output is a dictionary containing an array of shape (n_neurons, n_basis_funcs, n_neurons)
print(f"{weights.keys()}: {weights['convolved counts'].shape}")

# get the array
weights = weights["convolved counts"]

# %%
# Multiply the weights by the basis, to get the history filters.
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/plot_06_calcium_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@

# define a basis that expect all the other neurons as predictors, i.e. shape (num_samples, num_neurons - 1)
num_neurons = Y.shape[1]
coupling_basis = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=10, n_basis_input=num_neurons - 1)
coupling_basis = nmo.basis.RaisedCosineBasisLog(3, mode="conv", window_size=10)

# %%
# We need to make sure the design matrix will be full-rank by applying identifiability constraints to the Cyclic Bspline, and then combine the bases (the resturned object will be an `AdditiveBasis` object).
Expand Down
133 changes: 91 additions & 42 deletions src/nemos/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jax
import numpy as np
import scipy.linalg
from numba.core.ir_utils import raise_on_unsupported_feature
from numpy.typing import ArrayLike, NDArray
from pynapple import Tsd, TsdFrame
from scipy.interpolate import splev
Expand Down Expand Up @@ -492,10 +491,12 @@ def __init__(

self._mode = mode

self._n_basis_input = None#(1 if n_basis_input is None else int(n_basis_input),)
self._n_basis_input = (
None # (1 if n_basis_input is None else int(n_basis_input),)
)

# pre-compute the expected output feature dimensionality
self._num_output_features = None#self._n_basis_input[0] * n_basis_funcs
self._num_output_features = None # self._n_basis_input[0] * n_basis_funcs
self._label = str(label)
self.window_size = window_size
self.bounds = bounds
Expand Down Expand Up @@ -1214,7 +1215,62 @@ def _get_default_slicing(
start_slice += self._num_output_features
return split_dict, start_slice

def split_by_feature(self, x: NDArray, axis: int = 1) -> dict:
def _split_by_feature(
self,
x: NDArray,
axis: int = 1,
store_input_in: Literal["dict", "tensor"] = "tensor",
):
"""
Split x by feature returning.
Splits the array by feature; for each feature it returns a tensor, in which the feature dimension
is split into (n_basis_funcs, n_basis_input), or a dictionary with keys 1,...,n_basis_input,
of arrays of (n_basis_funcs,).
"""
split_by_input = store_input_in == "dict"

# Get the slice dictionary based on predefined feature slicing
slice_dict = self._get_feature_slicing(split_by_input=split_by_input)[0]

# Helper function to build index tuples for each slice
def build_index_tuple(slice_obj, axis: int, ndim: int):
"""Create an index tuple to apply a slice on the given axis."""
index = [slice(None)] * ndim # Initialize index for all dimensions
index[axis] = slice_obj # Replace the axis with the slice object
return tuple(index)

# Get the dict for slicing the correct axis
index_dict = jax.tree_util.tree_map(
lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict
)

# Custom leaf function to identify index tuples as leaves
def is_leaf(val):
# Check if it's a tuple, length matches ndim, and all elements are slice objects
if isinstance(val, tuple) and len(val) == x.ndim:
return all(isinstance(v, slice) for v in val)
return False

# Apply the slicing using the custom leaf function
out = jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf)

# reshape to array
if not split_by_input:
reshaped_out = dict()
for i, vals in enumerate(out.items()):
key, val = vals
shape = list(val.shape)
reshaped_out[key] = val.reshape(
shape[:axis] + [self._n_basis_input[i], -1] + shape[axis + 1 :]
)
return reshaped_out
return out

def split_by_feature(self, x: NDArray, axis: int = 0) -> dict:
r"""
Decompose a feature matrix along a specified axis into a dictionary of sub-arrays based on basis components.
Expand Down Expand Up @@ -1317,30 +1373,7 @@ def split_by_feature(self, x: NDArray, axis: int = 1) -> dict:
f" `x.shape[axis] == {x.shape[axis]}`, while the expected number "
f"of features is {self.num_output_features}"
)
# Get the slice dictionary based on predefined feature slicing
slice_dict = self._get_feature_slicing()[0]

# Helper function to build index tuples for each slice
def build_index_tuple(slice_obj, axis: int, ndim: int):
"""Create an index tuple to apply a slice on the given axis."""
index = [slice(None)] * ndim # Initialize index for all dimensions
index[axis] = slice_obj # Replace the axis with the slice object
return tuple(index)

# Get the dict for slicing the correct axis
index_dict = jax.tree_util.tree_map(
lambda sl: build_index_tuple(sl, axis, x.ndim), slice_dict
)

# Custom leaf function to identify index tuples as leaves
def is_leaf(val):
# Check if it's a tuple, length matches ndim, and all elements are slice objects
if isinstance(val, tuple) and len(val) == x.ndim:
return all(isinstance(v, slice) for v in val)
return False

# Apply the slicing and return the result using the custom leaf function
return jax.tree_util.tree_map(lambda sl: x[sl], index_dict, is_leaf=is_leaf)
return self._split_by_feature(x, axis=axis, store_input_in="tensor")

def _set_num_output_features(self, *xi: NDArray):
# this is reimplemented in AdditiveBasis and MultiplicativeBasis
Expand All @@ -1351,14 +1384,18 @@ def _set_num_output_features(self, *xi: NDArray):
axis = self._conv_kwargs.get("axis", 0)

# remove time axis & get the total input number
n_inputs = (1, ) if xi[0].ndim == 1 else (np.prod(shape[:axis] + shape[axis + 1:]), )
n_inputs = (
(1,) if xi[0].ndim == 1 else (np.prod(shape[:axis] + shape[axis + 1 :]),)
)

if self._n_basis_input is not None and self._n_basis_input != n_inputs:
raise ValueError(f"Input dimensionality mismatch. "
f"The basis {self.__class__.__name__} with label {self.label} was expecting "
f"{self._n_basis_input[0]} inputs, {n_inputs[0]} provided. "
"This happens when the basis is used to compute features multiple times, with"
"input of different shapes. If you need to compute features over")
raise ValueError(
f"Input dimensionality mismatch. "
f"The basis {self.__class__.__name__} with label {self.label} was expecting "
f"{self._n_basis_input[0]} inputs, {n_inputs[0]} provided. "
"This happens when the basis is used to compute features multiple times, with"
"input of different shapes. If you need to compute features over"
)

self._n_basis_input = n_inputs
self._num_output_features = self.n_basis_funcs * self._n_basis_input[0]
Expand Down Expand Up @@ -1390,7 +1427,7 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:
self._n_input_dimensionality = (
basis1._n_input_dimensionality + basis2._n_input_dimensionality
)
self._n_basis_input = None # (*basis1._n_basis_input, *basis2._n_basis_input)
self._n_basis_input = None # (*basis1._n_basis_input, *basis2._n_basis_input)
# self._num_output_features = (
# basis1._num_output_features + basis2._num_output_features
# )
Expand All @@ -1402,10 +1439,16 @@ def __init__(self, basis1: Basis, basis2: Basis) -> None:

def _set_num_output_features(self, *xi: NDArray) -> tuple:
self._n_basis_input = (
*self._basis1._set_num_output_features(*xi[: self._basis1._n_input_dimensionality])._n_basis_input,
*self._basis2._set_num_output_features(*xi[self._basis1._n_input_dimensionality: ])._n_basis_input,
*self._basis1._set_num_output_features(
*xi[: self._basis1._n_input_dimensionality]
)._n_basis_input,
*self._basis2._set_num_output_features(
*xi[self._basis1._n_input_dimensionality :]
)._n_basis_input,
)
self._num_output_features = (
self._basis1.num_output_features + self._basis2.num_output_features
)
self._num_output_features = self._basis1.num_output_features + self._basis2.num_output_features
return self

def _check_n_basis_min(self) -> None:
Expand Down Expand Up @@ -1578,7 +1621,6 @@ def __call__(self, *xi: ArrayLike) -> FeatureMatrix:
X = self._apply_identifiability_constraints(X)
return X


def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:
"""
Compute the features for the multiplied bases, and compute their outer product.
Expand Down Expand Up @@ -1607,12 +1649,19 @@ def _compute_features(self, *xi: ArrayLike) -> FeatureMatrix:

def _set_num_output_features(self, *xi: NDArray) -> Basis:
self._n_basis_input = (
*self._basis1._set_num_output_features(*xi[: self._basis1._n_input_dimensionality])._n_basis_input,
*self._basis2._set_num_output_features(*xi[self._basis1._n_input_dimensionality: ])._n_basis_input,
*self._basis1._set_num_output_features(
*xi[: self._basis1._n_input_dimensionality]
)._n_basis_input,
*self._basis2._set_num_output_features(
*xi[self._basis1._n_input_dimensionality :]
)._n_basis_input,
)
self._num_output_features = (
self._basis1.num_output_features * self._basis2.num_output_features
)
self._num_output_features = self._basis1.num_output_features * self._basis2.num_output_features
return self


class SplineBasis(Basis, abc.ABC):
"""
SplineBasis class inherits from the Basis class and represents spline basis functions.
Expand Down

0 comments on commit 7f8b2b1

Please sign in to comment.