Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix split_by_feature bug #289

Merged
merged 18 commits into from
Jan 14, 2025
Merged

fix split_by_feature bug #289

merged 18 commits into from
Jan 14, 2025

Conversation

billbrod
Copy link
Member

Small bugfix: in split_by_feature for the additive basis, jax tree map was sorting the dictionary that was being used to be alphabetical with respect to the labels of basis1 and basis2. Thus, if those two basis objects had different n_basis_input values and were passed in using a different order than their alphabetical sorting, split_by_feature would fail. This fixes that by using an OrderedDict

It doesn't look like this was an issue for the MultiplicativeBasis, but I added that test anyway, let me know if you want me to remove it.

@codecov-commenter
Copy link

codecov-commenter commented Jan 10, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.30%. Comparing base (a510ef3) to head (588ef3d).
Report is 91 commits behind head on development.

Additional details and impacted files
@@               Coverage Diff               @@
##           development     #289      +/-   ##
===============================================
+ Coverage        96.13%   97.30%   +1.16%     
===============================================
  Files               34       35       +1     
  Lines             2642     2779     +137     
===============================================
+ Hits              2540     2704     +164     
+ Misses             102       75      -27     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@BalzaniEdoardo BalzaniEdoardo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, if you add a couple of different input shapes in the test, this is good to go for me!

@@ -304,6 +304,20 @@ def test_expected_output_split_by_feature(basis_instance, super_class):
np.testing.assert_array_equal(xx[~nans], x[~nans])


@pytest.mark.parametrize("composite_op", ["add", "multiply"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, can you parametrize a few input shapes? Like (n,) (n,1) (n, 2), (n,1,2)

@BalzaniEdoardo
Copy link
Collaborator

Changed in this PR:

  • split_by_feature: the method splits the feature axis using the shape of the provided input. Example, if a basis with 5 elements processed an input of shape (n_samples, 1, 2 ,3), the feature axis, which will be of length 1*2*3*5 will be reshaped to (1, 2, 3, 5).
  • __iter__: method for iterating over the additive components, which usually corresponds to different task variables.
  • __len__: returns the number of components.

Copy link
Member Author

@billbrod billbrod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write out the attribute renaming that you did? Otherwise this looks good to me!

src/nemos/basis/_basis.py Outdated Show resolved Hide resolved
src/nemos/basis/_basis.py Show resolved Hide resolved
@billbrod
Copy link
Member Author

billbrod commented Jan 14, 2025

To be explicit: this PR changes the behavior of split_by_feature. Previously, it would always return a 3d array.

Now, basis.input_shape is set when compute_features or set_input_shape is called on a Nd array X, storing X.shape[1:] (it can thus be empty, if X is 1d). Then, basis.split_by_feature(X) returns an array of shape (n_samples, *basis.input_shape, basis.n_basis_funcs).

Some examples:

n_samples = 100
basis = nmo.basis.RaisedCosineLinearEval(7)
input_shape = (4, 5, 10) # or some other tuple of ints
X = np.random.rand(n_samples, *input_shape)
split = basis.split_by_feature(b.compute_features(X))
split[basis.label].shape
>>> (100, 4, 5, 10, 7)
n_samples = 100
basis = nmo.basis.RaisedCosineLinearEval(7)
input_shape = () # or some other tuple of ints
X = np.random.rand(n_samples, *input_shape)
split = basis.split_by_feature(b.compute_features(X))
split[basis.label].shape
>>> (100, 7)

@BalzaniEdoardo
Copy link
Collaborator

BalzaniEdoardo commented Jan 14, 2025

Additional changes to _basis.py:

  • Renamed _n_basis_input_ to '_input_shape_product`, a more descriptive name.
  • Removed the propertyn_basis_input_ since the attribute is used internally only.
  • New public property input_shape which returns the input.shape[1:] for atomic bases and multiplicative bases, and a list of all input shapes for additive bases.

@BalzaniEdoardo BalzaniEdoardo merged commit f5d3fde into development Jan 14, 2025
13 checks passed
@BalzaniEdoardo BalzaniEdoardo deleted the composite_ordering branch January 14, 2025 22:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants