Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Repeat and Unique Ops in PyTorch #890

Merged
merged 5 commits into from
Jul 12, 2024

Conversation

twaclaw
Copy link
Contributor

@twaclaw twaclaw commented Jul 6, 2024

Description

Implements Repeat and Unique Ops in PyTorch.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment on lines 51 to 54
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need these optional kwargs. I know jax implementations were doing it but I don't see why. These functions are never called with them changed.

Its enough to rely on the scope to access them

Copy link

codecov bot commented Jul 6, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.37%. Comparing base (ee4d4f7) to head (62e453d).
Report is 111 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #890   +/-   ##
=======================================
  Coverage   81.36%   81.37%           
=======================================
  Files         171      171           
  Lines       46811    46828   +17     
  Branches    11420    11421    +1     
=======================================
+ Hits        38088    38105   +17     
  Misses       6539     6539           
  Partials     2184     2184           
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/extra_ops.py 78.78% <100.00%> (+22.53%) ⬆️

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jul 6, 2024
@ricardoV94 ricardoV94 requested a review from jessegrabowski July 8, 2024 15:14
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. I'd definitely parameterize these tests to get coverage over the whole power set of options (and to shorten the tests), but I won't make it a blocker.

@twaclaw
Copy link
Contributor Author

twaclaw commented Jul 8, 2024

This looks good to me. I'd definitely parameterize these tests to get coverage over the whole power set of options (and to shorten the tests), but I won't make it a blocker.

I think you are right. Should I parametrize the tests?

@jessegrabowski
Copy link
Member

Yeah if you agree definitely go ahead.

@twaclaw twaclaw force-pushed the implement_repeat_unique_ops_torch branch from a50c103 to d561a1d Compare July 9, 2024 17:45
Comment on lines 70 to 86
@pytest.mark.parametrize(
"return_index, return_inverse, return_counts",
[
(False, True, False),
(False, True, True),
pytest.param(
True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError)
),
],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be distinict parametrizes to test all combinations?

Suggested change
@pytest.mark.parametrize(
"return_index, return_inverse, return_counts",
[
(False, True, False),
(False, True, True),
pytest.param(
True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError)
),
],
)
@pytest.mark.parametrize("return_index", (False, pytest.param(True, marks=...)))
@pytest.mark.parametrize("return_inverse", (False, True))
@pytest.mark.parametrize("return_counts", (False, True))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to test all the combinations. return_index=True should always fail. I wanna test that once instead of 4 times. The combination return_inverse=False and return_counts=False is tested somewhere else, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to test all the combinations? How fast is the test running with 3 vs the 9?


test_value = np.arange(6, dtype="float64").reshape((3, 2))

out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, do we allow axis=None like numpy? It may fail because we don't yet have Reshape implemented but we should test

[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64"
)

out = pt.unique(a, axis=axis)
Copy link
Member

@ricardoV94 ricardoV94 Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does PyTensor only allows integer axis at the moment? No None or partial multiple axis? If we do, we should test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good with this one?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check the code, but we need to see if multiple (but not all) axis is supported by PyTensor and if so test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean multiple axis like in the case of ArgMax and Max?
I don't think so:

axis : int, optional

@twaclaw twaclaw force-pushed the implement_repeat_unique_ops_torch branch from a3f478c to 62e453d Compare July 10, 2024 19:41
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good thanks!

@ricardoV94 ricardoV94 merged commit 4c78408 into pymc-devs:main Jul 12, 2024
59 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants