Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interact2d cmap options #1073

Merged
merged 10 commits into from
Jun 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v2
Expand Down
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).

## [Unreleased]

## Added
### Added
- `artists.interact2D` supports `cmap` kwarg.
- iPython integration: autocomplete includes axis, variable, and channel names

### Changed
- `artists.interact2D` uses matplotlib norm objects to control colormap scaling

### Fixed
- `kit.fft`: fixed bug where Fourier coefficients were off by a scalar factor.

Expand Down
94 changes: 54 additions & 40 deletions WrightTools/artists/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,42 @@ def get_channel(data, channel):
return channel


def get_colormap(channel):
if channel.signed:
cmap = "signed"
else:
cmap = "default"
def get_colormap(signed):
cmap = "signed" if signed else "default"
cmap = colormaps[cmap]
cmap.set_bad([0.75] * 3, 1.0)
cmap.set_under([0.75] * 3, 1.0)
return cmap


def get_clim(channel, current_state):
if current_state.local:
arr = current_state.dat[channel.natural_name][:]
if channel.signed:
mag = np.nanmax(np.abs(arr))
clim = [-mag, mag]
def get_norm(channel, current_state) -> object:
if channel.signed:
if not current_state.local:
norm = mpl.colors.CenteredNorm(vcenter=channel.null, halfrange=channel.mag())
else:
clim = [0, np.nanmax(arr)]
norm = mpl.colors.CenteredNorm(vcenter=channel.null)
norm.autoscale_None(current_state.dat[channel.natural_name][:])
if norm.halfrange == 0:
norm.halfrange = 1
else:
if channel.signed:
clim = [-channel.mag(), channel.mag()]
if not current_state.local:
norm = mpl.colors.Normalize(vmin=channel.null, vmax=channel.max())
else:
clim = [0, channel.max()]
return clim
norm = mpl.colors.Normalize(vmin=channel.null)
norm.autoscale_None(current_state.dat[channel.natural_name][:])
if norm.vmax == norm.vmin:
norm.vmax += 1
return norm


def norm_to_ticks(norm) -> np.array:
if type(norm) == mpl.colors.CenteredNorm:
vmin = norm.vcenter - norm.halfrange
vmax = norm.vcenter + norm.halfrange
else: # mpl.colors.Normalize
vmin = norm.vmin
vmax = norm.vmax
return np.linspace(vmin, vmax, 11)


def gen_ticklabels(points, signed=None):
Expand Down Expand Up @@ -133,7 +144,14 @@ def norm(arr, signed, ignore_zero=True):


def interact2D(
data: wt_data.Data, xaxis=0, yaxis=1, channel=0, local=False, use_imshow=False, verbose=True
data: wt_data.Data,
xaxis=0,
yaxis=1,
channel=0,
cmap=None,
local=False,
use_imshow=False,
verbose=True,
):
"""Interactive 2D plot of the dataset.
Side plots show x and y projections of the slice (shaded gray).
Expand All @@ -151,10 +169,12 @@ def interact2D(
Expression or index of y axis. Default is 1.
channel : string, integer, or data.Channel object (optional)
Name or index of channel to plot. Default is 0.
cmap : string or cm object (optional)
Name of colormap, or explicit colormap object. Defaults to channel default.
local : boolean (optional)
Toggle plotting locally. Default is False.
use_imshow : boolean (optional)
If true, matplotlib imshow is used to render the 2D slice.
If True, matplotlib imshow is used to render the 2D slice.
Can give better performance, but is only accurate for
uniform grids. Default is False.
verbose : boolean (optional)
Expand All @@ -163,10 +183,10 @@ def interact2D(
# avoid changing passed data object
data = data.copy()
# unpack
data.prune(keep_channels=channel)
data.prune(keep_channels=channel, verbose=False)
channel = get_channel(data, channel)
xaxis, yaxis = get_axes(data, [xaxis, yaxis])
cmap = get_colormap(channel)
cmap = cmap if cmap is not None else get_colormap(channel.signed)
current_state = SimpleNamespace()
# create figure
nsliders = data.ndim - 2
Expand Down Expand Up @@ -229,7 +249,7 @@ def interact2D(
*slider.ax.get_ylim(),
colors="k",
linestyle=":",
alpha=0.5
alpha=0.5,
)
slider.valtext.set_text(gen_ticklabels(axis.points)[0])
current_state.focus = Focus([ax0] + [slider.ax for slider in sliders.values()])
Expand All @@ -240,25 +260,21 @@ def interact2D(
at=_at_dict(data, sliders, xaxis, yaxis),
verbose=False,
)[0]
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
norm = get_norm(channel, current_state)

gen_mesh = ax0.pcolormesh if not use_imshow else ax0.imshow
obj2D = gen_mesh(
current_state.dat,
cmap=cmap,
vmin=clim[0],
vmax=clim[1],
norm=norm,
ylabel=yaxis.label,
xlabel=xaxis.label,
)
ax0.grid(b=True)
# colorbar
colorbar = plot_colorbar(
cax, cmap=cmap, label=channel.natural_name, ticks=np.linspace(clim[0], clim[1], 11)
)
ticks = norm_to_ticks(norm)
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar = plot_colorbar(cax, cmap=cmap, label=channel.natural_name, ticks=ticks)
colorbar.set_ticklabels(ticklabels)
fig.canvas.draw_idle()

Expand Down Expand Up @@ -384,12 +400,10 @@ def update_local(index):
if verbose:
print("normalization:", index)
current_state.local = radio.value_selected[1:] == "local"
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
norm = get_norm(channel, current_state)
obj2D.set_norm(norm)
ticklabels = gen_ticklabels(np.linspace(norm.vmin, norm.vmax, 11), channel.signed)
colorbar.set_ticklabels(ticklabels)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
obj2D.set_clim(*clim)
fig.canvas.draw_idle()

def update_slider(info, use_imshow=use_imshow):
Expand All @@ -416,11 +430,11 @@ def update_slider(info, use_imshow=use_imshow):
obj2D.set_data(current_state.dat[channel.natural_name][:].transpose(transpose))
else:
obj2D.set_array(current_state.dat[channel.natural_name][:].ravel())
clim = get_clim(channel, current_state)
ticklabels = gen_ticklabels(np.linspace(*clim, 11), channel.signed)
if clim[0] == clim[1]:
clim = [-1 if channel.signed else 0, 1]
obj2D.set_clim(*clim)
norm = get_norm(channel, current_state)
obj2D.set_norm(norm)

ticks = norm_to_ticks(norm)
ticklabels = gen_ticklabels(ticks, channel.signed)
colorbar.set_ticklabels(ticklabels)
sp_x.collections.clear()
sp_y.collections.clear()
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def read(fname):
name="WrightTools",
packages=find_packages(exclude=("tests", "tests.*")),
package_data=extra_files,
python_requires=">=3.6",
python_requires=">=3.7",
install_requires=[
"h5py",
"imageio",
"matplotlib>=3.3.0",
"matplotlib>=3.4.0",
"numexpr",
"numpy>=1.15.0",
"pint",
Expand Down Expand Up @@ -74,8 +74,9 @@ def read(fname):
"Framework :: Matplotlib",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering",
],
)