Skip to content

Commit

Permalink
widgets
Browse files Browse the repository at this point in the history
  • Loading branch information
apchytr committed Oct 7, 2024
1 parent d5b347d commit ed1dbb6
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 99 deletions.
16 changes: 2 additions & 14 deletions mrmustard/widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def fock(rep):

header_widget = widgets.HTML("<h1 class=h1-fock>Fock Representation</h1>")
table_widget = widgets.HTML(
TABLE + "<table class=table-fock>"
f"<tr><th>Ansatz</th><td>{rep.ansatz.__class__.__qualname__}</td></tr>"
f"<tr><th>Shape</th><td>{shape}</td></tr>"
"</table>"
TABLE + "<table class=table-fock>" f"<tr><th>Shape</th><td>{shape}</td></tr>" "</table>"
)
left_widget = widgets.VBox(children=[header_widget, table_widget])
plot_widget.layout.padding = "10px"
Expand Down Expand Up @@ -120,15 +117,6 @@ def get_abc_str(A, b, c, round_val):
round_w = widgets.IntText(value=round_default, description="Rounding (negative -> none):")
round_w.style.description_width = "230px"
header_w = widgets.HTML("<h1>Bargmann Representation</h1>")
sub_w = widgets.HBox(
[
widgets.HTML(
'<div style="font-weight: bold; font-size: 18px">Ansatz:</div>'
f"{rep.ansatz.__class__.__qualname__}</br>"
),
round_w,
]
)
triple_w = widgets.HTML(TABLE + triple_fstr.format(*get_abc_str(A, b, c, round_default)))
eigs_header_w = widgets.HTML("<h2>Eigenvalues of A</h2>")
eigvals_w = go.FigureWidget(
Expand Down Expand Up @@ -175,7 +163,7 @@ def on_value_change(change):

eigs_vbox = widgets.VBox([eigs_header_w, eigvals_w])
return widgets.Box(
[widgets.VBox([header_w, sub_w, triple_w]), eigs_vbox],
[widgets.VBox([header_w, round_w, triple_w]), eigs_vbox],
layout=widgets.Layout(max_width="50%", flex_flow="row wrap"),
)

Expand Down
91 changes: 43 additions & 48 deletions tests/test_physics/test_representations/test_bargmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,51 +388,46 @@ def test_trace(self):
assert np.allclose(bargmann.b, b)
assert np.allclose(bargmann.c, c)

# @patch("mrmustard.physics.representations.bargmann.display")
# def test_ipython_repr(self, mock_display):
# """Test the IPython repr function."""
# rep = Bargmann(*Abc_triple(2))
# rep._ipython_display_() # pylint:disable=protected-access
# [box] = mock_display.call_args.args
# assert isinstance(box, Box)
# assert box.layout.max_width == "50%"

# # data on left, eigvals on right
# [data_vbox, eigs_vbox] = box.children
# assert isinstance(data_vbox, VBox)
# assert isinstance(eigs_vbox, VBox)

# # data forms a stack: header, ansatz, triple data
# [header, sub, table] = data_vbox.children
# assert isinstance(header, HTML)
# assert isinstance(sub, HBox)
# assert isinstance(table, HTML)

# # ansatz goes beside button to modify rounding
# [ansatz, round_w] = sub.children
# assert isinstance(ansatz, HTML)
# assert isinstance(round_w, IntText)

# # eigvals have a header and a unit circle plot
# [eig_header, unit_circle] = eigs_vbox.children
# assert isinstance(eig_header, HTML)
# assert isinstance(unit_circle, FigureWidget)

# @patch("mrmustard.physics.representations.bargmann.display")
# def test_ipython_repr_batched(self, mock_display):
# """Test the IPython repr function for a batched repr."""
# A1, b1, c1 = Abc_triple(2)
# A2, b2, c2 = Abc_triple(2)
# rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2]))
# rep._ipython_display_() # pylint:disable=protected-access
# [vbox] = mock_display.call_args.args
# assert isinstance(vbox, VBox)

# [slider, stack] = vbox.children
# assert isinstance(slider, IntSlider)
# assert slider.max == 1 # the batch size - 1
# assert isinstance(stack, Stack)

# # max_width is spot-check that this is bargmann widget
# assert len(stack.children) == 2
# assert all(box.layout.max_width == "50%" for box in stack.children)
@patch("mrmustard.physics.representations.bargmann.display")
def test_ipython_repr(self, mock_display):
"""Test the IPython repr function."""
rep = Bargmann(*Abc_triple(2))
rep._ipython_display_() # pylint:disable=protected-access
[box] = mock_display.call_args.args
assert isinstance(box, Box)
assert box.layout.max_width == "50%"

# data on left, eigvals on right
[data_vbox, eigs_vbox] = box.children
assert isinstance(data_vbox, VBox)
assert isinstance(eigs_vbox, VBox)

# data forms a stack: header, ansatz, triple data
[header, sub, table] = data_vbox.children
assert isinstance(header, HTML)
assert isinstance(sub, IntText)
assert isinstance(table, HTML)

# eigvals have a header and a unit circle plot
[eig_header, unit_circle] = eigs_vbox.children
assert isinstance(eig_header, HTML)
assert isinstance(unit_circle, FigureWidget)

@patch("mrmustard.physics.representations.bargmann.display")
def test_ipython_repr_batched(self, mock_display):
"""Test the IPython repr function for a batched repr."""
A1, b1, c1 = Abc_triple(2)
A2, b2, c2 = Abc_triple(2)
rep = Bargmann(np.array([A1, A2]), np.array([b1, b2]), np.array([c1, c2]))
rep._ipython_display_() # pylint:disable=protected-access
[vbox] = mock_display.call_args.args
assert isinstance(vbox, VBox)

[slider, stack] = vbox.children
assert isinstance(slider, IntSlider)
assert slider.max == 1 # the batch size - 1
assert isinstance(stack, Stack)

# max_width is spot-check that this is bargmann widget
assert len(stack.children) == 2
assert all(box.layout.max_width == "50%" for box in stack.children)
74 changes: 37 additions & 37 deletions tests/test_physics/test_representations/test_fock.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,40 +222,40 @@ def test_truediv_a_scalar(self):
assert isinstance(aa1_scalar, Fock)
assert np.allclose(aa1_scalar.array, array / 6)

# @pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)])
# @patch("mrmustard.physics.representations.display")
# def test_ipython_repr(self, mock_display, shape):
# """Test the IPython repr function."""
# rep = Fock(np.random.random(shape), batched=True)
# rep._ipython_display_() # pylint:disable=protected-access
# [hbox] = mock_display.call_args.args
# assert isinstance(hbox, HBox)

# # the CSS, the header+ansatz, and the tabs of plots
# [css, left, plots] = hbox.children
# assert isinstance(css, HTML)
# assert isinstance(left, VBox)
# assert isinstance(plots, Tab)

# # left contains header and ansatz
# left = left.children
# assert len(left) == 2 and all(isinstance(w, HTML) for w in left)

# # one plot for magnitude, another for phase
# assert plots.titles == ("Magnitude", "Phase")
# plots = plots.children
# assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots)

# @patch("mrmustard.physics.representations.display")
# def test_ipython_repr_expects_batch_1(self, mock_display):
# """Test the IPython repr function does nothing with real batch."""
# rep = Fock(np.random.random((2, 8)), batched=True)
# rep._ipython_display_() # pylint:disable=protected-access
# mock_display.assert_not_called()

# @patch("mrmustard.physics.representations.display")
# def test_ipython_repr_expects_3_dims_or_less(self, mock_display):
# """Test the IPython repr function does nothing with 4+ dims."""
# rep = Fock(np.random.random((1, 4, 4, 4)), batched=True)
# rep._ipython_display_() # pylint:disable=protected-access
# mock_display.assert_not_called()
@pytest.mark.parametrize("shape", [(1, 8), (1, 8, 8)])
@patch("mrmustard.physics.representations.fock.display")
def test_ipython_repr(self, mock_display, shape):
"""Test the IPython repr function."""
rep = Fock(np.random.random(shape), batched=True)
rep._ipython_display_() # pylint:disable=protected-access
[hbox] = mock_display.call_args.args
assert isinstance(hbox, HBox)

# the CSS, the header+ansatz, and the tabs of plots
[css, left, plots] = hbox.children
assert isinstance(css, HTML)
assert isinstance(left, VBox)
assert isinstance(plots, Tab)

# left contains header and ansatz
left = left.children
assert len(left) == 2 and all(isinstance(w, HTML) for w in left)

# one plot for magnitude, another for phase
assert plots.titles == ("Magnitude", "Phase")
plots = plots.children
assert len(plots) == 2 and all(isinstance(p, FigureWidget) for p in plots)

@patch("mrmustard.physics.representations.fock.display")
def test_ipython_repr_expects_batch_1(self, mock_display):
"""Test the IPython repr function does nothing with real batch."""
rep = Fock(np.random.random((2, 8)), batched=True)
rep._ipython_display_() # pylint:disable=protected-access
mock_display.assert_not_called()

@patch("mrmustard.physics.representations.fock.display")
def test_ipython_repr_expects_3_dims_or_less(self, mock_display):
"""Test the IPython repr function does nothing with 4+ dims."""
rep = Fock(np.random.random((1, 4, 4, 4)), batched=True)
rep._ipython_display_() # pylint:disable=protected-access
mock_display.assert_not_called()

0 comments on commit ed1dbb6

Please sign in to comment.