-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix split_by_feature bug #289
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## development #289 +/- ##
===============================================
+ Coverage 96.13% 97.30% +1.16%
===============================================
Files 34 35 +1
Lines 2642 2779 +137
===============================================
+ Hits 2540 2704 +164
+ Misses 102 75 -27 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, if you add a couple of different input shapes in the test, this is good to go for me!
@@ -304,6 +304,20 @@ def test_expected_output_split_by_feature(basis_instance, super_class): | |||
np.testing.assert_array_equal(xx[~nans], x[~nans]) | |||
|
|||
|
|||
@pytest.mark.parametrize("composite_op", ["add", "multiply"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good, can you parametrize a few input shapes? Like (n,) (n,1) (n, 2), (n,1,2)
currently failing, need to modify behavior
Changed in this PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you write out the attribute renaming that you did? Otherwise this looks good to me!
To be explicit: this PR changes the behavior of Now, Some examples: n_samples = 100
basis = nmo.basis.RaisedCosineLinearEval(7)
input_shape = (4, 5, 10) # or some other tuple of ints
X = np.random.rand(n_samples, *input_shape)
split = basis.split_by_feature(b.compute_features(X))
split[basis.label].shape
>>> (100, 4, 5, 10, 7) n_samples = 100
basis = nmo.basis.RaisedCosineLinearEval(7)
input_shape = () # or some other tuple of ints
X = np.random.rand(n_samples, *input_shape)
split = basis.split_by_feature(b.compute_features(X))
split[basis.label].shape
>>> (100, 7) |
Co-authored-by: William F. Broderick <[email protected]>
Co-authored-by: William F. Broderick <[email protected]>
Additional changes to
|
Small bugfix: in
split_by_feature
for the additive basis, jax tree map was sorting the dictionary that was being used to be alphabetical with respect to the labels ofbasis1
andbasis2
. Thus, if those two basis objects had differentn_basis_input
values and were passed in using a different order than their alphabetical sorting,split_by_feature
would fail. This fixes that by using an OrderedDictIt doesn't look like this was an issue for the
MultiplicativeBasis
, but I added that test anyway, let me know if you want me to remove it.