Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paths and file management #61

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 69 additions & 51 deletions retinal_rl/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,32 @@
import networkx as nx
import numpy as np
import seaborn as sns
import torch
from matplotlib import gridspec, patches
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.lines import Line2D
from matplotlib.patches import Circle, Wedge
from matplotlib.ticker import MaxNLocator
from numpy import fft
from torch import Tensor
from torchvision.utils import make_grid

from retinal_rl.models.brain import Brain
from retinal_rl.models.objective import ContextT, Objective
from retinal_rl.util import FloatArray


def plot_transforms(
source_transforms: Dict[str, Dict[float, List[torch.Tensor]]],
noise_transforms: Dict[str, Dict[float, List[torch.Tensor]]],
source_transforms: Dict[str, Dict[float, List[FloatArray]]],
noise_transforms: Dict[str, Dict[float, List[FloatArray]]],
) -> Figure:
"""Use the result of the transform_base_images function to plot the effects of source and noise transforms on images.
"""Plot effects of source and noise transforms on images.

Args:
----
source_transforms: A dictionary of source transforms and their effects on images.
noise_transforms: A dictionary of noise transforms and their effects on images.
source_transforms: Dictionary of source transforms (numpy arrays)
noise_transforms: Dictionary of noise transforms (numpy arrays)

Returns:
-------
Figure: A matplotlib Figure containing the plotted transforms.

Figure containing the plotted transforms
"""
# Determine the number of transforms and images
num_source_transforms = len(source_transforms)
num_noise_transforms = len(noise_transforms)
num_transforms = num_source_transforms + num_noise_transforms
Expand All @@ -48,33 +41,53 @@ def plot_transforms(
]
)

# Create a figure with subplots for each transform
fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms))
if num_transforms == 1:
axs = [axs]

transform_index = 0

def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from looking at it, I think this function might be worth moving out of the local scope (but perhaps private to the module)?
Or do you think it's only of value within this method?

"""Create a grid of images from a list of numpy arrays."""
# Assuming arrays are [C, H, W]
n = len(arrays)
if not n:
return np.array([])

ncol = nrow
nrow = (n - 1) // ncol + 1

nchns, hght, wdth = arrays[0].shape
grid = np.zeros((nchns, hght * nrow, wdth * ncol))

for idx, array in enumerate(arrays):
i = idx // ncol
j = idx % ncol
grid[:, i * hght : (i + 1) * hght, j * wdth : (j + 1) * wdth] = array

return grid

# Plot source transforms
for transform_name, transform_data in source_transforms.items():
ax = axs[transform_index]
steps = sorted(transform_data.keys())

# Create a grid of images for each step
images = [
make_grid(
torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]),
make_image_grid(
[(img * 0.5 + 0.5) for img in transform_data[step]],
nrow=num_images,
)
for step in steps
]
grid = make_grid(images, nrow=len(steps))
grid = make_image_grid(images, nrow=len(steps))

# Display the grid
ax.imshow(grid.permute(1, 2, 0))
# Move channels last for imshow
grid_display = np.transpose(grid, (1, 2, 0))
ax.imshow(grid_display)
ax.set_title(f"Source Transform: {transform_name}")
ax.set_xticks(
[(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))]
[(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))]
)
ax.set_xticklabels([f"{step:.2f}" for step in steps])
ax.set_yticks([])
Expand All @@ -88,19 +101,20 @@ def plot_transforms(

# Create a grid of images for each step
images = [
make_grid(
torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]),
make_image_grid(
[(img * 0.5 + 0.5) for img in transform_data[step]],
nrow=num_images,
)
for step in steps
]
grid = make_grid(images, nrow=len(steps))
grid = make_image_grid(images, nrow=len(steps))

# Display the grid
ax.imshow(grid.permute(1, 2, 0))
# Move channels last for imshow
grid_display = np.transpose(grid, (1, 2, 0))
ax.imshow(grid_display)
ax.set_title(f"Noise Transform: {transform_name}")
ax.set_xticks(
[(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))]
[(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))]
)
ax.set_xticklabels([f"{step:.2f}" for step in steps])
ax.set_yticks([])
Expand Down Expand Up @@ -237,16 +251,17 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F
return fig


def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure:
def plot_receptive_field_sizes(
input_shape: Tuple[int, ...], layers: Dict[str, Dict[str, FloatArray]]
) -> Figure:
"""Plot the receptive field sizes for each layer of the convolutional part of the network."""
# Get visual field size from the input shape
input_shape = results["input"]["shape"]
[_, height, width] = list(input_shape)

# Calculate receptive field sizes for each layer
rf_sizes: List[Tuple[int, int]] = []
layer_names: List[str] = []
for name, layer_data in results.items():
for name, layer_data in layers.items():
if name == "input":
continue
rf = layer_data["receptive_fields"]
Expand Down Expand Up @@ -357,22 +372,26 @@ def plot_histories(histories: Dict[str, List[float]]) -> Figure:


def plot_channel_statistics(
layer_data: Dict[str, FloatArray], layer_name: str, channel: int
receptive_fields: FloatArray,
spectral: Dict[str, FloatArray],
histogram: Dict[str, FloatArray],
layer_name: str,
channel: int,
) -> Figure:
"""Plot receptive fields, pixel histograms, and autocorrelation plots for a single channel in a layer."""
fig, axs = plt.subplots(2, 3, figsize=(20, 10))
fig.suptitle(f"Layer: {layer_name}, Channel: {channel}", fontsize=16)

# Receptive Fields
rf = layer_data["receptive_fields"][channel]
rf = receptive_fields[channel]
_plot_receptive_fields(axs[0, 0], rf)
axs[0, 0].set_title("Receptive Field")
axs[0, 0].set_xlabel("X")
axs[0, 0].set_ylabel("Y")

# Pixel Histograms
hist = layer_data["pixel_histograms"][channel]
bin_edges = layer_data["histogram_bin_edges"]
hist = histogram["channel_histograms"][channel]
bin_edges = histogram["bin_edges"]
axs[1, 0].bar(
bin_edges[:-1],
hist,
Expand All @@ -387,7 +406,7 @@ def plot_channel_statistics(

# Autocorrelation plots
# Plot average 2D autocorrelation and variance
autocorr = fft.fftshift(layer_data["mean_autocorr"][channel])
autocorr = fft.fftshift(spectral["mean_autocorr"][channel])
h, w = autocorr.shape
extent = [-w // 2, w // 2, -h // 2, h // 2]
im = axs[0, 1].imshow(
Expand All @@ -399,7 +418,7 @@ def plot_channel_statistics(
fig.colorbar(im, ax=axs[0, 1])
_set_integer_ticks(axs[0, 1])

autocorr_sd = fft.fftshift(np.sqrt(layer_data["var_autocorr"][channel]))
autocorr_sd = fft.fftshift(np.sqrt(spectral["var_autocorr"][channel]))
im = axs[0, 2].imshow(
autocorr_sd, cmap="inferno", origin="lower", extent=extent, vmin=0
)
Expand All @@ -411,7 +430,7 @@ def plot_channel_statistics(

# Plot average 2D power spectrum
log_power_spectrum = fft.fftshift(
np.log1p(layer_data["mean_power_spectrum"][channel])
np.log1p(spectral["mean_power_spectrum"][channel])
)
h, w = log_power_spectrum.shape

Expand All @@ -425,7 +444,7 @@ def plot_channel_statistics(
_set_integer_ticks(axs[1, 1])

log_power_spectrum_sd = fft.fftshift(
np.log1p(np.sqrt(layer_data["var_power_spectrum"][channel]))
np.log1p(np.sqrt(spectral["var_power_spectrum"][channel]))
)
im = axs[1, 2].imshow(
log_power_spectrum_sd,
Expand All @@ -450,16 +469,15 @@ def _set_integer_ticks(ax: Axes):
ax.yaxis.set_major_locator(MaxNLocator(integer=True))


# Function to plot the original and reconstructed images
def plot_reconstructions(
normalization_mean: List[float],
normalization_std: List[float],
train_sources: List[Tuple[Tensor, int]],
train_inputs: List[Tuple[Tensor, int]],
train_estimates: List[Tuple[Tensor, int]],
test_sources: List[Tuple[Tensor, int]],
test_inputs: List[Tuple[Tensor, int]],
test_estimates: List[Tuple[Tensor, int]],
train_sources: List[Tuple[FloatArray, int]],
train_inputs: List[Tuple[FloatArray, int]],
train_estimates: List[Tuple[FloatArray, int]],
test_sources: List[Tuple[FloatArray, int]],
test_inputs: List[Tuple[FloatArray, int]],
test_estimates: List[Tuple[FloatArray, int]],
num_samples: int,
) -> Figure:
"""Plot original and reconstructed images for both training and test sets, including the classes."""
Expand All @@ -474,27 +492,28 @@ def plot_reconstructions(
test_recon, test_pred = test_estimates[i]

# Unnormalize the original images using the normalization lists
# Arrays are already [C, H, W], need to move channels to last dimension
train_source = (
train_source.permute(1, 2, 0).numpy() * normalization_std
np.transpose(train_source, (1, 2, 0)) * normalization_std
+ normalization_mean
)
train_input = (
train_input.permute(1, 2, 0).numpy() * normalization_std
np.transpose(train_input, (1, 2, 0)) * normalization_std
+ normalization_mean
)
train_recon = (
train_recon.permute(1, 2, 0).numpy() * normalization_std
np.transpose(train_recon, (1, 2, 0)) * normalization_std
+ normalization_mean
)
test_source = (
test_source.permute(1, 2, 0).numpy() * normalization_std
np.transpose(test_source, (1, 2, 0)) * normalization_std
+ normalization_mean
)
test_input = (
test_input.permute(1, 2, 0).numpy() * normalization_std + normalization_mean
np.transpose(test_input, (1, 2, 0)) * normalization_std + normalization_mean
)
test_recon = (
test_recon.permute(1, 2, 0).numpy() * normalization_std + normalization_mean
np.transpose(test_recon, (1, 2, 0)) * normalization_std + normalization_mean
)

axes[0, i].imshow(np.clip(train_source, 0, 1))
Expand Down Expand Up @@ -522,7 +541,6 @@ def plot_reconstructions(
axes[5, i].set_title(f"Pred: {test_pred}")

# Set y-axis labels for each row

fig.text(
0.02,
0.90,
Expand Down
Loading