diff --git a/retinal_rl/analysis/plot.py b/retinal_rl/analysis/plot.py index ff2a3d1..ff25e69 100644 --- a/retinal_rl/analysis/plot.py +++ b/retinal_rl/analysis/plot.py @@ -6,7 +6,6 @@ import networkx as nx import numpy as np import seaborn as sns -import torch from matplotlib import gridspec, patches from matplotlib.axes import Axes from matplotlib.figure import Figure @@ -14,8 +13,6 @@ from matplotlib.patches import Circle, Wedge from matplotlib.ticker import MaxNLocator from numpy import fft -from torch import Tensor -from torchvision.utils import make_grid from retinal_rl.models.brain import Brain from retinal_rl.models.objective import ContextT, Objective @@ -23,22 +20,18 @@ def plot_transforms( - source_transforms: Dict[str, Dict[float, List[torch.Tensor]]], - noise_transforms: Dict[str, Dict[float, List[torch.Tensor]]], + source_transforms: Dict[str, Dict[float, List[FloatArray]]], + noise_transforms: Dict[str, Dict[float, List[FloatArray]]], ) -> Figure: - """Use the result of the transform_base_images function to plot the effects of source and noise transforms on images. + """Plot effects of source and noise transforms on images. Args: - ---- - source_transforms: A dictionary of source transforms and their effects on images. - noise_transforms: A dictionary of noise transforms and their effects on images. + source_transforms: Dictionary of source transforms (numpy arrays) + noise_transforms: Dictionary of noise transforms (numpy arrays) Returns: - ------- - Figure: A matplotlib Figure containing the plotted transforms. - + Figure containing the plotted transforms """ - # Determine the number of transforms and images num_source_transforms = len(source_transforms) num_noise_transforms = len(noise_transforms) num_transforms = num_source_transforms + num_noise_transforms @@ -48,13 +41,32 @@ def plot_transforms( ] ) - # Create a figure with subplots for each transform fig, axs = plt.subplots(num_transforms, 1, figsize=(20, 5 * num_transforms)) if num_transforms == 1: axs = [axs] transform_index = 0 + def make_image_grid(arrays: List[FloatArray], nrow: int) -> FloatArray: + """Create a grid of images from a list of numpy arrays.""" + # Assuming arrays are [C, H, W] + n = len(arrays) + if not n: + return np.array([]) + + ncol = nrow + nrow = (n - 1) // ncol + 1 + + nchns, hght, wdth = arrays[0].shape + grid = np.zeros((nchns, hght * nrow, wdth * ncol)) + + for idx, array in enumerate(arrays): + i = idx // ncol + j = idx % ncol + grid[:, i * hght : (i + 1) * hght, j * wdth : (j + 1) * wdth] = array + + return grid + # Plot source transforms for transform_name, transform_data in source_transforms.items(): ax = axs[transform_index] @@ -62,19 +74,20 @@ def plot_transforms( # Create a grid of images for each step images = [ - make_grid( - torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]), + make_image_grid( + [(img * 0.5 + 0.5) for img in transform_data[step]], nrow=num_images, ) for step in steps ] - grid = make_grid(images, nrow=len(steps)) + grid = make_image_grid(images, nrow=len(steps)) - # Display the grid - ax.imshow(grid.permute(1, 2, 0)) + # Move channels last for imshow + grid_display = np.transpose(grid, (1, 2, 0)) + ax.imshow(grid_display) ax.set_title(f"Source Transform: {transform_name}") ax.set_xticks( - [(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))] + [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] ) ax.set_xticklabels([f"{step:.2f}" for step in steps]) ax.set_yticks([]) @@ -88,19 +101,20 @@ def plot_transforms( # Create a grid of images for each step images = [ - make_grid( - torch.stack([img * 0.5 + 0.5 for img in transform_data[step]]), + make_image_grid( + [(img * 0.5 + 0.5) for img in transform_data[step]], nrow=num_images, ) for step in steps ] - grid = make_grid(images, nrow=len(steps)) + grid = make_image_grid(images, nrow=len(steps)) - # Display the grid - ax.imshow(grid.permute(1, 2, 0)) + # Move channels last for imshow + grid_display = np.transpose(grid, (1, 2, 0)) + ax.imshow(grid_display) ax.set_title(f"Noise Transform: {transform_name}") ax.set_xticks( - [(i + 0.5) * grid.shape[2] / len(steps) for i in range(len(steps))] + [(i + 0.5) * grid_display.shape[1] / len(steps) for i in range(len(steps))] ) ax.set_xticklabels([f"{step:.2f}" for step in steps]) ax.set_yticks([]) @@ -237,16 +251,17 @@ def plot_brain_and_optimizers(brain: Brain, objective: Objective[ContextT]) -> F return fig -def plot_receptive_field_sizes(results: Dict[str, Dict[str, FloatArray]]) -> Figure: +def plot_receptive_field_sizes( + input_shape: Tuple[int, ...], layers: Dict[str, Dict[str, FloatArray]] +) -> Figure: """Plot the receptive field sizes for each layer of the convolutional part of the network.""" # Get visual field size from the input shape - input_shape = results["input"]["shape"] [_, height, width] = list(input_shape) # Calculate receptive field sizes for each layer rf_sizes: List[Tuple[int, int]] = [] layer_names: List[str] = [] - for name, layer_data in results.items(): + for name, layer_data in layers.items(): if name == "input": continue rf = layer_data["receptive_fields"] @@ -357,22 +372,26 @@ def plot_histories(histories: Dict[str, List[float]]) -> Figure: def plot_channel_statistics( - layer_data: Dict[str, FloatArray], layer_name: str, channel: int + receptive_fields: FloatArray, + spectral: Dict[str, FloatArray], + histogram: Dict[str, FloatArray], + layer_name: str, + channel: int, ) -> Figure: """Plot receptive fields, pixel histograms, and autocorrelation plots for a single channel in a layer.""" fig, axs = plt.subplots(2, 3, figsize=(20, 10)) fig.suptitle(f"Layer: {layer_name}, Channel: {channel}", fontsize=16) # Receptive Fields - rf = layer_data["receptive_fields"][channel] + rf = receptive_fields[channel] _plot_receptive_fields(axs[0, 0], rf) axs[0, 0].set_title("Receptive Field") axs[0, 0].set_xlabel("X") axs[0, 0].set_ylabel("Y") # Pixel Histograms - hist = layer_data["pixel_histograms"][channel] - bin_edges = layer_data["histogram_bin_edges"] + hist = histogram["channel_histograms"][channel] + bin_edges = histogram["bin_edges"] axs[1, 0].bar( bin_edges[:-1], hist, @@ -387,7 +406,7 @@ def plot_channel_statistics( # Autocorrelation plots # Plot average 2D autocorrelation and variance - autocorr = fft.fftshift(layer_data["mean_autocorr"][channel]) + autocorr = fft.fftshift(spectral["mean_autocorr"][channel]) h, w = autocorr.shape extent = [-w // 2, w // 2, -h // 2, h // 2] im = axs[0, 1].imshow( @@ -399,7 +418,7 @@ def plot_channel_statistics( fig.colorbar(im, ax=axs[0, 1]) _set_integer_ticks(axs[0, 1]) - autocorr_sd = fft.fftshift(np.sqrt(layer_data["var_autocorr"][channel])) + autocorr_sd = fft.fftshift(np.sqrt(spectral["var_autocorr"][channel])) im = axs[0, 2].imshow( autocorr_sd, cmap="inferno", origin="lower", extent=extent, vmin=0 ) @@ -411,7 +430,7 @@ def plot_channel_statistics( # Plot average 2D power spectrum log_power_spectrum = fft.fftshift( - np.log1p(layer_data["mean_power_spectrum"][channel]) + np.log1p(spectral["mean_power_spectrum"][channel]) ) h, w = log_power_spectrum.shape @@ -425,7 +444,7 @@ def plot_channel_statistics( _set_integer_ticks(axs[1, 1]) log_power_spectrum_sd = fft.fftshift( - np.log1p(np.sqrt(layer_data["var_power_spectrum"][channel])) + np.log1p(np.sqrt(spectral["var_power_spectrum"][channel])) ) im = axs[1, 2].imshow( log_power_spectrum_sd, @@ -450,16 +469,15 @@ def _set_integer_ticks(ax: Axes): ax.yaxis.set_major_locator(MaxNLocator(integer=True)) -# Function to plot the original and reconstructed images def plot_reconstructions( normalization_mean: List[float], normalization_std: List[float], - train_sources: List[Tuple[Tensor, int]], - train_inputs: List[Tuple[Tensor, int]], - train_estimates: List[Tuple[Tensor, int]], - test_sources: List[Tuple[Tensor, int]], - test_inputs: List[Tuple[Tensor, int]], - test_estimates: List[Tuple[Tensor, int]], + train_sources: List[Tuple[FloatArray, int]], + train_inputs: List[Tuple[FloatArray, int]], + train_estimates: List[Tuple[FloatArray, int]], + test_sources: List[Tuple[FloatArray, int]], + test_inputs: List[Tuple[FloatArray, int]], + test_estimates: List[Tuple[FloatArray, int]], num_samples: int, ) -> Figure: """Plot original and reconstructed images for both training and test sets, including the classes.""" @@ -474,27 +492,28 @@ def plot_reconstructions( test_recon, test_pred = test_estimates[i] # Unnormalize the original images using the normalization lists + # Arrays are already [C, H, W], need to move channels to last dimension train_source = ( - train_source.permute(1, 2, 0).numpy() * normalization_std + np.transpose(train_source, (1, 2, 0)) * normalization_std + normalization_mean ) train_input = ( - train_input.permute(1, 2, 0).numpy() * normalization_std + np.transpose(train_input, (1, 2, 0)) * normalization_std + normalization_mean ) train_recon = ( - train_recon.permute(1, 2, 0).numpy() * normalization_std + np.transpose(train_recon, (1, 2, 0)) * normalization_std + normalization_mean ) test_source = ( - test_source.permute(1, 2, 0).numpy() * normalization_std + np.transpose(test_source, (1, 2, 0)) * normalization_std + normalization_mean ) test_input = ( - test_input.permute(1, 2, 0).numpy() * normalization_std + normalization_mean + np.transpose(test_input, (1, 2, 0)) * normalization_std + normalization_mean ) test_recon = ( - test_recon.permute(1, 2, 0).numpy() * normalization_std + normalization_mean + np.transpose(test_recon, (1, 2, 0)) * normalization_std + normalization_mean ) axes[0, i].imshow(np.clip(train_source, 0, 1)) @@ -522,7 +541,6 @@ def plot_reconstructions( axes[5, i].set_title(f"Pred: {test_pred}") # Set y-axis labels for each row - fig.text( 0.02, 0.90, diff --git a/retinal_rl/analysis/statistics.py b/retinal_rl/analysis/statistics.py index fae7753..3aa9c59 100644 --- a/retinal_rl/analysis/statistics.py +++ b/retinal_rl/analysis/statistics.py @@ -1,7 +1,8 @@ """Functions for analysis and statistics on a Brain model.""" import logging -from typing import Dict, List, Tuple, cast +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -21,9 +22,76 @@ logger = logging.getLogger(__name__) +### Dataclasses ### + + +@dataclass +class TransformStatistics: + """Results of applying transformations to images.""" + + source_transforms: Dict[str, Dict[float, List[FloatArray]]] + noise_transforms: Dict[str, Dict[float, List[FloatArray]]] + + +@dataclass +class Reconstructions: + """Set of source images, inputs, and their reconstructions.""" + + sources: List[Tuple[FloatArray, int]] + inputs: List[Tuple[FloatArray, int]] + estimates: List[Tuple[FloatArray, int]] + + +@dataclass +class ReconstructionStatistics: + """Results of image reconstruction for both training and test sets.""" + + train: Reconstructions + test: Reconstructions + + +@dataclass +class SpectralAnalysis: + """Results of spectral analysis for a layer.""" + + mean_power_spectrum: FloatArray + var_power_spectrum: FloatArray + mean_autocorr: FloatArray + var_autocorr: FloatArray + + +@dataclass +class HistogramAnalysis: + """Results of histogram analysis for a layer.""" + + channel_histograms: FloatArray + bin_edges: FloatArray + + +@dataclass +class LayerStatistics: + """Statistics for a single layer.""" + + receptive_fields: FloatArray + num_channels: int + spectral: Optional[SpectralAnalysis] = None + histogram: Optional[HistogramAnalysis] = None + + +@dataclass +class CNNStatistics: + """Complete statistics for a CNN model.""" + + input_shape: Tuple[int, ...] # nclrs, hght, wdth + layers: Dict[str, LayerStatistics] + + +### Functions ### + + def transform_base_images( imageset: Imageset, num_steps: int, num_images: int -) -> Dict[str, Dict[str, Dict[float, List[Tensor]]]]: +) -> TransformStatistics: """Apply transformations to a set of images from an Imageset.""" images: List[Image.Image] = [] @@ -34,7 +102,7 @@ def transform_base_images( src, _ = base_dataset[np.random.randint(base_len)] images.append(src) - results: Dict[str, Dict[str, Dict[float, List[Tensor]]]] = { + results: Dict[str, Dict[str, Dict[float, List[FloatArray]]]] = { "source_transforms": {}, "noise_transforms": {}, } @@ -56,10 +124,10 @@ def transform_base_images( results[category][transform.name][step] = [] for img in images: results[category][transform.name][step].append( - imageset.to_tensor(transform.transform(img, step)) + imageset.to_tensor(transform.transform(img, step)).cpu().numpy() ) - return results + return TransformStatistics(**results) def reconstruct_images( @@ -69,19 +137,17 @@ def reconstruct_images( test_set: Imageset, train_set: Imageset, sample_size: int, -) -> Dict[str, List[Tuple[Tensor, int]]]: +) -> ReconstructionStatistics: """Compute reconstructions of a set of training and test images using a Brain model.""" brain.eval() # Set the model to evaluation mode def collect_reconstructions( imageset: Imageset, sample_size: int - ) -> Tuple[ - List[Tuple[Tensor, int]], List[Tuple[Tensor, int]], List[Tuple[Tensor, int]] - ]: + ) -> Reconstructions: """Collect reconstructions for a subset of a dataset.""" - source_subset: List[Tuple[Tensor, int]] = [] - input_subset: List[Tuple[Tensor, int]] = [] - estimates: List[Tuple[Tensor, int]] = [] + source_subset: List[Tuple[FloatArray, int]] = [] + input_subset: List[Tuple[FloatArray, int]] = [] + estimates: List[Tuple[FloatArray, int]] = [] indices = torch.randperm(imageset.epoch_len())[:sample_size] with torch.no_grad(): # Disable gradient computation @@ -93,28 +159,17 @@ def collect_reconstructions( response = brain(stimulus) rec_img = response[decoder].squeeze(0) pred_k = response["classifier"].argmax().item() - source_subset.append((src.cpu(), k)) - input_subset.append((img.cpu(), k)) - estimates.append((rec_img.cpu(), pred_k)) + source_subset.append((src.cpu().numpy(), k)) + input_subset.append((img.cpu().numpy(), k)) + estimates.append((rec_img.cpu().numpy(), pred_k)) - return source_subset, input_subset, estimates + return Reconstructions(source_subset, input_subset, estimates) - train_source, train_input, train_estimates = collect_reconstructions( - train_set, sample_size - ) - test_source, test_input, test_estimates = collect_reconstructions( - test_set, sample_size + return ReconstructionStatistics( + collect_reconstructions(train_set, sample_size), + collect_reconstructions(test_set, sample_size), ) - return { - "train_sources": train_source, - "train_inputs": train_input, - "train_estimates": train_estimates, - "test_sources": test_source, - "test_inputs": test_input, - "test_estimates": test_estimates, - } - def cnn_statistics( device: torch.device, @@ -122,20 +177,8 @@ def cnn_statistics( brain: Brain, channel_analysis: bool, max_sample_size: int = 0, -) -> Dict[str, Dict[str, FloatArray]]: - """Compute statistics for a convolutional encoder model. - - Args: - device: The device to run computations on. - imageset: The dataset to analyze. - brain: The trained Brain model. - channel_analysis: Whether to compute channel-wise statistics (histograms, spectra). - max_sample_size: Maximum number of samples to use. If 0, use all samples. - - Returns: - A nested dictionary containing statistics for the input and each layer. - When channel_analysis is False, only receptive_fields and num_channels are computed. - """ +) -> CNNStatistics: + """Compute statistics for a convolutional encoder model.""" brain.eval() brain.to(device) input_shape, cnn_layers = get_cnn_circuit(brain) @@ -144,7 +187,7 @@ def cnn_statistics( dataloader = _prepare_dataset(imageset, max_sample_size) # Initialize results - results: Dict[str, Dict[str, FloatArray]] = { + results = { "input": _analyze_input(device, dataloader, input_shape, channel_analysis) } @@ -168,7 +211,7 @@ def cnn_statistics( device, dataloader, head_layers, input_shape, out_channels, channel_analysis ) - return results + return CNNStatistics(input_shape, results) def _prepare_dataset( @@ -225,34 +268,22 @@ def _analyze_layer( input_shape: Tuple[int, ...], out_channels: int, channel_analysis: bool = True, -) -> Dict[str, FloatArray]: +) -> LayerStatistics: """Analyze statistics for a single layer.""" head_model = nn.Sequential(*head_layers) - results: Dict[str, FloatArray] = {} # Always compute receptive fields - results["receptive_fields"] = _compute_receptive_fields( - device, head_layers, input_shape, out_channels - ) - results["num_channels"] = np.array(out_channels, dtype=np.float64) + rfs = _compute_receptive_fields(device, head_layers, input_shape, out_channels) + + layer_spectral = None + layer_histograms = None # Compute channel-wise statistics only if requested if channel_analysis: layer_spectral = _layer_spectral_analysis(device, dataloader, head_model) layer_histograms = _layer_pixel_histograms(device, dataloader, head_model) - results.update( - { - "pixel_histograms": layer_histograms["channel_histograms"], - "histogram_bin_edges": layer_histograms["bin_edges"], - "mean_power_spectrum": layer_spectral["mean_power_spectrum"], - "var_power_spectrum": layer_spectral["var_power_spectrum"], - "mean_autocorr": layer_spectral["mean_autocorr"], - "var_autocorr": layer_spectral["var_autocorr"], - } - ) - - return results + return LayerStatistics(rfs, out_channels, layer_spectral, layer_histograms) def _analyze_input( @@ -260,31 +291,22 @@ def _analyze_input( dataloader: DataLoader[Tuple[Tensor, Tensor, int]], input_shape: Tuple[int, ...], channel_analysis: bool, -) -> Dict[str, FloatArray]: +) -> LayerStatistics: """Analyze statistics for the input layer.""" - nclrs = input_shape[0] - results: Dict[str, FloatArray] = { - "receptive_fields": np.eye(nclrs)[:, :, np.newaxis, np.newaxis], - "shape": np.array(input_shape, dtype=np.float64), - "num_channels": np.array(nclrs, dtype=np.float64), - } + + input_spectral = None + input_histograms = None if channel_analysis: input_spectral = _layer_spectral_analysis(device, dataloader, nn.Identity()) input_histograms = _layer_pixel_histograms(device, dataloader, nn.Identity()) - results.update( - { - "pixel_histograms": input_histograms["channel_histograms"], - "histogram_bin_edges": input_histograms["bin_edges"], - "mean_power_spectrum": input_spectral["mean_power_spectrum"], - "var_power_spectrum": input_spectral["var_power_spectrum"], - "mean_autocorr": input_spectral["mean_autocorr"], - "var_autocorr": input_spectral["var_autocorr"], - } - ) - - return results + return LayerStatistics( + np.eye(input_shape[0])[:, :, np.newaxis, np.newaxis], + input_shape[0], + input_spectral, + input_histograms, + ) def _layer_pixel_histograms( @@ -292,7 +314,7 @@ def _layer_pixel_histograms( dataloader: DataLoader[Tuple[Tensor, Tensor, int]], model: nn.Module, num_bins: int = 20, -) -> Dict[str, FloatArray]: +) -> HistogramAnalysis: """Compute histograms of pixel/activation values for each channel across all data in an imageset.""" _, first_batch, _ = next(iter(dataloader)) with torch.no_grad(): @@ -335,19 +357,17 @@ def _layer_pixel_histograms( bin_width = (hist_range[1] - hist_range[0]) / num_bins normalized_histograms = histograms / (total_elements * bin_width / num_channels) - return { - "bin_edges": np.linspace( - hist_range[0], hist_range[1], num_bins + 1, dtype=np.float64 - ), - "channel_histograms": normalized_histograms.cpu().numpy(), - } + return HistogramAnalysis( + normalized_histograms.cpu().numpy(), + np.linspace(hist_range[0], hist_range[1], num_bins + 1, dtype=np.float64), + ) def _layer_spectral_analysis( device: torch.device, dataloader: DataLoader[Tuple[Tensor, Tensor, int]], model: nn.Module, -) -> Dict[str, FloatArray]: +) -> SpectralAnalysis: """Compute spectral analysis statistics for each channel across all data in an imageset.""" _, first_batch, _ = next(iter(dataloader)) with torch.no_grad(): @@ -392,9 +412,9 @@ def _layer_spectral_analysis( var_power_spectrum = m2_power_spectrum / count - (mean_power_spectrum / count) ** 2 var_autocorr = m2_autocorr / count - (mean_autocorr / count) ** 2 - return { - "mean_power_spectrum": mean_power_spectrum.cpu().numpy(), - "var_power_spectrum": var_power_spectrum.cpu().numpy(), - "mean_autocorr": mean_autocorr.cpu().numpy(), - "var_autocorr": var_autocorr.cpu().numpy(), - } + return SpectralAnalysis( + mean_power_spectrum.cpu().numpy(), + var_power_spectrum.cpu().numpy(), + mean_autocorr.cpu().numpy(), + var_autocorr.cpu().numpy(), + ) diff --git a/runner/frameworks/classification/analyze.py b/runner/frameworks/classification/analyze.py index 66c9781..013ba78 100644 --- a/runner/frameworks/classification/analyze.py +++ b/runner/frameworks/classification/analyze.py @@ -1,9 +1,12 @@ +import json import logging -import os import shutil -from typing import Dict, List +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict, List import matplotlib.pyplot as plt +import numpy as np import torch import wandb from matplotlib.figure import Figure @@ -19,6 +22,8 @@ plot_transforms, ) from retinal_rl.analysis.statistics import ( + CNNStatistics, + LayerStatistics, cnn_statistics, reconstruct_images, transform_base_images, @@ -27,13 +32,27 @@ from retinal_rl.models.brain import Brain from retinal_rl.models.loss import ReconstructionLoss from retinal_rl.models.objective import ContextT, Objective -from retinal_rl.util import FloatArray + +### Infrastructure ### + logger = logging.getLogger(__name__) init_dir = "initialization_analysis" +class NumpyEncoder(json.JSONEncoder): + """JSON encoder that handles numpy arrays.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, np.ndarray): + return obj.tolist() + return super().default(obj) + + +### Analysis ### + + def analyze( cfg: DictConfig, device: torch.device, @@ -45,116 +64,259 @@ def analyze( epoch: int, copy_checkpoint: bool = False, ): - if not cfg.logging.use_wandb: - _plot_and_save_histories(cfg, histories) + ## DictConfig + + # Path creation + run_dir = Path(cfg.path.run_dir) + run_dir.mkdir(exist_ok=True) + + plot_dir = Path(cfg.path.plot_dir) + plot_dir.mkdir(exist_ok=True) + + checkpoint_plot_dir = Path(cfg.path.checkpoint_plot_dir) + checkpoint_plot_dir.mkdir(exist_ok=True) + + analyses_dir = Path(cfg.path.data_dir) / "analyses" + analyses_dir.mkdir(exist_ok=True) + + # Variables + use_wandb = cfg.logging.use_wandb + channel_analysis = cfg.logging.channel_analysis + plot_sample_size = cfg.logging.plot_sample_size + + ## Analysis - cnn_analysis = cnn_statistics( + if not use_wandb: + _plot_and_save_histories(plot_dir, histories) + + # Get CNN statistics and save them + cnn_stats = cnn_statistics( device, test_set, brain, - cfg.logging.channel_analysis, - cfg.logging.plot_sample_size, + channel_analysis, + plot_sample_size, ) + # Save CNN statistics + with open(analyses_dir / f"cnn_stats_epoch_{epoch}.json", "w") as f: + json.dump(asdict(cnn_stats), f, cls=NumpyEncoder) + if epoch == 0: - _perform_initialization_analysis(cfg, brain, objective, train_set, cnn_analysis) + _perform_initialization_analysis( + channel_analysis, + use_wandb, + analyses_dir, + plot_dir, + checkpoint_plot_dir, + run_dir, + brain, + objective, + train_set, + cnn_stats, + ) - _analyze_layers(cfg, cnn_analysis, epoch, copy_checkpoint) + _analyze_layers( + channel_analysis, + use_wandb, + plot_dir, + checkpoint_plot_dir, + cnn_stats, + epoch, + copy_checkpoint, + ) _perform_reconstruction_analysis( - cfg, device, brain, objective, train_set, test_set, epoch, copy_checkpoint + use_wandb, + analyses_dir, + plot_dir, + checkpoint_plot_dir, + device, + brain, + objective, + train_set, + test_set, + epoch, + copy_checkpoint, ) + hist_fig = plot_histories(histories) + _save_figure(plot_dir, "", "histories", hist_fig) + plt.close(hist_fig) + -def _plot_and_save_histories(cfg: DictConfig, histories: Dict[str, List[float]]): +def _plot_and_save_histories(plot_dir: Path, histories: Dict[str, List[float]]): hist_fig = plot_histories(histories) - _save_figure(cfg, "", "histories", hist_fig) + _save_figure(plot_dir, "", "histories", hist_fig) plt.close(hist_fig) def _perform_initialization_analysis( - cfg: DictConfig, + channel_analysis: bool, + use_wandb: bool, + analyses_dir: Path, + plot_dir: Path, + checkpoint_plot_dir: Path, + run_dir: Path, brain: Brain, objective: Objective[ContextT], train_set: Imageset, - cnn_analysis: Dict[str, Dict[str, FloatArray]], + cnn_stats: CNNStatistics, ): summary = brain.scan() - filepath = os.path.join(cfg.path.run_dir, "brain_summary.txt") - - with open(filepath, "w") as f: - f.write(summary) + filepath = run_dir / "brain_summary.txt" + filepath.write_text(summary) - if cfg.logging.use_wandb: - wandb.save(filepath, base_path=cfg.path.run_dir, policy="now") + if use_wandb: + wandb.save(str(filepath), base_path=run_dir, policy="now") - rf_sizes_fig = plot_receptive_field_sizes(cnn_analysis) - _process_figure(cfg, False, rf_sizes_fig, init_dir, "receptive_field_sizes", 0) + # TODO: This is a bit of a hack, we should refactor this to get the relevant information out of cnn_stats + rf_sizes_fig = plot_receptive_field_sizes(**asdict(cnn_stats)) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + rf_sizes_fig, + init_dir, + "receptive_field_sizes", + 0, + ) graph_fig = plot_brain_and_optimizers(brain, objective) - _process_figure(cfg, False, graph_fig, init_dir, "brain_graph", 0) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + graph_fig, + init_dir, + "brain_graph", + 0, + ) transforms = transform_base_images(train_set, num_steps=5, num_images=2) - transforms_fig = plot_transforms(**transforms) - _process_figure(cfg, False, transforms_fig, init_dir, "transforms", 0) + # Save transform statistics + transform_path = analyses_dir / "transforms.json" + with open(transform_path, "w") as f: + json.dump(asdict(transforms), f, cls=NumpyEncoder) - _analyze_input_layer(cfg, cnn_analysis["input"], cfg.logging.channel_analysis) + transforms_fig = plot_transforms(**asdict(transforms)) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + transforms_fig, + init_dir, + "transforms", + 0, + ) + + _analyze_input_layer( + use_wandb, + plot_dir, + checkpoint_plot_dir, + cnn_stats.layers["input"], + channel_analysis, + ) def _analyze_layers( - cfg: DictConfig, - cnn_analysis: Dict[str, Dict[str, FloatArray]], + channel_analysis: bool, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, + cnn_stats: CNNStatistics, epoch: int, copy_checkpoint: bool, ): - for layer_name, layer_data in cnn_analysis.items(): + for layer_name, layer_data in cnn_stats.layers.items(): if layer_name != "input": _analyze_regular_layer( - cfg, + use_wandb, + plot_dir, + checkpoint_plot_dir, layer_name, layer_data, epoch, copy_checkpoint, - cfg.logging.channel_analysis, + channel_analysis, ) def _analyze_input_layer( - cfg: DictConfig, - layer_data: Dict[str, FloatArray], + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, + layer_statistics: LayerStatistics, channel_analysis: bool, ): - layer_rfs = layer_receptive_field_plots(layer_data["receptive_fields"]) - _process_figure(cfg, False, layer_rfs, init_dir, "input_rfs", 0) + layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) + _process_figure( + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + layer_rfs, + init_dir, + "input_rfs", + 0, + ) if channel_analysis: - num_channels = int(layer_data["num_channels"]) + layer_dict = asdict(layer_statistics) + num_channels = int(layer_dict.pop("num_channels")) for channel in range(num_channels): - channel_fig = plot_channel_statistics(layer_data, "input", channel) + channel_fig = plot_channel_statistics( + **layer_dict, layer_name="input", channel=channel + ) _process_figure( - cfg, False, channel_fig, init_dir, f"input_channel_{channel}", 0 + use_wandb, + plot_dir, + checkpoint_plot_dir, + False, + channel_fig, + init_dir, + f"input_channel_{channel}", + 0, ) def _analyze_regular_layer( - cfg: DictConfig, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, layer_name: str, - layer_data: Dict[str, FloatArray], + layer_statistics: LayerStatistics, epoch: int, copy_checkpoint: bool, channel_analysis: bool, ): - layer_rfs = layer_receptive_field_plots(layer_data["receptive_fields"]) + layer_rfs = layer_receptive_field_plots(layer_statistics.receptive_fields) _process_figure( - cfg, copy_checkpoint, layer_rfs, "receptive_fields", f"{layer_name}", epoch + use_wandb, + plot_dir, + checkpoint_plot_dir, + copy_checkpoint, + layer_rfs, + "receptive_fields", + f"{layer_name}", + epoch, ) if channel_analysis: - num_channels = int(layer_data["num_channels"]) + layer_dict = asdict(layer_statistics) + num_channels = int(layer_dict.pop("num_channels")) for channel in range(num_channels): - channel_fig = plot_channel_statistics(layer_data, layer_name, channel) + channel_fig = plot_channel_statistics( + **layer_dict, layer_name=layer_name, channel=channel + ) + _process_figure( - cfg, + use_wandb, + plot_dir, + checkpoint_plot_dir, copy_checkpoint, channel_fig, f"{layer_name}_layer_channel_analysis", @@ -164,7 +326,10 @@ def _analyze_regular_layer( def _perform_reconstruction_analysis( - cfg: DictConfig, + use_wandb: bool, + analyses_dir: Path, + plot_dir: Path, + checkpoint_plot_dir: Path, device: torch.device, brain: Brain, objective: Objective[ContextT], @@ -181,12 +346,25 @@ def _perform_reconstruction_analysis( for decoder in reconstruction_decoders: norm_means, norm_stds = train_set.normalization_stats - rec_dict = reconstruct_images(device, brain, decoder, train_set, test_set, 5) + rec_dict = asdict( + reconstruct_images(device, brain, decoder, train_set, test_set, 5) + ) + # Save the reconstructions + rec_path = analyses_dir / f"{decoder}_reconstructions_epoch_{epoch}.json" + with open(rec_path, "w") as f: + json.dump(rec_dict, f, cls=NumpyEncoder) + recon_fig = plot_reconstructions( - norm_means, norm_stds, **rec_dict, num_samples=5 + norm_means, + norm_stds, + *rec_dict["train"].values(), + *rec_dict["test"].values(), + num_samples=5, ) _process_figure( - cfg, + use_wandb, + plot_dir, + checkpoint_plot_dir, copy_checkpoint, recon_fig, "reconstruction", @@ -195,19 +373,24 @@ def _perform_reconstruction_analysis( ) -def _save_figure(cfg: DictConfig, sub_dir: str, file_name: str, fig: Figure) -> None: - dir = os.path.join(cfg.path.plot_dir, sub_dir) - os.makedirs(dir, exist_ok=True) - file_name = os.path.join(dir, f"{file_name}.png") - fig.savefig(file_name) +### Helper Functions ### + +def _save_figure(plot_dir: Path, sub_dir: str, file_name: str, fig: Figure) -> None: + dir = plot_dir / sub_dir + dir.mkdir(exist_ok=True) + file_path = dir / f"{file_name}.png" + fig.savefig(file_path) -def _checkpoint_copy(cfg: DictConfig, sub_dir: str, file_name: str, epoch: int) -> None: - src_path = os.path.join(cfg.path.plot_dir, sub_dir, f"{file_name}.png") - dest_dir = os.path.join(cfg.path.checkpoint_plot_dir, f"epoch_{epoch}", sub_dir) - os.makedirs(dest_dir, exist_ok=True) - dest_path = os.path.join(dest_dir, f"{file_name}.png") +def _checkpoint_copy( + plot_dir: Path, checkpoint_plot_dir: Path, sub_dir: str, file_name: str, epoch: int +) -> None: + src_path = plot_dir / sub_dir / f"{file_name}.png" + + dest_dir = checkpoint_plot_dir / f"epoch_{epoch}" / sub_dir + dest_dir.mkdir(parents=True, exist_ok=True) + dest_path = dest_dir / f"{file_name}.png" shutil.copy2(src_path, dest_path) @@ -230,19 +413,21 @@ def capitalize_part(part: str) -> str: def _process_figure( - cfg: DictConfig, + use_wandb: bool, + plot_dir: Path, + checkpoint_plot_dir: Path, copy_checkpoint: bool, fig: Figure, sub_dir: str, file_name: str, epoch: int, ) -> None: - if cfg.logging.use_wandb: + if use_wandb: title = f"{_wandb_title(sub_dir)}/{_wandb_title(file_name)}" img = wandb.Image(fig) wandb.log({title: img}, commit=False) else: - _save_figure(cfg, sub_dir, file_name, fig) + _save_figure(plot_dir, sub_dir, file_name, fig) if copy_checkpoint: - _checkpoint_copy(cfg, sub_dir, file_name, epoch) + _checkpoint_copy(plot_dir, checkpoint_plot_dir, sub_dir, file_name, epoch) plt.close(fig) diff --git a/runner/frameworks/classification/initialize.py b/runner/frameworks/classification/initialize.py index 5956c26..4213eec 100644 --- a/runner/frameworks/classification/initialize.py +++ b/runner/frameworks/classification/initialize.py @@ -2,7 +2,9 @@ ### Imports ### import logging -import os +from dataclasses import dataclass +from os import getenv +from pathlib import Path from typing import Any, Dict, List, Tuple, cast import omegaconf @@ -15,68 +17,115 @@ from retinal_rl.models.brain import Brain from runner.util import save_checkpoint +### Infrastructure ### + + # Initialize the logger logger = logging.getLogger(__name__) +@dataclass +class InitConfig: + """Configuration for initialization.""" + + # Paths + data_dir: Path + checkpoint_dir: Path + plot_dir: Path + wandb_dir: Path + + # WandB settings + use_wandb: bool + wandb_project: str + wandb_entity: str | None + wandb_preempt: bool + + # Run settings + run_name: str + max_checkpoints: int + + @classmethod + def from_dict_config(cls, cfg: DictConfig) -> "InitConfig": + """Create InitConfig from a DictConfig.""" + return cls( + data_dir=Path(cfg.path.data_dir), + checkpoint_dir=Path(cfg.path.checkpoint_dir), + plot_dir=Path(cfg.path.plot_dir), + wandb_dir=Path(cfg.path.wandb_dir), + use_wandb=cfg.logging.use_wandb, + wandb_project=cfg.logging.wandb_project, + wandb_entity=None + if cfg.logging.wandb_entity == "default" + else cfg.logging.wandb_entity, + wandb_preempt=cfg.logging.wandb_preempt, + run_name=cfg.run_name, + max_checkpoints=cfg.logging.max_checkpoints, + ) + + +### Initialization ### + + def initialize( - cfg: DictConfig, + dict_cfg: DictConfig, brain: Brain, optimizer: Optimizer, ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: """Initialize the Brain, Optimizers, and training histories. Checks whether the experiment directory exists and loads the model and history if it does. Otherwise, initializes a new model and history.""" - wandb_sweep_id = os.getenv("WANDB_SWEEP_ID", "local") + + cfg = InitConfig.from_dict_config(dict_cfg) + wandb_sweep_id = getenv("WANDB_SWEEP_ID", "local") logger.info(f"Run Name: {cfg.run_name}") logger.info(f"(WANDB) Sweep ID: {wandb_sweep_id}") # If continuing from a previous run, load the model and history - if os.path.exists(cfg.path.data_dir): + if cfg.data_dir.exists(): return _initialize_reload(cfg, brain, optimizer) # else, initialize a new model and history - return _initialize_create(cfg, brain, optimizer) + logger.info( + f"Experiment data path {cfg.data_dir} does not exist. Initializing {cfg.run_name}." + ) + + cfg_backup = omegaconf.OmegaConf.to_container( + dict_cfg, resolve=True, throw_on_missing=True + ) + cfg_backup = cast(Dict[str, Any], cfg_backup) + + return _initialize_create(cfg, cfg_backup, brain, optimizer) def _initialize_create( - cfg: DictConfig, + cfg: InitConfig, + cfg_backup: dict[Any, Any], brain: Brain, optimizer: Optimizer, ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: epoch = 0 - logger.info( - f"Experiment path {cfg.path.run_dir} does not exist. Initializing {cfg.run_name}." - ) - # initialize the training histories histories: Dict[str, List[float]] = {} - # create the directories - os.makedirs(cfg.path.data_dir) - os.makedirs(cfg.path.checkpoint_dir) - if not cfg.logging.use_wandb: - os.makedirs(cfg.path.plot_dir) - + cfg.data_dir.mkdir(parents=True, exist_ok=True) + cfg.checkpoint_dir.mkdir(parents=True, exist_ok=True) + if not cfg.use_wandb: + cfg.plot_dir.mkdir(parents=True, exist_ok=True) else: - os.makedirs(cfg.path.wandb_dir) + cfg.wandb_dir.mkdir(parents=True, exist_ok=True) # convert DictConfig to dict - dict_conf = omegaconf.OmegaConf.to_container( - cfg, resolve=True, throw_on_missing=True - ) - dict_conf = cast(Dict[str, Any], dict_conf) - entity = cfg.logging.wandb_entity + entity = cfg.wandb_entity if entity == "default": entity = None wandb.init( - project=cfg.logging.wandb_project, + project=cfg.wandb_project, entity=entity, group=HydraConfig.get().runtime.choices.experiment, job_type=HydraConfig.get().runtime.choices.brain, - config=dict_conf, + config=cfg_backup, name=cfg.run_name, id=cfg.run_name, - dir=cfg.path.wandb_dir, + dir=cfg.wandb_dir, ) - if cfg.logging.wandb_preempt: + if cfg.wandb_preempt: wandb.mark_preempting() wandb.define_metric("Epoch") @@ -84,9 +133,9 @@ def _initialize_create( wandb.define_metric("Test/*", step_metric="Epoch") save_checkpoint( - cfg.path.data_dir, - cfg.path.checkpoint_dir, - cfg.logging.max_checkpoints, + cfg.data_dir, + cfg.checkpoint_dir, + cfg.max_checkpoints, brain, optimizer, histories, @@ -97,15 +146,15 @@ def _initialize_create( def _initialize_reload( - cfg: DictConfig, brain: Brain, optimizer: Optimizer + cfg: InitConfig, brain: Brain, optimizer: Optimizer ) -> Tuple[Brain, Optimizer, Dict[str, List[float]], int]: logger.info( - f"Experiment dir {cfg.path.run_dir} exists. Loading existing model and history." + f"Experiment data dir {cfg.data_dir} exists. Loading existing model and history." ) - checkpoint_file = os.path.join(cfg.path.data_dir, "current_checkpoint.pt") + checkpoint_file = cfg.data_dir / "current_checkpoint.pt" # check if files don't exist - if not os.path.exists(checkpoint_file): + if not checkpoint_file.exists(): logger.error(f"File not found: {checkpoint_file}") raise FileNotFoundError("Checkpoint file does not exist.") @@ -115,22 +164,22 @@ def _initialize_reload( optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) completed_epochs = checkpoint["completed_epochs"] history = checkpoint["histories"] - entity = cfg.logging.wandb_entity + entity = cfg.wandb_entity if entity == "default": entity = None - if cfg.logging.use_wandb: + if cfg.use_wandb: wandb.init( - project=cfg.logging.wandb_project, + project=cfg.wandb_project, entity=entity, group=HydraConfig.get().runtime.choices.experiment, job_type=HydraConfig.get().runtime.choices.brain, name=cfg.run_name, id=cfg.run_name, resume="must", - dir=cfg.path.wandb_dir, + dir=cfg.wandb_dir, ) - if cfg.logging.wandb_preempt: + if cfg.wandb_preempt: wandb.mark_preempting() return brain, optimizer, history, completed_epochs diff --git a/runner/frameworks/classification/train.py b/runner/frameworks/classification/train.py index dcf187c..1c52195 100644 --- a/runner/frameworks/classification/train.py +++ b/runner/frameworks/classification/train.py @@ -2,6 +2,7 @@ import logging import time +from pathlib import Path from typing import Dict, List import torch @@ -47,11 +48,23 @@ def train( history (Dict[str, List[float]]): The training history. """ + + use_wandb = cfg.logging.use_wandb + + data_dir = Path(cfg.path.data_dir) + checkpoint_dir = Path(cfg.path.checkpoint_dir) + + max_checkpoints = cfg.logging.max_checkpoints + checkpoint_step = cfg.logging.checkpoint_step + + num_epochs = cfg.optimizer.num_epochs + num_workers = cfg.system.num_workers + trainloader = DataLoader( - train_set, batch_size=64, shuffle=True, num_workers=cfg.system.num_workers + train_set, batch_size=64, shuffle=True, num_workers=num_workers ) testloader = DataLoader( - test_set, batch_size=64, shuffle=False, num_workers=cfg.system.num_workers + test_set, batch_size=64, shuffle=False, num_workers=num_workers ) wall_time = time.time() @@ -103,7 +116,7 @@ def train( wall_time = new_wall_time logger.info(f"Initialization complete. Wall Time: {epoch_wall_time:.2f}s.") - if cfg.logging.use_wandb: + if use_wandb: _wandb_log_statistics(initial_epoch, epoch_wall_time, history) else: @@ -111,7 +124,7 @@ def train( f"Reloading complete. Resuming training from epoch {initial_epoch}." ) - for epoch in range(initial_epoch + 1, cfg.optimizer.num_epochs + 1): + for epoch in range(initial_epoch + 1, num_epochs + 1): brain, history = run_epoch( device, brain, @@ -128,13 +141,13 @@ def train( wall_time = new_wall_time logger.info(f"Epoch {epoch} complete. Wall Time: {epoch_wall_time:.2f}s.") - if epoch % cfg.logging.checkpoint_step == 0: + if epoch % checkpoint_step == 0: logger.info("Saving checkpoint and plots.") save_checkpoint( - cfg.path.data_dir, - cfg.path.checkpoint_dir, - cfg.logging.max_checkpoints, + data_dir, + checkpoint_dir, + max_checkpoints, brain, optimizer, history, @@ -154,7 +167,7 @@ def train( ) logger.info("Analysis complete.") - if cfg.logging.use_wandb: + if use_wandb: _wandb_log_statistics(epoch, epoch_wall_time, history) diff --git a/runner/util.py b/runner/util.py index fbfea1b..1ae4e4f 100644 --- a/runner/util.py +++ b/runner/util.py @@ -5,6 +5,7 @@ import logging import os import shutil +from pathlib import Path from typing import Any, Dict, List, Tuple, cast import networkx as nx @@ -25,8 +26,8 @@ def save_checkpoint( - data_dir: str, - checkpoint_dir: str, + data_dir: Path, + checkpoint_dir: Path, max_checkpoints: int, brain: nn.Module, optimizer: Optimizer, @@ -34,8 +35,8 @@ def save_checkpoint( completed_epochs: int, ) -> None: """Save a checkpoint of the model and optimizer state.""" - current_file = os.path.join(data_dir, "current_checkpoint.pt") - checkpoint_file = os.path.join(checkpoint_dir, f"epoch_{completed_epochs}.pt") + current_file = data_dir / "current_checkpoint.pt" + checkpoint_file = checkpoint_dir / f"epoch_{completed_epochs}.pt" checkpoint_dict: Dict[str, Any] = { "completed_epochs": completed_epochs, "brain_state_dict": brain.state_dict(), @@ -59,11 +60,10 @@ def save_checkpoint( os.remove(os.path.join(checkpoint_dir, checkpoints.pop())) -def delete_results(cfg: DictConfig) -> None: +def delete_results(run_dir: Path) -> None: """Delete the results directory.""" - run_dir: str = cfg.path.run_dir - if not os.path.exists(run_dir): + if not run_dir.exists(): print(f"Directory {run_dir} does not exist.") return