Skip to content

Commit

Permalink
Merge pull request #79 from mila-iqia/checkpointing_on_structure_factor
Browse files Browse the repository at this point in the history
Checkpointing on structure factor
  • Loading branch information
rousseab authored Sep 26, 2024
2 parents 812996c + 9510afa commit e9f18cc
Show file tree
Hide file tree
Showing 33 changed files with 1,248 additions and 643 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LANGEVIN_EXPLORATION_DIRECTORY
from crystal_diffusion.analysis.analytic_score.utils import (
get_exact_samples, get_silicon_supercell)
from crystal_diffusion.callbacks.sampling_callback import logger
from crystal_diffusion.callbacks.sampling_visualization_callback import logger
from crystal_diffusion.generators.langevin_generator import LangevinGenerator
from crystal_diffusion.generators.predictor_corrector_position_generator import \
PredictorCorrectorSamplingParameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_exact_samples, get_silicon_supercell)
from crystal_diffusion.callbacks.loss_monitoring_callback import \
LossMonitoringCallback
from crystal_diffusion.callbacks.sampling_callback import \
from crystal_diffusion.callbacks.sampling_visualization_callback import \
PredictorCorrectorDiffusionSamplingCallback
from crystal_diffusion.generators.predictor_corrector_position_generator import \
PredictorCorrectorSamplingParameters
Expand Down
8 changes: 4 additions & 4 deletions crystal_diffusion/callbacks/analysis_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from crystal_diffusion.analysis import PLOT_STYLE_PATH
from crystal_diffusion.analysis.analytic_score.utils import \
get_relative_harmonic_energy
from crystal_diffusion.callbacks.sampling_callback import \
DiffusionSamplingCallback
from crystal_diffusion.callbacks.sampling_visualization_callback import \
SamplingVisualizationCallback
from crystal_diffusion.generators.position_generator import SamplingParameters
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

Expand All @@ -22,7 +22,7 @@
plt.style.use(PLOT_STYLE_PATH)


class HarmonicEnergyDiffusionSamplingCallback(DiffusionSamplingCallback):
class HarmonicEnergyDiffusionSamplingCallback(SamplingVisualizationCallback):
"""Callback class to periodically generate samples and log their energies."""

def __init__(self, noise_parameters: NoiseParameters,
Expand Down Expand Up @@ -54,7 +54,7 @@ def _compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) ->
@staticmethod
def _plot_energy_histogram(sample_energies: np.ndarray, validation_dataset_energies: np.array,
epoch: int) -> plt.figure:
fig = DiffusionSamplingCallback._plot_energy_histogram(sample_energies, validation_dataset_energies, epoch)
fig = SamplingVisualizationCallback._plot_energy_histogram(sample_energies, validation_dataset_energies, epoch)

fig.suptitle(f'Sampling Unitless Harmonic Potential Energy Distributions\nEpoch {epoch}')
ax1 = fig.axes[0]
Expand Down
6 changes: 3 additions & 3 deletions crystal_diffusion/callbacks/callback_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from crystal_diffusion.callbacks.loss_monitoring_callback import \
instantiate_loss_monitoring_callback
from crystal_diffusion.callbacks.sampling_callback import \
instantiate_diffusion_sampling_callback
from crystal_diffusion.callbacks.sampling_visualization_callback import \
instantiate_sampling_visualization_callback
from crystal_diffusion.callbacks.standard_callbacks import (
CustomProgressBar, instantiate_early_stopping_callback,
instantiate_model_checkpoint_callbacks)

OPTIONAL_CALLBACK_DICTIONARY = dict(early_stopping=instantiate_early_stopping_callback,
model_checkpoint=instantiate_model_checkpoint_callbacks,
diffusion_sampling=instantiate_diffusion_sampling_callback,
sampling_visualization=instantiate_sampling_visualization_callback,
loss_monitoring=instantiate_loss_monitoring_callback)


Expand Down
414 changes: 0 additions & 414 deletions crystal_diffusion/callbacks/sampling_callback.py

This file was deleted.

297 changes: 297 additions & 0 deletions crystal_diffusion/callbacks/sampling_visualization_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, AnyStr, Dict

import numpy as np
import torch
from matplotlib import pyplot as plt
from pytorch_lightning import Callback, LightningModule, Trainer

from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.loggers.logger_loader import log_figure

logger = logging.getLogger(__name__)

plt.style.use(PLOT_STYLE_PATH)


@dataclass(kw_only=True)
class SamplingVisualizationParameters:
"""Parameters to decide what to plot and write to disk."""
record_every_n_epochs: int = 1
first_record_epoch: int = 1
record_trajectories: bool = True
record_energies: bool = True
record_structure: bool = True


def instantiate_sampling_visualization_callback(
callback_params: Dict[AnyStr, Any], output_directory: str, verbose: bool
) -> Dict[str, Callback]:
"""Instantiate the Diffusion Sampling callback."""
sampling_visualization_parameters = SamplingVisualizationParameters(
**callback_params
)

callback = SamplingVisualizationCallback(
sampling_visualization_parameters, output_directory
)

return dict(sampling_visualization=callback)


class SamplingVisualizationCallback(Callback):
"""Callback class to periodically generate samples and log their energies."""

def __init__(
self,
sampling_visualization_parameters: SamplingVisualizationParameters,
output_directory: str,
):
"""Init method."""
self.parameters = sampling_visualization_parameters
self.output_directory = output_directory

if self.parameters.record_energies:
self.sample_energies_output_directory = os.path.join(
output_directory, "energy_samples"
)
Path(self.sample_energies_output_directory).mkdir(
parents=True, exist_ok=True
)

if self.parameters.record_structure:
self.sample_distances_output_directory = os.path.join(
output_directory, "distance_samples"
)
Path(self.sample_distances_output_directory).mkdir(
parents=True, exist_ok=True
)

if self.parameters.record_trajectories:
self.sample_trajectories_output_directory = os.path.join(
output_directory, "trajectory_samples"
)
Path(self.sample_trajectories_output_directory).mkdir(
parents=True, exist_ok=True
)

def on_validation_end(self, trainer: Trainer, pl_model: LightningModule) -> None:
"""On validation end."""
if not self._compute_results_at_this_epoch(trainer.current_epoch):
return

if self.parameters.record_energies:
assert (
pl_model.energy_ks_metric is not None
), "The energy_ks_metric is absent. Energy calculation must be requested in order to be visualized!"
reference_energies = (
pl_model.energy_ks_metric.reference_samples_metric.compute()
)
sample_energies = (
pl_model.energy_ks_metric.predicted_samples_metric.compute()
)
energy_output_path = os.path.join(
self.sample_energies_output_directory,
f"energies_sample_epoch={trainer.current_epoch}.pt",
)
torch.save(sample_energies, energy_output_path)

sample_energies = sample_energies.cpu().numpy()
reference_energies = reference_energies.cpu().numpy()

fig1 = self._plot_energy_histogram(
sample_energies, reference_energies, trainer.current_epoch
)
fig2 = self._plot_energy_quantiles(
sample_energies, reference_energies, trainer.current_epoch
)

for pl_logger in trainer.loggers:
log_figure(
figure=fig1,
global_step=trainer.global_step,
dataset="validation",
pl_logger=pl_logger,
name="energy_distribution",
)
log_figure(
figure=fig2,
global_step=trainer.global_step,
dataset="validation",
pl_logger=pl_logger,
name="energy_quantiles",
)
plt.close(fig1)
plt.close(fig2)

if self.parameters.record_structure:
assert pl_model.structure_ks_metric is not None, (
"The structure_ks_metric is absent. Structure factor calculation "
"must be requested in order to be visualized!"
)

reference_distances = (
pl_model.structure_ks_metric.reference_samples_metric.compute()
)
sample_distances = (
pl_model.structure_ks_metric.predicted_samples_metric.compute()
)

distance_output_path = os.path.join(
self.sample_distances_output_directory,
f"distances_sample_epoch={trainer.current_epoch}.pt",
)

torch.save(sample_distances, distance_output_path)
fig = self._plot_distance_histogram(
sample_distances.numpy(),
reference_distances.numpy(),
trainer.current_epoch,
)

for pl_logger in trainer.loggers:
log_figure(
figure=fig,
global_step=trainer.global_step,
dataset="validation",
pl_logger=pl_logger,
name="distances",
)
plt.close(fig)

if self.parameters.record_trajectories:
assert (
pl_model.generator is not None
), "Cannot record trajectories if a generator has not be created."

pickle_output_path = os.path.join(
self.sample_trajectories_output_directory,
f"trajectories_sample_epoch={trainer.current_epoch}.pt",
)
pl_model.generator.sample_trajectory_recorder.write_to_pickle(
pickle_output_path
)

def _compute_results_at_this_epoch(self, current_epoch: int) -> bool:
"""Check if results should be computed at this epoch."""
if (
current_epoch % self.parameters.record_every_n_epochs == 0
and current_epoch >= self.parameters.first_record_epoch
):
return True
else:
return False

@staticmethod
def _plot_energy_quantiles(
sample_energies: np.ndarray, validation_dataset_energies: np.array, epoch: int
) -> plt.figure:
"""Generate a plot of the energy quantiles."""
list_q = np.linspace(0, 1, 101)
sample_quantiles = np.quantile(sample_energies, list_q)
dataset_quantiles = np.quantile(validation_dataset_energies, list_q)

fig = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig.suptitle(f"Sampling Energy Quantiles\nEpoch {epoch}")
ax = fig.add_subplot(111)

label = f"Samples \n(total count = {len(sample_energies)})"
ax.plot(100 * list_q, sample_quantiles, "-", lw=5, color="red", label=label)

label = f"Validation Data \n(count = {len(validation_dataset_energies)})"
ax.plot(
100 * list_q, dataset_quantiles, "--", lw=10, color="green", label=label
)
ax.set_xlabel("Quantile (%)")
ax.set_ylabel("Energy (eV)")
ax.set_xlim(-0.1, 100.1)
ax.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6)
fig.tight_layout()

return fig

@staticmethod
def _plot_energy_histogram(
sample_energies: np.ndarray, validation_dataset_energies: np.array, epoch: int
) -> plt.figure:
"""Generate a plot of the energy samples."""
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)

minimum_energy = validation_dataset_energies.min()
maximum_energy = validation_dataset_energies.max()
energy_range = maximum_energy - minimum_energy

emin = minimum_energy - 0.2 * energy_range
emax = maximum_energy + 0.2 * energy_range
bins = np.linspace(emin, emax, 101)

number_of_samples_in_range = np.logical_and(
sample_energies >= emin, sample_energies <= emax
).sum()

fig.suptitle(f"Sampling Energy Distributions\nEpoch {epoch}")

common_params = dict(density=True, bins=bins, histtype="stepfilled", alpha=0.25)

ax1 = fig.add_subplot(111)

ax1.hist(
sample_energies,
**common_params,
label=f"Samples \n(total count = {len(sample_energies)}, in range = {number_of_samples_in_range})",
color="red",
)
ax1.hist(
validation_dataset_energies,
**common_params,
label=f"Validation Data \n(count = {len(validation_dataset_energies)})",
color="green",
)

ax1.set_xlabel("Energy (eV)")
ax1.set_ylabel("Density")
ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6)
fig.tight_layout()
return fig

@staticmethod
def _plot_distance_histogram(
sample_distances: np.ndarray, validation_dataset_distances: np.array, epoch: int
) -> plt.figure:
"""Generate a plot of the inter-atomic distances of the samples."""
fig = plt.figure(figsize=PLEASANT_FIG_SIZE)

maximum_distance = validation_dataset_distances.max()

dmin = 0.0
dmax = maximum_distance + 0.1
bins = np.linspace(dmin, dmax, 251)

fig.suptitle(f"Sampling Distances Distribution\nEpoch {epoch}")

common_params = dict(density=True, bins=bins, histtype="stepfilled", alpha=0.25)

ax1 = fig.add_subplot(111)

ax1.hist(
sample_distances,
**common_params,
label=f"Samples \n(total count = {len(sample_distances)})",
color="red",
)
ax1.hist(
validation_dataset_distances,
**common_params,
label=f"Validation Data \n(count = {len(validation_dataset_distances)})",
color="green",
)

ax1.set_xlabel(r"Distance ($\AA$)")
ax1.set_ylabel("Density")
ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6)
ax1.set_xlim(left=dmin, right=dmax)
fig.tight_layout()
return fig
Loading

0 comments on commit e9f18cc

Please sign in to comment.