From 8cb7a7011900f4bb2a8783424fe762dcdb99d51f Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 15:33:13 -0400 Subject: [PATCH 01/20] Cleaning up the unit cell generation. --- .../callbacks/sampling_callback.py | 27 +++++-------------- crystal_diffusion/utils/structure_utils.py | 14 ++++++++++ tests/utils/test_structure_utils.py | 18 +++++++++++++ 3 files changed, 39 insertions(+), 20 deletions(-) create mode 100644 tests/utils/test_structure_utils.py diff --git a/crystal_diffusion/callbacks/sampling_callback.py b/crystal_diffusion/callbacks/sampling_callback.py index 59867403..6c47fe05 100644 --- a/crystal_diffusion/callbacks/sampling_callback.py +++ b/crystal_diffusion/callbacks/sampling_callback.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import Any, AnyStr, Dict, List, Optional, Tuple +from typing import Any, AnyStr, Dict, Optional, Tuple import numpy as np import scipy.stats as ss @@ -26,7 +26,8 @@ from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.basis_transformations import \ get_positions_from_coordinates -from crystal_diffusion.utils.structure_utils import compute_distances_in_batch +from crystal_diffusion.utils.structure_utils import ( + compute_distances_in_batch, get_orthogonal_basis_vectors) logger = logging.getLogger(__name__) @@ -99,20 +100,6 @@ def __init__(self, noise_parameters: NoiseParameters, self._initialize_validation_energies_array() self._initialize_validation_distance_array() - @staticmethod - def _get_orthogonal_unit_cell(batch_size: int, cell_dimensions: List[float]) -> torch.Tensor: - """Get orthogonal unit cell. - - Args: - batch_size: number of required repetitions of the unit cell. - cell_dimensions : list of dimensions that correspond to the sides of the unit cell. - - Returns: - unit_cell: a diagonal matrix with the dimensions along the diagonal. - """ - unit_cell = torch.diag(torch.Tensor(cell_dimensions)).unsqueeze(0).repeat(batch_size, 1, 1) - return unit_cell - @staticmethod def compute_kolmogorov_smirnov_distance_and_pvalue(sampling_energies: np.ndarray, reference_energies: np.ndarray) -> Tuple[float, float]: @@ -169,9 +156,9 @@ def _create_generator(self, pl_model: LightningModule) -> PositionGenerator: def _create_unit_cell(self, pl_model) -> torch.Tensor: """Create the batch of unit cells needed by the generative model.""" # TODO we will have to sample unit cell dimensions at some points instead of working with fixed size - unit_cell = (self._get_orthogonal_unit_cell(batch_size=self.sampling_parameters.number_of_samples, - cell_dimensions=self.sampling_parameters.cell_dimensions) - .to(pl_model.device)) + unit_cell = ( + get_orthogonal_basis_vectors(batch_size=self.sampling_parameters.number_of_samples, + cell_dimensions=self.sampling_parameters.cell_dimensions).to(pl_model.device)) return unit_cell @staticmethod @@ -243,7 +230,7 @@ def _compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> """Compute energies from samples.""" batch_size = batch_relative_coordinates.shape[0] cell_dimensions = self.sampling_parameters.cell_dimensions - basis_vectors = self._get_orthogonal_unit_cell(batch_size, cell_dimensions) + basis_vectors = get_orthogonal_basis_vectors(batch_size, cell_dimensions) batch_cartesian_positions = get_positions_from_coordinates(batch_relative_coordinates, basis_vectors) atom_types = np.ones(self.sampling_parameters.number_of_atoms, dtype=int) diff --git a/crystal_diffusion/utils/structure_utils.py b/crystal_diffusion/utils/structure_utils.py index 23cd3ffb..f68e9644 100644 --- a/crystal_diffusion/utils/structure_utils.py +++ b/crystal_diffusion/utils/structure_utils.py @@ -91,3 +91,17 @@ def compute_distances_in_batch(cartesian_positions: torch.Tensor, unit_cell: tor # Identify neighbors within the radial_cutoff, but avoiding self. valid_neighbor_mask = torch.logical_and(zero < distances, distances <= radial_cutoff) return distances[valid_neighbor_mask] + + +def get_orthogonal_basis_vectors(batch_size: int, cell_dimensions: List[float]) -> torch.Tensor: + """Get orthogonal basis vectors. + + Args: + batch_size: number of required repetitions of the basis vectors. + cell_dimensions : list of dimensions that correspond to the sides of the unit cell. + + Returns: + basis_vectors: a diagonal matrix with the dimensions along the diagonal. + """ + basis_vectors = torch.diag(torch.Tensor(cell_dimensions)).unsqueeze(0).repeat(batch_size, 1, 1) + return basis_vectors diff --git a/tests/utils/test_structure_utils.py b/tests/utils/test_structure_utils.py new file mode 100644 index 00000000..4fd9b0d9 --- /dev/null +++ b/tests/utils/test_structure_utils.py @@ -0,0 +1,18 @@ +import torch + +from crystal_diffusion.utils.structure_utils import \ + get_orthogonal_basis_vectors + + +def test_get_orthogonal_basis_vectors(): + + cell_dimensions = [12.34, 8.32, 7.12] + batch_size = 16 + + computed_basis_vectors = get_orthogonal_basis_vectors(batch_size, cell_dimensions) + + expected_basis_vectors = torch.zeros_like(computed_basis_vectors) + + for d, acell in enumerate(cell_dimensions): + expected_basis_vectors[:, d, d] = acell + torch.testing.assert_allclose(computed_basis_vectors, expected_basis_vectors) From f83e48e64f97c76d6347b7c736f11613246551a0 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 15:35:23 -0400 Subject: [PATCH 02/20] Full service sampling method. --- crystal_diffusion/generators/sampling.py | 55 +++++++++++++ tests/generators/test_sampling.py | 99 ++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 crystal_diffusion/generators/sampling.py create mode 100644 tests/generators/test_sampling.py diff --git a/crystal_diffusion/generators/sampling.py b/crystal_diffusion/generators/sampling.py new file mode 100644 index 00000000..13661934 --- /dev/null +++ b/crystal_diffusion/generators/sampling.py @@ -0,0 +1,55 @@ +import logging + +import torch + +from crystal_diffusion.generators.position_generator import ( + PositionGenerator, SamplingParameters) +from crystal_diffusion.namespace import (CARTESIAN_POSITIONS, + RELATIVE_COORDINATES, UNIT_CELL) +from crystal_diffusion.utils.basis_transformations import \ + get_positions_from_coordinates + +logger = logging.getLogger(__name__) + + +def create_batch_of_samples(generator: PositionGenerator, + sampling_parameters: SamplingParameters, + device: torch.device): + """Create batch of samples. + + Utility function to drive the generation of samples. + + Args: + generator : position generator. + sampling_parameters : parameters defining how to sample. + device: device where the generator is located. + + Returns: + sample_batch: drawn samples in the same dictionary format as the training data. + """ + logger.info("Creating a batch of samples") + number_of_samples = sampling_parameters.number_of_samples + cell_dimensions = sampling_parameters.cell_dimensions + basis_vectors = torch.diag(torch.Tensor(cell_dimensions)).unsqueeze(0).repeat(number_of_samples, 1, 1).to(device) + + if sampling_parameters.sample_batchsize is None: + sample_batch_size = number_of_samples + else: + sample_batch_size = sampling_parameters.sample_batchsize + + list_sampled_relative_coordinates = [] + for sampling_batch_indices in torch.split(torch.arange(number_of_samples), sample_batch_size): + basis_vectors_ = basis_vectors[sampling_batch_indices] + sampled_relative_coordinates = generator.sample(len(sampling_batch_indices), + unit_cell=basis_vectors_, + device=device) + list_sampled_relative_coordinates.append(sampled_relative_coordinates) + + relative_coordinates = torch.concat(list_sampled_relative_coordinates) + cartesian_positions = get_positions_from_coordinates(relative_coordinates, basis_vectors) + + batch = {CARTESIAN_POSITIONS: cartesian_positions, + RELATIVE_COORDINATES: relative_coordinates, + UNIT_CELL: basis_vectors} + + return batch diff --git a/tests/generators/test_sampling.py b/tests/generators/test_sampling.py new file mode 100644 index 00000000..9d813170 --- /dev/null +++ b/tests/generators/test_sampling.py @@ -0,0 +1,99 @@ +import einops +import pytest +import torch + +from crystal_diffusion.generators.position_generator import ( + PositionGenerator, SamplingParameters) +from crystal_diffusion.generators.sampling import create_batch_of_samples +from crystal_diffusion.namespace import (CARTESIAN_POSITIONS, + RELATIVE_COORDINATES, UNIT_CELL) +from crystal_diffusion.utils.basis_transformations import \ + get_positions_from_coordinates + + +class DummyGenerator(PositionGenerator): + def __init__(self, relative_coordinates): + self._relative_coordinates = relative_coordinates + self._counter = 0 + + def initialize(self, number_of_samples: int): + pass + + def sample( + self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor + ) -> torch.Tensor: + self._counter += number_of_samples + return self._relative_coordinates[self._counter - number_of_samples:self._counter] + + +@pytest.fixture +def device(): + return torch.device("cpu") + + +@pytest.fixture +def number_of_samples(): + return 16 + + +@pytest.fixture +def number_of_atoms(): + return 8 + + +@pytest.fixture +def spatial_dimensions(): + return 3 + + +@pytest.fixture +def relative_coordinates(number_of_samples, number_of_atoms, spatial_dimensions): + return torch.rand(number_of_samples, number_of_atoms, spatial_dimensions) + + +@pytest.fixture +def cell_dimensions(spatial_dimensions): + return list((10 * torch.rand(spatial_dimensions)).numpy()) + + +@pytest.fixture +def generator(relative_coordinates): + return DummyGenerator(relative_coordinates) + + +@pytest.fixture +def sampling_parameters( + spatial_dimensions, number_of_atoms, number_of_samples, cell_dimensions +): + return SamplingParameters( + algorithm="dummy", + spatial_dimension=spatial_dimensions, + number_of_atoms=number_of_atoms, + number_of_samples=number_of_samples, + sample_batchsize=2, + cell_dimensions=cell_dimensions, + ) + + +def test_create_batch_of_samples( + generator, sampling_parameters, device, relative_coordinates, cell_dimensions +): + computed_samples = create_batch_of_samples(generator, sampling_parameters, device) + + batch_size = computed_samples[UNIT_CELL].shape[0] + + expected_basis_vectors = einops.repeat( + torch.diag(torch.tensor(cell_dimensions)), "d1 d2 -> b d1 d2", b=batch_size + ) + + expected_cartesian_coordinates = get_positions_from_coordinates( + relative_coordinates, expected_basis_vectors + ) + + torch.testing.assert_allclose( + computed_samples[RELATIVE_COORDINATES], relative_coordinates + ) + torch.testing.assert_allclose(computed_samples[UNIT_CELL], expected_basis_vectors) + torch.testing.assert_allclose( + computed_samples[CARTESIAN_POSITIONS], expected_cartesian_coordinates + ) From 003cccdbb31154813cba9e9a72b1586e7bb6ced5 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 15:37:36 -0400 Subject: [PATCH 03/20] Cleaner sampling. --- crystal_diffusion/generators/sampling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crystal_diffusion/generators/sampling.py b/crystal_diffusion/generators/sampling.py index 13661934..1c3e357b 100644 --- a/crystal_diffusion/generators/sampling.py +++ b/crystal_diffusion/generators/sampling.py @@ -8,6 +8,8 @@ RELATIVE_COORDINATES, UNIT_CELL) from crystal_diffusion.utils.basis_transformations import \ get_positions_from_coordinates +from crystal_diffusion.utils.structure_utils import \ + get_orthogonal_basis_vectors logger = logging.getLogger(__name__) @@ -30,7 +32,7 @@ def create_batch_of_samples(generator: PositionGenerator, logger.info("Creating a batch of samples") number_of_samples = sampling_parameters.number_of_samples cell_dimensions = sampling_parameters.cell_dimensions - basis_vectors = torch.diag(torch.Tensor(cell_dimensions)).unsqueeze(0).repeat(number_of_samples, 1, 1).to(device) + basis_vectors = get_orthogonal_basis_vectors(number_of_samples, cell_dimensions).to(device) if sampling_parameters.sample_batchsize is None: sample_batch_size = number_of_samples From 2466372747d7a7d0c1c06db1c16f174aace04944 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 16:40:44 -0400 Subject: [PATCH 04/20] Make the spatial dimension an input parameter. --- crystal_diffusion/utils/neighbors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crystal_diffusion/utils/neighbors.py b/crystal_diffusion/utils/neighbors.py index c304e19c..0bf28f0c 100644 --- a/crystal_diffusion/utils/neighbors.py +++ b/crystal_diffusion/utils/neighbors.py @@ -186,7 +186,8 @@ def get_periodic_adjacency_information(cartesian_positions: torch.Tensor, number_of_edges=number_of_edges) -def _get_relative_coordinates_lattice_vectors(number_of_shells: int = 1) -> torch.Tensor: +def _get_relative_coordinates_lattice_vectors(number_of_shells: int = 1, + spatial_dimension: int = 3) -> torch.Tensor: """Get relative coordinates lattice vectors. Get all the lattice vectors in relative coordinates from -number_of_shells to +number_of_shells, @@ -198,7 +199,6 @@ def _get_relative_coordinates_lattice_vectors(number_of_shells: int = 1) -> torc Returns: list_relative_lattice_vectors : all the lattice vectors in relative coordinates (ie, integers). """ - spatial_dimension = 3 shifts = range(-number_of_shells, number_of_shells + 1) list_relative_lattice_vectors = 1.0 * torch.tensor(list(itertools.product(shifts, repeat=spatial_dimension))) From 1a7645c07ae079be1e74d6db01b748f3ab900d55 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 16:41:00 -0400 Subject: [PATCH 05/20] Make the spatial dimension an input parameter. --- crystal_diffusion/utils/structure_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crystal_diffusion/utils/structure_utils.py b/crystal_diffusion/utils/structure_utils.py index f68e9644..aa375491 100644 --- a/crystal_diffusion/utils/structure_utils.py +++ b/crystal_diffusion/utils/structure_utils.py @@ -53,7 +53,8 @@ def compute_distances_in_batch(cartesian_positions: torch.Tensor, unit_cell: tor zero = torch.tensor(0.0).to(device) # The relative coordinates lattice vectors have dimensions [number of lattice vectors, spatial_dimension] - relative_lattice_vectors = _get_relative_coordinates_lattice_vectors(number_of_shells=1).to(device) + relative_lattice_vectors = _get_relative_coordinates_lattice_vectors(number_of_shells=1, + spatial_dimension=spatial_dimension).to(device) number_of_relative_lattice_vectors = len(relative_lattice_vectors) # Repeat the relative lattice vectors along the batch dimension; the basis vectors could potentially be From 110b620b04c040ed9ce8da0b8cf0f27bfff5f14d Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 18:01:19 -0400 Subject: [PATCH 06/20] Compute the structure ks as a metric. --- .../position_diffusion_lightning_model.py | 198 ++++++++++++++---- .../kolmogorov_smirnov_metrics.py | 58 +++++ crystal_diffusion/utils/structure_utils.py | 20 +- ...test_position_diffusion_lightning_model.py | 29 ++- tests/utils/test_structure_utils.py | 59 +++++- 5 files changed, 317 insertions(+), 47 deletions(-) create mode 100644 crystal_diffusion/sampling_metrics/kolmogorov_smirnov_metrics.py diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 30b37777..8f3607e6 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -1,10 +1,15 @@ import logging import typing from dataclasses import dataclass +from typing import Optional import pytorch_lightning as pl import torch +from crystal_diffusion.generators.langevin_generator import LangevinGenerator +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.generators.sampling import create_batch_of_samples from crystal_diffusion.models.loss import (LossParameters, create_loss_calculator) from crystal_diffusion.models.optimizer import (OptimizerParameters, @@ -15,17 +20,20 @@ ScoreNetworkParameters from crystal_diffusion.models.score_networks.score_network_factory import \ create_score_network -from crystal_diffusion.namespace import (CARTESIAN_FORCES, NOISE, - NOISY_RELATIVE_COORDINATES, +from crystal_diffusion.namespace import (CARTESIAN_FORCES, CARTESIAN_POSITIONS, + NOISE, NOISY_RELATIVE_COORDINATES, RELATIVE_COORDINATES, TIME, UNIT_CELL) from crystal_diffusion.samplers.noisy_relative_coordinates_sampler import \ NoisyRelativeCoordinatesSampler from crystal_diffusion.samplers.variance_sampler import ( ExplodingVarianceSampler, NoiseParameters) +from crystal_diffusion.sampling_metrics.kolmogorov_smirnov_metrics import \ + KolmogorovSmirnovMetrics from crystal_diffusion.score.wrapped_gaussian_score import \ get_sigma_normalized_score -from crystal_diffusion.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell +from crystal_diffusion.utils.basis_transformations import ( + get_positions_from_coordinates, map_relative_coordinates_to_unit_cell) +from crystal_diffusion.utils.structure_utils import compute_distances_in_batch from crystal_diffusion.utils.tensor_utils import \ broadcast_batch_tensor_to_all_dimensions @@ -35,14 +43,15 @@ @dataclass(kw_only=True) class PositionDiffusionParameters: """Position Diffusion parameters.""" + score_network_parameters: ScoreNetworkParameters loss_parameters: LossParameters optimizer_parameters: OptimizerParameters - scheduler_parameters: typing.Union[SchedulerParameters, None] = None + scheduler_parameters: Optional[SchedulerParameters] = None noise_parameters: NoiseParameters - kmax_target_score: int = ( - 4 # convergence parameter for the Ewald-like sum of the perturbation kernel. - ) + sampling_parameters: Optional[PredictorCorrectorSamplingParameters] = None + # convergence parameter for the Ewald-like sum of the perturbation kernel. + kmax_target_score: int = 4 class PositionDiffusionLightningModel(pl.LightningModule): @@ -59,16 +68,29 @@ def __init__(self, hyper_params: PositionDiffusionParameters): super().__init__() self.hyper_params = hyper_params - self.save_hyperparameters(logger=False) # It is not the responsibility of this class to log its parameters. + self.save_hyperparameters( + logger=False + ) # It is not the responsibility of this class to log its parameters. # we will model sigma x score - self.sigma_normalized_score_network = create_score_network(hyper_params.score_network_parameters) + self.sigma_normalized_score_network = create_score_network( + hyper_params.score_network_parameters + ) self.loss_calculator = create_loss_calculator(hyper_params.loss_parameters) self.noisy_relative_coordinates_sampler = NoisyRelativeCoordinatesSampler() self.variance_sampler = ExplodingVarianceSampler(hyper_params.noise_parameters) + if self.hyper_params.sampling_parameters is not None: + self.draw_samples = True + self.max_distance = ( + min(self.hyper_params.sampling_parameters.cell_dimensions) - 0.1 + ) + self.structure_ks_metric = KolmogorovSmirnovMetrics() + else: + self.draw_samples = False + def configure_optimizers(self): """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. @@ -83,8 +105,10 @@ def configure_optimizers(self): output = dict(optimizer=optimizer) if self.hyper_params.scheduler_parameters is not None: - scheduler_dict = load_scheduler_dictionary(scheduler_parameters=self.hyper_params.scheduler_parameters, - optimizer=optimizer) + scheduler_dict = load_scheduler_dictionary( + scheduler_parameters=self.hyper_params.scheduler_parameters, + optimizer=optimizer, + ) output.update(scheduler_dict) return output @@ -100,7 +124,9 @@ def _get_batch_size(batch: torch.Tensor) -> int: batch_size: the size of the batch. """ # The RELATIVE_COORDINATES have dimensions [batch_size, number_of_atoms, spatial_dimension]. - assert RELATIVE_COORDINATES in batch, f"The field '{RELATIVE_COORDINATES}' is missing from the input." + assert ( + RELATIVE_COORDINATES in batch + ), f"The field '{RELATIVE_COORDINATES}' is missing from the input." batch_size = batch[RELATIVE_COORDINATES].shape[0] return batch_size @@ -142,7 +168,9 @@ def _generic_step( loss : the computed loss. """ # The RELATIVE_COORDINATES have dimensions [batch_size, number_of_atoms, spatial_dimension]. - assert RELATIVE_COORDINATES in batch, f"The field '{RELATIVE_COORDINATES}' is missing from the input." + assert ( + RELATIVE_COORDINATES in batch + ), f"The field '{RELATIVE_COORDINATES}' is missing from the input." x0 = batch[RELATIVE_COORDINATES] shape = x0.shape assert len(shape) == 3, ( @@ -160,34 +188,48 @@ def _generic_step( batch_values=noise_sample.sigma, final_shape=shape ) - xt = self.noisy_relative_coordinates_sampler.get_noisy_relative_coordinates_sample(x0, sigmas) + xt = self.noisy_relative_coordinates_sampler.get_noisy_relative_coordinates_sample( + x0, sigmas + ) # The target is nabla log p_{t|0} (xt | x0): it is NOT the "score", but rather a "conditional" (on x0) score. - target_normalized_conditional_scores = self._get_target_normalized_score(xt, x0, sigmas) + target_normalized_conditional_scores = self._get_target_normalized_score( + xt, x0, sigmas + ) - unit_cell = torch.diag_embed(batch["box"]) # from (batch, spatial_dim) to (batch, spatial_dim, spatial_dim) + unit_cell = torch.diag_embed( + batch["box"] + ) # from (batch, spatial_dim) to (batch, spatial_dim, spatial_dim) forces = batch[CARTESIAN_FORCES] - augmented_batch = {NOISY_RELATIVE_COORDINATES: xt, - TIME: noise_sample.time.reshape(-1, 1), - NOISE: noise_sample.sigma.reshape(-1, 1), - UNIT_CELL: unit_cell, - CARTESIAN_FORCES: forces} + augmented_batch = { + NOISY_RELATIVE_COORDINATES: xt, + TIME: noise_sample.time.reshape(-1, 1), + NOISE: noise_sample.sigma.reshape(-1, 1), + UNIT_CELL: unit_cell, + CARTESIAN_FORCES: forces, + } use_conditional = None if no_conditional is False else False - predicted_normalized_scores = self.sigma_normalized_score_network(augmented_batch, conditional=use_conditional) + predicted_normalized_scores = self.sigma_normalized_score_network( + augmented_batch, conditional=use_conditional + ) - unreduced_loss = self.loss_calculator.calculate_unreduced_loss(predicted_normalized_scores, - target_normalized_conditional_scores, - sigmas.to(self.device)) + unreduced_loss = self.loss_calculator.calculate_unreduced_loss( + predicted_normalized_scores, + target_normalized_conditional_scores, + sigmas.to(self.device), + ) loss = torch.mean(unreduced_loss) - output = dict(loss=loss, - unreduced_loss=unreduced_loss.detach(), - sigmas=sigmas, - predicted_normalized_scores=predicted_normalized_scores.detach(), - target_normalized_conditional_scores=target_normalized_conditional_scores) + output = dict( + loss=loss, + unreduced_loss=unreduced_loss.detach(), + sigmas=sigmas, + predicted_normalized_scores=predicted_normalized_scores.detach(), + target_normalized_conditional_scores=target_normalized_conditional_scores, + ) output[RELATIVE_COORDINATES] = x0 output[NOISY_RELATIVE_COORDINATES] = xt @@ -217,8 +259,9 @@ def _get_target_normalized_score( target normalized score: sigma times target score, ie, sigma times nabla_xt log P_{t|0}(xt| x0). Tensor of dimensions [batch_size, number_of_atoms, spatial_dimension] """ - delta_relative_coordinates = map_relative_coordinates_to_unit_cell(noisy_relative_coordinates - - real_relative_coordinates) + delta_relative_coordinates = map_relative_coordinates_to_unit_cell( + noisy_relative_coordinates - real_relative_coordinates + ) target_normalized_scores = get_sigma_normalized_score( delta_relative_coordinates, sigmas, kmax=self.hyper_params.kmax_target_score ) @@ -235,7 +278,13 @@ def training_step(self, batch, batch_idx): self.log("train_step_loss", loss, on_step=True, on_epoch=False, prog_bar=True) # The 'train_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. - self.log("train_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True) + self.log( + "train_epoch_loss", + loss, + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) return output def validation_step(self, batch, batch_idx): @@ -245,8 +294,29 @@ def validation_step(self, batch, batch_idx): batch_size = self._get_batch_size(batch) # The 'validation_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. - self.log("validation_epoch_loss", loss, - batch_size=batch_size, on_step=False, on_epoch=True, prog_bar=True) + self.log( + "validation_epoch_loss", + loss, + batch_size=batch_size, + on_step=False, + on_epoch=True, + prog_bar=True, + ) + + if self.draw_samples: + basis_vectors = torch.diag_embed(batch["box"]) + cartesian_positions = get_positions_from_coordinates( + relative_coordinates=batch[RELATIVE_COORDINATES], + basis_vectors=basis_vectors, + ) + + distances = compute_distances_in_batch( + cartesian_positions=cartesian_positions, + unit_cell=basis_vectors, + max_distance=self.max_distance, + ) + self.structure_ks_metric.register_reference_samples(distances) + return output def test_step(self, batch, batch_idx): @@ -255,5 +325,59 @@ def test_step(self, batch, batch_idx): loss = output["loss"] batch_size = self._get_batch_size(batch) # The 'test_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. - self.log("test_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True) + self.log( + "test_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True + ) return output + + def generate_samples(self): + """Generate a batch of samples.""" + assert ( + self.hyper_params.sampling_parameters is not None + ), "sampling parameters must be provided to create a generator." + logger.info("Creating Langevin Generator for sampling") + + with torch.no_grad(): + generator = LangevinGenerator( + noise_parameters=self.hyper_params.noise_parameters, + sampling_parameters=self.hyper_params.sampling_parameters, + sigma_normalized_score_network=self.sigma_normalized_score_network, + ) + + logger.info("Draw samples") + samples_batch = create_batch_of_samples( + generator=generator, + sampling_parameters=self.hyper_params.sampling_parameters, + device=self.device, + ) + return samples_batch + + def on_validation_epoch_end(self) -> None: + """On validation epoch end.""" + if not self.draw_samples: + return + + samples_batch = self.generate_samples() + sample_distances = compute_distances_in_batch( + cartesian_positions=samples_batch[CARTESIAN_POSITIONS], + unit_cell=samples_batch[UNIT_CELL], + max_distance=self.max_distance, + ) + + self.structure_ks_metric.register_predicted_samples(sample_distances) + + ( + ks_distance, + p_value, + ) = self.structure_ks_metric.compute_kolmogorov_smirnov_distance_and_pvalue() + self.structure_ks_metric.reset() + + self.log( + "validation_ks_distance_structure", + ks_distance, + on_step=False, + on_epoch=True, + ) + self.log( + "validation_ks_p_value_structure", p_value, on_step=False, on_epoch=True + ) diff --git a/crystal_diffusion/sampling_metrics/kolmogorov_smirnov_metrics.py b/crystal_diffusion/sampling_metrics/kolmogorov_smirnov_metrics.py new file mode 100644 index 00000000..4c0b91e9 --- /dev/null +++ b/crystal_diffusion/sampling_metrics/kolmogorov_smirnov_metrics.py @@ -0,0 +1,58 @@ +from typing import Tuple + +import scipy.stats as ss +from torchmetrics import CatMetric + + +class KolmogorovSmirnovMetrics: + """Kolmogorov Smirnov metrics.""" + + def __init__(self): + """Init method.""" + self._reference_samples_metric = CatMetric() + self._predicted_samples_metric = CatMetric() + + def register_reference_samples(self, reference_samples): + """Register reference samples.""" + self._reference_samples_metric.update(reference_samples) + + def register_predicted_samples(self, predicted_samples): + """Register predicted samples.""" + self._predicted_samples_metric.update(predicted_samples) + + def reset(self): + """reset.""" + self._reference_samples_metric.reset() + self._predicted_samples_metric.reset() + + def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]: + """Compute Kolmogorov Smirnov Distance. + + Compute the two sample Kolmogorov–Smirnov test in order to gauge whether the + predicted samples were drawn from the same distribution as the reference samples. + + Args: + predicted_samples : samples drawn from the diffusion model. + reference_samples : samples drawn from the reference distribution. + + Returns: + ks_distance, p_value: the Kolmogorov-Smirnov test statistic (a "distance") + and the statistical test's p-value. + """ + reference_samples = self._reference_samples_metric.compute() + predicted_samples = self._predicted_samples_metric.compute() + + test_result = ss.ks_2samp(predicted_samples.detach().cpu().numpy(), + reference_samples.detach().cpu().numpy(), + alternative='two-sided', method='auto') + + # The "test statistic" of the two-sided KS test is the largest vertical distance between + # the empirical CDFs of the two samples. The larger this is, the less likely the two + # samples were drawn from the same underlying distribution, hence the idea of 'distance'. + ks_distance = test_result.statistic + + # The null hypothesis of the KS test is that both samples are drawn from the same distribution. + # Thus, a small p-value (which leads to the rejection of the null hypothesis) indicates that + # the samples probably come from different distributions (ie, our samples are bad!). + p_value = test_result.pvalue + return ks_distance, p_value diff --git a/crystal_diffusion/utils/structure_utils.py b/crystal_diffusion/utils/structure_utils.py index aa375491..125d6f9c 100644 --- a/crystal_diffusion/utils/structure_utils.py +++ b/crystal_diffusion/utils/structure_utils.py @@ -7,7 +7,7 @@ from crystal_diffusion.utils.neighbors import ( _get_relative_coordinates_lattice_vectors, _get_shifted_positions, - get_positions_from_coordinates) + get_periodic_adjacency_information, get_positions_from_coordinates) def create_structure(basis_vectors: np.ndarray, relative_coordinates: np.ndarray, species: List[str]) -> Structure: @@ -106,3 +106,21 @@ def get_orthogonal_basis_vectors(batch_size: int, cell_dimensions: List[float]) """ basis_vectors = torch.diag(torch.Tensor(cell_dimensions)).unsqueeze(0).repeat(batch_size, 1, 1) return basis_vectors + + +def compute_distances(cartesian_positions: torch.Tensor, basis_vectors: torch.Tensor, max_distance: float): + """Compute distances.""" + adj_info = get_periodic_adjacency_information(cartesian_positions, basis_vectors, radial_cutoff=max_distance) + + # The following are 1D arrays of length equal to the total number of neighbors for all batch elements + # and all atoms. + # bch: which batch does an edge belong to + # src: at which atom does an edge start + # dst: at which atom does an edge end + bch = adj_info.edge_batch_indices + src, dst = adj_info.adjacency_matrix + + cartesian_displacements = cartesian_positions[bch, dst] - cartesian_positions[bch, src] + adj_info.shifts + distances = torch.linalg.norm(cartesian_displacements, dim=-1) + # Identify neighbors within the radial_cutoff, but avoiding self. + return distances[distances > 0.0] diff --git a/tests/models/test_position_diffusion_lightning_model.py b/tests/models/test_position_diffusion_lightning_model.py index 999ef3e8..6905bb26 100644 --- a/tests/models/test_position_diffusion_lightning_model.py +++ b/tests/models/test_position_diffusion_lightning_model.py @@ -3,6 +3,8 @@ from pytorch_lightning import LightningDataModule, Trainer from torch.utils.data import DataLoader, random_split +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters from crystal_diffusion.models.loss import create_loss_parameters from crystal_diffusion.models.optimizer import OptimizerParameters from crystal_diffusion.models.position_diffusion_lightning_model import ( @@ -68,7 +70,7 @@ def number_of_atoms(self): @pytest.fixture() def unit_cell_size(self): - return 10 + return 10.1 @pytest.fixture(params=['adam', 'adamw']) def optimizer_parameters(self, request): @@ -92,9 +94,25 @@ def scheduler_parameters(self, request): def loss_parameters(self, request): return create_loss_parameters(model_dictionary=dict(algorithm=request.param)) + @pytest.fixture() + def number_of_samples(self): + return 12 + + @pytest.fixture() + def cell_dimensions(self, unit_cell_size, spatial_dimension): + return spatial_dimension * [unit_cell_size] + + @pytest.fixture() + def sampling_parameters(self, number_of_atoms, spatial_dimension, number_of_samples, cell_dimensions): + sampling_parameters = PredictorCorrectorSamplingParameters(number_of_atoms=number_of_atoms, + spatial_dimension=spatial_dimension, + number_of_samples=number_of_samples, + cell_dimensions=cell_dimensions) + return sampling_parameters + @pytest.fixture() def hyper_params(self, number_of_atoms, spatial_dimension, - optimizer_parameters, scheduler_parameters, loss_parameters): + optimizer_parameters, scheduler_parameters, loss_parameters, sampling_parameters): score_network_parameters = MLPScoreNetworkParameters( number_of_atoms=number_of_atoms, n_hidden_dimensions=3, @@ -110,7 +128,8 @@ def hyper_params(self, number_of_atoms, spatial_dimension, optimizer_parameters=optimizer_parameters, scheduler_parameters=scheduler_parameters, noise_parameters=noise_parameters, - loss_parameters=loss_parameters + loss_parameters=loss_parameters, + sampling_parameters=sampling_parameters ) return hyper_params @@ -213,3 +232,7 @@ def test_smoke_test(self, lightning_model, fake_datamodule, accelerator): trainer = Trainer(fast_dev_run=3, accelerator=accelerator) trainer.fit(lightning_model, fake_datamodule) trainer.test(lightning_model, fake_datamodule) + + def test_generate_sample(self, lightning_model, number_of_samples, number_of_atoms, spatial_dimension): + samples_batch = lightning_model.generate_samples() + assert samples_batch[RELATIVE_COORDINATES].shape == (number_of_samples, number_of_atoms, spatial_dimension) diff --git a/tests/utils/test_structure_utils.py b/tests/utils/test_structure_utils.py index 4fd9b0d9..f8ef694a 100644 --- a/tests/utils/test_structure_utils.py +++ b/tests/utils/test_structure_utils.py @@ -1,18 +1,65 @@ +import pytest import torch -from crystal_diffusion.utils.structure_utils import \ - get_orthogonal_basis_vectors +from crystal_diffusion.utils.basis_transformations import \ + get_positions_from_coordinates +from crystal_diffusion.utils.structure_utils import ( + compute_distances, compute_distances_in_batch, + get_orthogonal_basis_vectors) -def test_get_orthogonal_basis_vectors(): +@pytest.fixture() +def spatial_dimension(): + return 3 - cell_dimensions = [12.34, 8.32, 7.12] - batch_size = 16 - computed_basis_vectors = get_orthogonal_basis_vectors(batch_size, cell_dimensions) +@pytest.fixture() +def cell_dimensions(spatial_dimension): + return list(7.5 + 2.5 * torch.rand(spatial_dimension).numpy()) + + +@pytest.fixture() +def batch_size(): + return 16 + + +@pytest.fixture() +def number_of_atoms(): + return 12 + + +@pytest.fixture() +def relative_coordinates(batch_size, number_of_atoms, spatial_dimension): + return torch.rand(batch_size, number_of_atoms, spatial_dimension) + +def test_get_orthogonal_basis_vectors(batch_size, cell_dimensions): + computed_basis_vectors = get_orthogonal_basis_vectors(batch_size, cell_dimensions) expected_basis_vectors = torch.zeros_like(computed_basis_vectors) for d, acell in enumerate(cell_dimensions): expected_basis_vectors[:, d, d] = acell torch.testing.assert_allclose(computed_basis_vectors, expected_basis_vectors) + + +def test_compute_distances(batch_size, cell_dimensions, relative_coordinates): + max_distance = min(cell_dimensions) - 0.5 + basis_vectors = get_orthogonal_basis_vectors(batch_size, cell_dimensions) + + cartesian_positions = get_positions_from_coordinates( + relative_coordinates=relative_coordinates, basis_vectors=basis_vectors + ) + + distances = compute_distances( + cartesian_positions=cartesian_positions, + basis_vectors=basis_vectors, + max_distance=float(max_distance), + ) + + alt_distances = compute_distances_in_batch( + cartesian_positions=cartesian_positions, + unit_cell=basis_vectors, + max_distance=float(max_distance), + ) + + torch.testing.assert_allclose(distances, alt_distances) From 2e9ecacff3b118e4f45086af360c1100a559cfd3 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 18:06:56 -0400 Subject: [PATCH 07/20] using sampling parameters. --- crystal_diffusion/models/model_loader.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/crystal_diffusion/models/model_loader.py b/crystal_diffusion/models/model_loader.py index 66433392..5a5ed822 100644 --- a/crystal_diffusion/models/model_loader.py +++ b/crystal_diffusion/models/model_loader.py @@ -2,6 +2,8 @@ import logging from typing import Any, AnyStr, Dict +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters from crystal_diffusion.models.loss import create_loss_parameters from crystal_diffusion.models.optimizer import create_optimizer_parameters from crystal_diffusion.models.position_diffusion_lightning_model import ( @@ -37,15 +39,22 @@ def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> PositionDiffusionLi model_dict = hyper_params['model'] loss_parameters = create_loss_parameters(model_dict) - noise_dict = hyper_params['model']['noise'] + noise_dict = model_dict['noise'] noise_parameters = NoiseParameters(**noise_dict) + if 'sampling' in model_dict: + sampling_dict = model_dict['sampling'] + sampling_parameters = PredictorCorrectorSamplingParameters(**sampling_dict) + else: + sampling_parameters = None + diffusion_params = PositionDiffusionParameters( score_network_parameters=score_network_parameters, loss_parameters=loss_parameters, optimizer_parameters=optimizer_parameters, scheduler_parameters=scheduler_parameters, noise_parameters=noise_parameters, + sampling_parameters=sampling_parameters ) model = PositionDiffusionLightningModel(diffusion_params) From ac527b08bf3920ae13fa2f88ad7e30e9ab5133da Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 18:09:04 -0400 Subject: [PATCH 08/20] using sampling parameters. --- .../models/position_diffusion_lightning_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 8f3607e6..5ac388b7 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -83,10 +83,10 @@ def __init__(self, hyper_params: PositionDiffusionParameters): self.variance_sampler = ExplodingVarianceSampler(hyper_params.noise_parameters) if self.hyper_params.sampling_parameters is not None: + assert self.hyper_params.sampling_parameters.compute_structure_factor, \ + "compute_structure_factor should be True. Config is now inconsistent." self.draw_samples = True - self.max_distance = ( - min(self.hyper_params.sampling_parameters.cell_dimensions) - 0.1 - ) + self.max_distance = self.hyper_params.sampling_parameters.structure_factor_max_distance self.structure_ks_metric = KolmogorovSmirnovMetrics() else: self.draw_samples = False From edb1083580e8324a67270dc3fac3206306c4ecb2 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 18:13:33 -0400 Subject: [PATCH 09/20] Fix test bjork. --- tests/utils/test_structure_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/utils/test_structure_utils.py b/tests/utils/test_structure_utils.py index f8ef694a..1f366968 100644 --- a/tests/utils/test_structure_utils.py +++ b/tests/utils/test_structure_utils.py @@ -15,7 +15,10 @@ def spatial_dimension(): @pytest.fixture() def cell_dimensions(spatial_dimension): - return list(7.5 + 2.5 * torch.rand(spatial_dimension).numpy()) + values = [] + for v in list(7.5 + 2.5 * torch.rand(spatial_dimension).numpy()): + values.append(float(v)) + return values @pytest.fixture() From 16c2eaee25a92c9c48c60bbe72ae5fd2490d51a1 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Fri, 20 Sep 2024 18:17:09 -0400 Subject: [PATCH 10/20] Fix test bjork. --- tests/models/test_position_diffusion_lightning_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_position_diffusion_lightning_model.py b/tests/models/test_position_diffusion_lightning_model.py index 6905bb26..257c8b22 100644 --- a/tests/models/test_position_diffusion_lightning_model.py +++ b/tests/models/test_position_diffusion_lightning_model.py @@ -107,6 +107,8 @@ def sampling_parameters(self, number_of_atoms, spatial_dimension, number_of_samp sampling_parameters = PredictorCorrectorSamplingParameters(number_of_atoms=number_of_atoms, spatial_dimension=spatial_dimension, number_of_samples=number_of_samples, + compute_structure_factor=True, + structure_factor_max_distance=min(cell_dimensions), cell_dimensions=cell_dimensions) return sampling_parameters From 320818a301edcc7b0f002f76e6b2fb49663a88f0 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sat, 21 Sep 2024 11:12:51 -0400 Subject: [PATCH 11/20] rename folder for clarity. --- crystal_diffusion/models/position_diffusion_lightning_model.py | 2 +- crystal_diffusion/samples_and_metrics/__init__.py | 0 .../kolmogorov_smirnov_metrics.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 crystal_diffusion/samples_and_metrics/__init__.py rename crystal_diffusion/{sampling_metrics => samples_and_metrics}/kolmogorov_smirnov_metrics.py (100%) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 5ac388b7..38b5495f 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -27,7 +27,7 @@ NoisyRelativeCoordinatesSampler from crystal_diffusion.samplers.variance_sampler import ( ExplodingVarianceSampler, NoiseParameters) -from crystal_diffusion.sampling_metrics.kolmogorov_smirnov_metrics import \ +from crystal_diffusion.samples_and_metrics.kolmogorov_smirnov_metrics import \ KolmogorovSmirnovMetrics from crystal_diffusion.score.wrapped_gaussian_score import \ get_sigma_normalized_score diff --git a/crystal_diffusion/samples_and_metrics/__init__.py b/crystal_diffusion/samples_and_metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/crystal_diffusion/sampling_metrics/kolmogorov_smirnov_metrics.py b/crystal_diffusion/samples_and_metrics/kolmogorov_smirnov_metrics.py similarity index 100% rename from crystal_diffusion/sampling_metrics/kolmogorov_smirnov_metrics.py rename to crystal_diffusion/samples_and_metrics/kolmogorov_smirnov_metrics.py From 2f36962cc352a2517d174b9691ed3959868051d6 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sat, 21 Sep 2024 18:34:23 -0400 Subject: [PATCH 12/20] Major Refactor. This refactor lets the lightning module compute the various ks metrics, which used to be computed in a callback. The callback now just records to disck / plots the results. The configuration has also been modified to reflect these changes. --- .../generate_sample_energies.py | 2 +- .../perfect_score_loss_analysis.py | 2 +- .../callbacks/analysis_callbacks.py | 8 +- .../callbacks/callback_loader.py | 6 +- .../generators/instantiate_generator.py | 36 +++++ .../generators/load_sampling_parameters.py | 39 +++++ .../generators/position_generator.py | 7 +- ...ader.py => instantiate_diffusion_model.py} | 12 +- .../position_diffusion_lightning_model.py | 135 +++++++++++------ crystal_diffusion/oracle/energies.py | 53 +++++++ .../diffusion_sampling_parameters.py | 51 +++++++ .../kolmogorov_smirnov_metrics.py | 16 +-- .../sampling.py | 0 .../sampling_metrics_parameters.py | 13 ++ crystal_diffusion/train_diffusion.py | 7 +- .../diffusion/config_diffusion_mlp.yaml | 54 ++++--- examples/local/diffusion/run_diffusion.sh | 4 +- .../energy_consistency_analysis.py | 6 +- .../sampling_si_diffusion.py | 3 +- .../repaint_with_sota_score.py | 3 +- .../sota_score_sampling_and_plotting.py | 3 +- tests/callbacks/test_sampling_callback.py | 136 ------------------ ...test_position_diffusion_lightning_model.py | 21 ++- tests/samples_and_metrics/__init__.py | 0 .../test_sampling.py | 3 +- tests/test_train_diffusion.py | 7 +- 26 files changed, 378 insertions(+), 249 deletions(-) create mode 100644 crystal_diffusion/generators/instantiate_generator.py create mode 100644 crystal_diffusion/generators/load_sampling_parameters.py rename crystal_diffusion/models/{model_loader.py => instantiate_diffusion_model.py} (85%) create mode 100644 crystal_diffusion/oracle/energies.py create mode 100644 crystal_diffusion/samples_and_metrics/diffusion_sampling_parameters.py rename crystal_diffusion/{generators => samples_and_metrics}/sampling.py (100%) create mode 100644 crystal_diffusion/samples_and_metrics/sampling_metrics_parameters.py delete mode 100644 tests/callbacks/test_sampling_callback.py create mode 100644 tests/samples_and_metrics/__init__.py rename tests/{generators => samples_and_metrics}/test_sampling.py (96%) diff --git a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py index 1979e460..02340d4d 100644 --- a/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py +++ b/crystal_diffusion/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py @@ -11,7 +11,7 @@ LANGEVIN_EXPLORATION_DIRECTORY from crystal_diffusion.analysis.analytic_score.utils import ( get_exact_samples, get_silicon_supercell) -from crystal_diffusion.callbacks.sampling_callback import logger +from crystal_diffusion.callbacks.sampling_visualization_callback import logger from crystal_diffusion.generators.langevin_generator import LangevinGenerator from crystal_diffusion.generators.predictor_corrector_position_generator import \ PredictorCorrectorSamplingParameters diff --git a/crystal_diffusion/analysis/analytic_score/perfect_score_loss_analysis.py b/crystal_diffusion/analysis/analytic_score/perfect_score_loss_analysis.py index cfcbe484..182ac5c4 100644 --- a/crystal_diffusion/analysis/analytic_score/perfect_score_loss_analysis.py +++ b/crystal_diffusion/analysis/analytic_score/perfect_score_loss_analysis.py @@ -16,7 +16,7 @@ get_exact_samples, get_silicon_supercell) from crystal_diffusion.callbacks.loss_monitoring_callback import \ LossMonitoringCallback -from crystal_diffusion.callbacks.sampling_callback import \ +from crystal_diffusion.callbacks.sampling_visualization_callback import \ PredictorCorrectorDiffusionSamplingCallback from crystal_diffusion.generators.predictor_corrector_position_generator import \ PredictorCorrectorSamplingParameters diff --git a/crystal_diffusion/callbacks/analysis_callbacks.py b/crystal_diffusion/callbacks/analysis_callbacks.py index f0550329..7785ec3d 100644 --- a/crystal_diffusion/callbacks/analysis_callbacks.py +++ b/crystal_diffusion/callbacks/analysis_callbacks.py @@ -12,8 +12,8 @@ from crystal_diffusion.analysis import PLOT_STYLE_PATH from crystal_diffusion.analysis.analytic_score.utils import \ get_relative_harmonic_energy -from crystal_diffusion.callbacks.sampling_callback import \ - DiffusionSamplingCallback +from crystal_diffusion.callbacks.sampling_visualization_callback import \ + SamplingVisualizationCallback from crystal_diffusion.generators.position_generator import SamplingParameters from crystal_diffusion.samplers.variance_sampler import NoiseParameters @@ -22,7 +22,7 @@ plt.style.use(PLOT_STYLE_PATH) -class HarmonicEnergyDiffusionSamplingCallback(DiffusionSamplingCallback): +class HarmonicEnergyDiffusionSamplingCallback(SamplingVisualizationCallback): """Callback class to periodically generate samples and log their energies.""" def __init__(self, noise_parameters: NoiseParameters, @@ -54,7 +54,7 @@ def _compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> @staticmethod def _plot_energy_histogram(sample_energies: np.ndarray, validation_dataset_energies: np.array, epoch: int) -> plt.figure: - fig = DiffusionSamplingCallback._plot_energy_histogram(sample_energies, validation_dataset_energies, epoch) + fig = SamplingVisualizationCallback._plot_energy_histogram(sample_energies, validation_dataset_energies, epoch) fig.suptitle(f'Sampling Unitless Harmonic Potential Energy Distributions\nEpoch {epoch}') ax1 = fig.axes[0] diff --git a/crystal_diffusion/callbacks/callback_loader.py b/crystal_diffusion/callbacks/callback_loader.py index 2e1f3868..977e26b5 100644 --- a/crystal_diffusion/callbacks/callback_loader.py +++ b/crystal_diffusion/callbacks/callback_loader.py @@ -5,15 +5,15 @@ from crystal_diffusion.callbacks.loss_monitoring_callback import \ instantiate_loss_monitoring_callback -from crystal_diffusion.callbacks.sampling_callback import \ - instantiate_diffusion_sampling_callback +from crystal_diffusion.callbacks.sampling_visualization_callback import \ + instantiate_sampling_visualization_callback from crystal_diffusion.callbacks.standard_callbacks import ( CustomProgressBar, instantiate_early_stopping_callback, instantiate_model_checkpoint_callbacks) OPTIONAL_CALLBACK_DICTIONARY = dict(early_stopping=instantiate_early_stopping_callback, model_checkpoint=instantiate_model_checkpoint_callbacks, - diffusion_sampling=instantiate_diffusion_sampling_callback, + sampling_visualization=instantiate_sampling_visualization_callback, loss_monitoring=instantiate_loss_monitoring_callback) diff --git a/crystal_diffusion/generators/instantiate_generator.py b/crystal_diffusion/generators/instantiate_generator.py new file mode 100644 index 00000000..ac6277bb --- /dev/null +++ b/crystal_diffusion/generators/instantiate_generator.py @@ -0,0 +1,36 @@ +from crystal_diffusion.generators.langevin_generator import LangevinGenerator +from crystal_diffusion.generators.ode_position_generator import \ + ExplodingVarianceODEPositionGenerator +from crystal_diffusion.generators.position_generator import SamplingParameters +from crystal_diffusion.generators.sde_position_generator import \ + ExplodingVarianceSDEPositionGenerator +from crystal_diffusion.models.score_networks import ScoreNetwork +from crystal_diffusion.samplers.variance_sampler import NoiseParameters + + +def instantiate_generator(sampling_parameters: SamplingParameters, + noise_parameters: NoiseParameters, + sigma_normalized_score_network: ScoreNetwork): + """Instantiate generator.""" + assert sampling_parameters.algorithm in ['ode', 'sde', 'predictor_corrector'], \ + "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'" + + match sampling_parameters.algorithm: + case 'predictor_corrector': + generator = LangevinGenerator(sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network) + case 'ode': + generator = ExplodingVarianceODEPositionGenerator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network) + case 'sde': + generator = ExplodingVarianceSDEPositionGenerator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network) + case _: + raise NotImplementedError(f"algorithm '{sampling_parameters.algorithm}' is not implemented") + + return generator diff --git a/crystal_diffusion/generators/load_sampling_parameters.py b/crystal_diffusion/generators/load_sampling_parameters.py new file mode 100644 index 00000000..21ce3c21 --- /dev/null +++ b/crystal_diffusion/generators/load_sampling_parameters.py @@ -0,0 +1,39 @@ +from typing import Any, AnyStr, Dict + +from crystal_diffusion.generators.ode_position_generator import \ + ODESamplingParameters +from crystal_diffusion.generators.position_generator import SamplingParameters +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.generators.sde_position_generator import \ + SDESamplingParameters + + +def load_sampling_parameters(sampling_parameter_dictionary: Dict[AnyStr, Any]) -> SamplingParameters: + """Load sampling parameters. + + Extract the needed information from the configuration dictionary. + + Args: + sampling_parameter_dictionary: dictionary of hyperparameters loaded from a config file + + Returns: + sampling_parameters: the relevant configuration object. + """ + assert 'algorithm' in sampling_parameter_dictionary, "The sampling parameters must select an algorithm." + algorithm = sampling_parameter_dictionary['algorithm'] + + assert algorithm in ['ode', 'sde', 'predictor_corrector'], \ + "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'" + + match algorithm: + case 'predictor_corrector': + sampling_parameters = PredictorCorrectorSamplingParameters(**sampling_parameter_dictionary) + case 'ode': + sampling_parameters = ODESamplingParameters(**sampling_parameter_dictionary) + case 'sde': + sampling_parameters = SDESamplingParameters(**sampling_parameter_dictionary) + case _: + raise NotImplementedError(f"algorithm '{algorithm}' is not implemented") + + return sampling_parameters diff --git a/crystal_diffusion/generators/position_generator.py b/crystal_diffusion/generators/position_generator.py index 9cf04b8a..6ae2dd39 100644 --- a/crystal_diffusion/generators/position_generator.py +++ b/crystal_diffusion/generators/position_generator.py @@ -12,14 +12,11 @@ class SamplingParameters: spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. number_of_atoms: int # the number of atoms that must be generated in a sampled configuration. number_of_samples: int - sample_batchsize: Optional[int] = None # iterate up to number_of_samples with batches of this size + # iterate up to number_of_samples with batches of this size # if None, use number_of_samples as batchsize - sample_every_n_epochs: int = 1 # Sampling is expensive; control frequency - first_sampling_epoch: int = 1 # Epoch at which sampling can begin; no sampling before this epoch. + sample_batchsize: Optional[int] = None cell_dimensions: List[float] # unit cell dimensions; the unit cell is assumed to be an orthogonal box. record_samples: bool = False # should the predictor and corrector steps be recorded to a file - compute_structure_factor: bool = False # should the structure factor (distances distribution) be recorded - structure_factor_max_distance: float = 10.0 # cutoff for the structure factor class PositionGenerator(ABC): diff --git a/crystal_diffusion/models/model_loader.py b/crystal_diffusion/models/instantiate_diffusion_model.py similarity index 85% rename from crystal_diffusion/models/model_loader.py rename to crystal_diffusion/models/instantiate_diffusion_model.py index 5a5ed822..db08248e 100644 --- a/crystal_diffusion/models/model_loader.py +++ b/crystal_diffusion/models/instantiate_diffusion_model.py @@ -2,8 +2,6 @@ import logging from typing import Any, AnyStr, Dict -from crystal_diffusion.generators.predictor_corrector_position_generator import \ - PredictorCorrectorSamplingParameters from crystal_diffusion.models.loss import create_loss_parameters from crystal_diffusion.models.optimizer import create_optimizer_parameters from crystal_diffusion.models.position_diffusion_lightning_model import ( @@ -12,6 +10,8 @@ from crystal_diffusion.models.score_networks.score_network_factory import \ create_score_network_parameters from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.samples_and_metrics.diffusion_sampling_parameters import \ + load_diffusion_sampling_parameters logger = logging.getLogger(__name__) @@ -42,11 +42,7 @@ def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> PositionDiffusionLi noise_dict = model_dict['noise'] noise_parameters = NoiseParameters(**noise_dict) - if 'sampling' in model_dict: - sampling_dict = model_dict['sampling'] - sampling_parameters = PredictorCorrectorSamplingParameters(**sampling_dict) - else: - sampling_parameters = None + diffusion_sampling_parameters = load_diffusion_sampling_parameters(hyper_params) diffusion_params = PositionDiffusionParameters( score_network_parameters=score_network_parameters, @@ -54,7 +50,7 @@ def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> PositionDiffusionLi optimizer_parameters=optimizer_parameters, scheduler_parameters=scheduler_parameters, noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters + diffusion_sampling_parameters=diffusion_sampling_parameters ) model = PositionDiffusionLightningModel(diffusion_params) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 38b5495f..34364e30 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -6,10 +6,8 @@ import pytorch_lightning as pl import torch -from crystal_diffusion.generators.langevin_generator import LangevinGenerator -from crystal_diffusion.generators.predictor_corrector_position_generator import \ - PredictorCorrectorSamplingParameters -from crystal_diffusion.generators.sampling import create_batch_of_samples +from crystal_diffusion.generators.instantiate_generator import \ + instantiate_generator from crystal_diffusion.models.loss import (LossParameters, create_loss_calculator) from crystal_diffusion.models.optimizer import (OptimizerParameters, @@ -23,12 +21,17 @@ from crystal_diffusion.namespace import (CARTESIAN_FORCES, CARTESIAN_POSITIONS, NOISE, NOISY_RELATIVE_COORDINATES, RELATIVE_COORDINATES, TIME, UNIT_CELL) +from crystal_diffusion.oracle.energies import compute_oracle_energies from crystal_diffusion.samplers.noisy_relative_coordinates_sampler import \ NoisyRelativeCoordinatesSampler from crystal_diffusion.samplers.variance_sampler import ( ExplodingVarianceSampler, NoiseParameters) +from crystal_diffusion.samples_and_metrics.diffusion_sampling_parameters import \ + DiffusionSamplingParameters from crystal_diffusion.samples_and_metrics.kolmogorov_smirnov_metrics import \ KolmogorovSmirnovMetrics +from crystal_diffusion.samples_and_metrics.sampling import \ + create_batch_of_samples from crystal_diffusion.score.wrapped_gaussian_score import \ get_sigma_normalized_score from crystal_diffusion.utils.basis_transformations import ( @@ -49,9 +52,9 @@ class PositionDiffusionParameters: optimizer_parameters: OptimizerParameters scheduler_parameters: Optional[SchedulerParameters] = None noise_parameters: NoiseParameters - sampling_parameters: Optional[PredictorCorrectorSamplingParameters] = None # convergence parameter for the Ewald-like sum of the perturbation kernel. kmax_target_score: int = 4 + diffusion_sampling_parameters: Optional[DiffusionSamplingParameters] = None class PositionDiffusionLightningModel(pl.LightningModule): @@ -82,14 +85,19 @@ def __init__(self, hyper_params: PositionDiffusionParameters): self.noisy_relative_coordinates_sampler = NoisyRelativeCoordinatesSampler() self.variance_sampler = ExplodingVarianceSampler(hyper_params.noise_parameters) - if self.hyper_params.sampling_parameters is not None: - assert self.hyper_params.sampling_parameters.compute_structure_factor, \ - "compute_structure_factor should be True. Config is now inconsistent." - self.draw_samples = True - self.max_distance = self.hyper_params.sampling_parameters.structure_factor_max_distance - self.structure_ks_metric = KolmogorovSmirnovMetrics() - else: - self.draw_samples = False + self.generator = None + self.structure_ks_metric = None + self.energy_ks_metric = None + + self.draw_samples = hyper_params.diffusion_sampling_parameters is not None + if self.draw_samples: + self.metrics_parameters = ( + self.hyper_params.diffusion_sampling_parameters.metrics_parameters + ) + if self.metrics_parameters.compute_structure_factor: + self.structure_ks_metric = KolmogorovSmirnovMetrics() + if self.metrics_parameters.compute_energies: + self.energy_ks_metric = KolmogorovSmirnovMetrics() def configure_optimizers(self): """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. @@ -303,19 +311,28 @@ def validation_step(self, batch, batch_idx): prog_bar=True, ) - if self.draw_samples: + if not self.draw_samples: + return output + + if self.metrics_parameters.compute_energies: + reference_energies = batch["potential_energy"] + self.energy_ks_metric.register_reference_samples(reference_energies.cpu()) + + if self.metrics_parameters.compute_structure_factor: basis_vectors = torch.diag_embed(batch["box"]) cartesian_positions = get_positions_from_coordinates( relative_coordinates=batch[RELATIVE_COORDINATES], basis_vectors=basis_vectors, ) - distances = compute_distances_in_batch( + reference_distances = compute_distances_in_batch( cartesian_positions=cartesian_positions, unit_cell=basis_vectors, - max_distance=self.max_distance, + max_distance=self.metrics_parameters.structure_factor_max_distance, + ) + self.structure_ks_metric.register_reference_samples( + reference_distances.cpu() ) - self.structure_ks_metric.register_reference_samples(distances) return output @@ -333,21 +350,21 @@ def test_step(self, batch, batch_idx): def generate_samples(self): """Generate a batch of samples.""" assert ( - self.hyper_params.sampling_parameters is not None + self.hyper_params.diffusion_sampling_parameters is not None ), "sampling parameters must be provided to create a generator." - logger.info("Creating Langevin Generator for sampling") - with torch.no_grad(): - generator = LangevinGenerator( + logger.info("Creating Generator for sampling") + self.generator = instantiate_generator( + sampling_parameters=self.hyper_params.diffusion_sampling_parameters.sampling_parameters, noise_parameters=self.hyper_params.noise_parameters, - sampling_parameters=self.hyper_params.sampling_parameters, sigma_normalized_score_network=self.sigma_normalized_score_network, ) + logger.info(f"Generator type : {type(self.generator)}") logger.info("Draw samples") samples_batch = create_batch_of_samples( - generator=generator, - sampling_parameters=self.hyper_params.sampling_parameters, + generator=self.generator, + sampling_parameters=self.hyper_params.diffusion_sampling_parameters.sampling_parameters, device=self.device, ) return samples_batch @@ -357,27 +374,57 @@ def on_validation_epoch_end(self) -> None: if not self.draw_samples: return + logger.info("Drawing samples at the end of the validation epoch.") samples_batch = self.generate_samples() - sample_distances = compute_distances_in_batch( - cartesian_positions=samples_batch[CARTESIAN_POSITIONS], - unit_cell=samples_batch[UNIT_CELL], - max_distance=self.max_distance, - ) - self.structure_ks_metric.register_predicted_samples(sample_distances) + if self.metrics_parameters.compute_energies: + sample_energies = compute_oracle_energies(samples_batch) + self.energy_ks_metric.register_predicted_samples(sample_energies.cpu()) + + ( + ks_distance, + p_value, + ) = self.energy_ks_metric.compute_kolmogorov_smirnov_distance_and_pvalue() + self.log( + "validation_ks_distance_energy", + ks_distance, + on_step=False, + on_epoch=True, + ) + self.log( + "validation_ks_p_value_energy", p_value, on_step=False, on_epoch=True + ) + + if self.metrics_parameters.compute_structure_factor: + sample_distances = compute_distances_in_batch( + cartesian_positions=samples_batch[CARTESIAN_POSITIONS], + unit_cell=samples_batch[UNIT_CELL], + max_distance=self.metrics_parameters.structure_factor_max_distance, + ) + self.structure_ks_metric.register_predicted_samples(sample_distances.cpu()) - ( - ks_distance, - p_value, - ) = self.structure_ks_metric.compute_kolmogorov_smirnov_distance_and_pvalue() - self.structure_ks_metric.reset() + ( + ks_distance, + p_value, + ) = ( + self.structure_ks_metric.compute_kolmogorov_smirnov_distance_and_pvalue() + ) + self.log( + "validation_ks_distance_structure", + ks_distance, + on_step=False, + on_epoch=True, + ) + self.log( + "validation_ks_p_value_structure", p_value, on_step=False, on_epoch=True + ) - self.log( - "validation_ks_distance_structure", - ks_distance, - on_step=False, - on_epoch=True, - ) - self.log( - "validation_ks_p_value_structure", p_value, on_step=False, on_epoch=True - ) + def on_validation_start(self) -> None: + """On validation start.""" + # Clear out any dangling state. + self.generator = None + if self.metrics_parameters.compute_energies: + self.energy_ks_metric.reset() + + if self.metrics_parameters.compute_structure_factor: + self.structure_ks_metric.reset() diff --git a/crystal_diffusion/oracle/energies.py b/crystal_diffusion/oracle/energies.py new file mode 100644 index 00000000..29eccc7d --- /dev/null +++ b/crystal_diffusion/oracle/energies.py @@ -0,0 +1,53 @@ +import logging +import tempfile +from typing import AnyStr, Dict + +import numpy as np +import torch + +from crystal_diffusion.namespace import CARTESIAN_POSITIONS, UNIT_CELL +from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps + +logger = logging.getLogger(__name__) + + +def compute_oracle_energies(samples: Dict[AnyStr, torch.Tensor]) -> torch.Tensor: + """Compute oracle energies. + + Method to call the oracle for samples expressed in a standardized format. + + Args: + samples: a dictionary assumed to contain the fields + - CARTESIAN_POSITIONS + - UNIT_CELL + + Returns: + energies: a numpy array with the computed energies. + """ + assert CARTESIAN_POSITIONS in samples, \ + f"the field '{CARTESIAN_POSITIONS}' must be present in the sample dictionary" + + assert UNIT_CELL in samples, \ + f"the field '{UNIT_CELL}' must be present in the sample dictionary" + + # Dimension [batch_size, space_dimension, space_dimension] + basis_vectors = samples[UNIT_CELL].detach().cpu().numpy() + + # Dimension [batch_size, number_of_atoms, space_dimension] + cartesian_positions = samples[CARTESIAN_POSITIONS].detach().cpu().numpy() + + number_of_atoms = cartesian_positions.shape[1] + atom_types = np.ones(number_of_atoms, dtype=int) + + logger.info("Compute energy from Oracle") + + list_energy = [] + with tempfile.TemporaryDirectory() as tmp_work_dir: + for positions, box in zip(cartesian_positions, basis_vectors): + energy, forces = get_energy_and_forces_from_lammps(positions, + box, + atom_types, + tmp_work_dir=tmp_work_dir) + list_energy.append(energy) + logger.info("Done computing energies from Oracle") + return torch.tensor(list_energy) diff --git a/crystal_diffusion/samples_and_metrics/diffusion_sampling_parameters.py b/crystal_diffusion/samples_and_metrics/diffusion_sampling_parameters.py new file mode 100644 index 00000000..27c41c16 --- /dev/null +++ b/crystal_diffusion/samples_and_metrics/diffusion_sampling_parameters.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Any, AnyStr, Dict, Union + +from crystal_diffusion.generators.load_sampling_parameters import \ + load_sampling_parameters +from crystal_diffusion.generators.position_generator import SamplingParameters +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.samples_and_metrics.sampling_metrics_parameters import \ + SamplingMetricsParameters + + +@dataclass(kw_only=True) +class DiffusionSamplingParameters: + """Diffusion sampling parameters. + + This dataclass holds various configuration objects that define how + samples should be generated and evaluated (ie, metrics) during training. + """ + sampling_parameters: SamplingParameters # Define the algorithm and parameters to draw samples. + noise_parameters: NoiseParameters # Noise for sampling, which can be different from training! + metrics_parameters: SamplingMetricsParameters # what should be done with the generated samples? + + +def load_diffusion_sampling_parameters(hyper_params: Dict[AnyStr, Any]) -> Union[DiffusionSamplingParameters, None]: + """Load diffusion sampling parameters. + + Extract the needed information from the configuration dictionary. + + Args: + hyper_params: dictionary of hyperparameters loaded from a config file + + Returns: + diffusion_sampling_parameters: the relevant configuration object. + """ + if 'diffusion_sampling' not in hyper_params: + return None + + diffusion_sampling_dict = hyper_params['diffusion_sampling'] + + assert 'sampling' in diffusion_sampling_dict, "The sampling parameters must be defined to draw samples." + sampling_parameters = load_sampling_parameters(diffusion_sampling_dict['sampling']) + + assert 'noise' in diffusion_sampling_dict, "The noise parameters must be defined to draw samples." + noise_parameters = NoiseParameters(**diffusion_sampling_dict['noise']) + + assert 'metrics' in diffusion_sampling_dict, "The metrics parameters must be defined to draw samples." + metrics_parameters = SamplingMetricsParameters(**diffusion_sampling_dict['metrics']) + + return DiffusionSamplingParameters(sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + metrics_parameters=metrics_parameters) diff --git a/crystal_diffusion/samples_and_metrics/kolmogorov_smirnov_metrics.py b/crystal_diffusion/samples_and_metrics/kolmogorov_smirnov_metrics.py index 4c0b91e9..5141ff50 100644 --- a/crystal_diffusion/samples_and_metrics/kolmogorov_smirnov_metrics.py +++ b/crystal_diffusion/samples_and_metrics/kolmogorov_smirnov_metrics.py @@ -9,21 +9,21 @@ class KolmogorovSmirnovMetrics: def __init__(self): """Init method.""" - self._reference_samples_metric = CatMetric() - self._predicted_samples_metric = CatMetric() + self.reference_samples_metric = CatMetric() + self.predicted_samples_metric = CatMetric() def register_reference_samples(self, reference_samples): """Register reference samples.""" - self._reference_samples_metric.update(reference_samples) + self.reference_samples_metric.update(reference_samples) def register_predicted_samples(self, predicted_samples): """Register predicted samples.""" - self._predicted_samples_metric.update(predicted_samples) + self.predicted_samples_metric.update(predicted_samples) def reset(self): """reset.""" - self._reference_samples_metric.reset() - self._predicted_samples_metric.reset() + self.reference_samples_metric.reset() + self.predicted_samples_metric.reset() def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]: """Compute Kolmogorov Smirnov Distance. @@ -39,8 +39,8 @@ def compute_kolmogorov_smirnov_distance_and_pvalue(self) -> Tuple[float, float]: ks_distance, p_value: the Kolmogorov-Smirnov test statistic (a "distance") and the statistical test's p-value. """ - reference_samples = self._reference_samples_metric.compute() - predicted_samples = self._predicted_samples_metric.compute() + reference_samples = self.reference_samples_metric.compute() + predicted_samples = self.predicted_samples_metric.compute() test_result = ss.ks_2samp(predicted_samples.detach().cpu().numpy(), reference_samples.detach().cpu().numpy(), diff --git a/crystal_diffusion/generators/sampling.py b/crystal_diffusion/samples_and_metrics/sampling.py similarity index 100% rename from crystal_diffusion/generators/sampling.py rename to crystal_diffusion/samples_and_metrics/sampling.py diff --git a/crystal_diffusion/samples_and_metrics/sampling_metrics_parameters.py b/crystal_diffusion/samples_and_metrics/sampling_metrics_parameters.py new file mode 100644 index 00000000..86ad5642 --- /dev/null +++ b/crystal_diffusion/samples_and_metrics/sampling_metrics_parameters.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + + +@dataclass(kw_only=True) +class SamplingMetricsParameters: + """Sampling metrics parameters. + + This dataclass configures what metrics should be computed given that samples have + been generated. + """ + compute_energies: bool = False # should the energies be computed + compute_structure_factor: bool = False # should the structure factor (distances distribution) be recorded + structure_factor_max_distance: float = 10.0 # cutoff for the structure factor diff --git a/crystal_diffusion/train_diffusion.py b/crystal_diffusion/train_diffusion.py index bbd89ad0..8d54cfcb 100644 --- a/crystal_diffusion/train_diffusion.py +++ b/crystal_diffusion/train_diffusion.py @@ -18,7 +18,8 @@ get_optimized_metric_name_and_mode, load_and_backup_hyperparameters, report_to_orion_if_on) -from crystal_diffusion.models.model_loader import load_diffusion_model +from crystal_diffusion.models.instantiate_diffusion_model import \ + load_diffusion_model from crystal_diffusion.utils.hp_utils import check_and_log_hp from crystal_diffusion.utils.logging_utils import (log_exp_details, setup_console_logger) @@ -181,6 +182,6 @@ def train(model, # Uncomment the following in order to use Pycharm's Remote Debugging server, which allows to # launch python commands through a bash script (and through Orion!). VERY useful for debugging. # This requires a professional edition of Pycharm and installing the pydevd_pycharm package with pip. - # import pydevd_pycharm - # pydevd_pycharm.settrace('localhost', port=50528, stdoutToServer=True, stderrToServer=True) + # import pydevd_pycharm + # pydevd_pycharm.settrace('localhost', port=56636, stdoutToServer=True, stderrToServer=True) main() diff --git a/examples/config_files/diffusion/config_diffusion_mlp.yaml b/examples/config_files/diffusion/config_diffusion_mlp.yaml index b4c737f6..31826b9d 100644 --- a/examples/config_files/diffusion/config_diffusion_mlp.yaml +++ b/examples/config_files/diffusion/config_diffusion_mlp.yaml @@ -1,7 +1,7 @@ # general exp_name: mlp_example -run_name: run2 -max_epoch: 500 +run_name: run1 +max_epoch: 10 log_every_n_steps: 1 gradient_clipping: 0 accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step @@ -25,15 +25,43 @@ model: architecture: mlp number_of_atoms: 8 n_hidden_dimensions: 2 - embedding_dimension_size: 16 + embedding_dimensions_size: 16 hidden_dimensions_size: 64 conditional_prob: 0.0 conditional_gamma: 2 condition_embedding_size: 64 noise: total_time_steps: 100 - sigma_min: 0.005 # default value - sigma_max: 0.5 # default value' + sigma_min: 0.0001 + sigma_max: 0.25 + + +# Sampling from the generative model +diffusion_sampling: + noise: + total_time_steps: 10 + sigma_min: 0.0001 + sigma_max: 0.1 + sampling: + algorithm: predictor_corrector + spatial_dimension: 3 + number_of_atoms: 8 + number_of_samples: 16 + sample_batchsize: 16 + record_samples: True + cell_dimensions: [5.43, 5.43, 5.43] + metrics: + compute_energies: True + compute_structure_factor: True + structure_factor_max_distance: 5.0 + + +sampling_visualization: + record_every_n_epochs: 1 + first_record_epoch: 0 + record_trajectories: True + record_energies: True + record_structure: True # optimizer and scheduler optimizer: @@ -61,22 +89,6 @@ loss_monitoring: number_of_bins: 50 sample_every_n_epochs: 25 -# Sampling from the generative model -diffusion_sampling: - noise: - total_time_steps: 100 - sigma_min: 0.001 # default value - sigma_max: 0.5 # default value - sampling: - algorithm: ode - spatial_dimension: 3 - number_of_atoms: 8 - number_of_samples: 16 - sample_batchsize: 16 - sample_every_n_epochs: 25 - record_samples: True - cell_dimensions: [5.43, 5.43, 5.43] - logging: # - comet - tensorboard diff --git a/examples/local/diffusion/run_diffusion.sh b/examples/local/diffusion/run_diffusion.sh index 47446dd8..ceec8b5f 100755 --- a/examples/local/diffusion/run_diffusion.sh +++ b/examples/local/diffusion/run_diffusion.sh @@ -3,7 +3,7 @@ # This example assumes that the dataset 'si_diffusion_small' is present locally in the DATA folder. # It is also assumed that the user has a Comet account for logging experiments. -CONFIG=../../config_files/diffusion/config_diffusion_egnn.yaml +CONFIG=../../config_files/diffusion/config_diffusion_mlp.yaml DATA_DIR=../../../data/si_diffusion_1x1x1 PROCESSED_DATA=${DATA_DIR}/processed DATA_WORK_DIR=${DATA_DIR}/cache/ @@ -15,4 +15,4 @@ python ../../../crystal_diffusion/train_diffusion.py \ --data $DATA_DIR \ --processed_datadir $PROCESSED_DATA \ --dataset_working_dir $DATA_WORK_DIR \ - --output $OUTPUT + --output $OUTPUT #> log.txt 2>&1 diff --git a/experiment_analysis/dataset_analysis/energy_consistency_analysis.py b/experiment_analysis/dataset_analysis/energy_consistency_analysis.py index 44ff518d..c3ed4afe 100644 --- a/experiment_analysis/dataset_analysis/energy_consistency_analysis.py +++ b/experiment_analysis/dataset_analysis/energy_consistency_analysis.py @@ -15,8 +15,8 @@ from crystal_diffusion import DATA_DIR from crystal_diffusion.analysis import PLOT_STYLE_PATH -from crystal_diffusion.callbacks.sampling_callback import ( - LOGGER_FIGSIZE, DiffusionSamplingCallback) +from crystal_diffusion.callbacks.sampling_visualization_callback import ( + LOGGER_FIGSIZE, SamplingVisualizationCallback) from crystal_diffusion.data.diffusion.data_loader import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps @@ -82,7 +82,7 @@ list_oracle_energies = np.array(list_oracle_energies) - fig = DiffusionSamplingCallback._plot_energy_histogram(list_oracle_energies, list_dataset_potential_energies) + fig = SamplingVisualizationCallback._plot_energy_histogram(list_oracle_energies, list_dataset_potential_energies) plt.show() fig2 = plt.figure(figsize=LOGGER_FIGSIZE) diff --git a/experiment_analysis/sampling_analysis/sampling_si_diffusion.py b/experiment_analysis/sampling_analysis/sampling_si_diffusion.py index 21997065..dd7d4518 100644 --- a/experiment_analysis/sampling_analysis/sampling_si_diffusion.py +++ b/experiment_analysis/sampling_analysis/sampling_si_diffusion.py @@ -16,7 +16,8 @@ from crystal_diffusion import DATA_DIR, TOP_DIR from crystal_diffusion.generators.langevin_generator import LangevinGenerator -from crystal_diffusion.models.model_loader import load_diffusion_model +from crystal_diffusion.models.instantiate_diffusion_model import \ + load_diffusion_model from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger diff --git a/experiment_analysis/sampling_sota_model/repaint_with_sota_score.py b/experiment_analysis/sampling_sota_model/repaint_with_sota_score.py index aaa8dc16..3eb0c564 100644 --- a/experiment_analysis/sampling_sota_model/repaint_with_sota_score.py +++ b/experiment_analysis/sampling_sota_model/repaint_with_sota_score.py @@ -11,7 +11,8 @@ from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH from crystal_diffusion.generators.constrained_langevin_generator import ( ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) -from crystal_diffusion.models.model_loader import load_diffusion_model +from crystal_diffusion.models.instantiate_diffusion_model import \ + load_diffusion_model from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger diff --git a/experiment_analysis/sampling_sota_model/sota_score_sampling_and_plotting.py b/experiment_analysis/sampling_sota_model/sota_score_sampling_and_plotting.py index b4a9dc47..0039ece7 100644 --- a/experiment_analysis/sampling_sota_model/sota_score_sampling_and_plotting.py +++ b/experiment_analysis/sampling_sota_model/sota_score_sampling_and_plotting.py @@ -17,7 +17,8 @@ ExplodingVarianceODEPositionGenerator, ODESamplingParameters) from crystal_diffusion.generators.predictor_corrector_position_generator import \ PredictorCorrectorSamplingParameters -from crystal_diffusion.models.model_loader import load_diffusion_model +from crystal_diffusion.models.instantiate_diffusion_model import \ + load_diffusion_model from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps from crystal_diffusion.samplers.variance_sampler import NoiseParameters from crystal_diffusion.utils.logging_utils import setup_analysis_logger diff --git a/tests/callbacks/test_sampling_callback.py b/tests/callbacks/test_sampling_callback.py deleted file mode 100644 index a80f2096..00000000 --- a/tests/callbacks/test_sampling_callback.py +++ /dev/null @@ -1,136 +0,0 @@ -from unittest.mock import MagicMock - -import numpy as np -import pytest -import torch -from pytorch_lightning import LightningModule - -from crystal_diffusion.callbacks.sampling_callback import \ - DiffusionSamplingCallback -from crystal_diffusion.generators.ode_position_generator import \ - ODESamplingParameters -from crystal_diffusion.generators.predictor_corrector_position_generator import \ - PredictorCorrectorSamplingParameters -from crystal_diffusion.samplers.variance_sampler import NoiseParameters - - -@pytest.mark.parametrize("total_time_steps", [1]) -@pytest.mark.parametrize("time_delta", [0.1]) -@pytest.mark.parametrize("sigma_min", [0.15]) -@pytest.mark.parametrize("corrector_step_epsilon", [0.25]) -@pytest.mark.parametrize("number_of_samples", [8]) -@pytest.mark.parametrize("unit_cell_size", [10]) -@pytest.mark.parametrize("lammps_energy", [2]) -@pytest.mark.parametrize("spatial_dimension", [3]) -@pytest.mark.parametrize("number_of_atoms", [4]) -@pytest.mark.parametrize("sample_batchsize", [None, 8, 4]) -@pytest.mark.parametrize("record_samples", [True, False]) -class TestSamplingCallback: - - @pytest.fixture(params=['predictor_corrector', 'ode']) - def algorithm(self, request): - return request.param - - @pytest.fixture() - def number_of_corrector_steps(self, algorithm): - if algorithm == 'predictor_corrector': - return 1 - else: - return 0 - - @pytest.fixture() - def mock_create_generator(self, number_of_atoms, spatial_dimension): - generator = MagicMock() - - def side_effect(n, device, unit_cell): - return torch.rand(n, number_of_atoms, spatial_dimension) - - generator.sample.side_effect = side_effect - return generator - - @pytest.fixture() - def mock_create_create_unit_cell(self, number_of_samples): - unit_cell = np.arange(number_of_samples) # Dummy unit cell - return unit_cell - - @pytest.fixture() - def mock_create_create_unit_cell_torch(self, number_of_samples, spatial_dimension): - unit_cell = torch.diag_embed(torch.rand(number_of_samples, spatial_dimension)) * 3 # Dummy unit cell - return unit_cell - - @pytest.fixture() - def mock_compute_lammps_energies(self, lammps_energy): - return np.ones((1,)) * lammps_energy - - @pytest.fixture() - def noise_parameters(self, total_time_steps, time_delta, sigma_min, corrector_step_epsilon): - noise_parameters = NoiseParameters(total_time_steps=total_time_steps, - time_delta=time_delta, - sigma_min=sigma_min, - corrector_step_epsilon=corrector_step_epsilon) - return noise_parameters - - @pytest.fixture() - def sampling_parameters(self, algorithm, spatial_dimension, number_of_corrector_steps, - number_of_atoms, number_of_samples, sample_batchsize, unit_cell_size, record_samples): - if algorithm == 'predictor_corrector': - sampling_parameters = ( - PredictorCorrectorSamplingParameters(spatial_dimension=spatial_dimension, - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - sample_batchsize=sample_batchsize, - cell_dimensions=[unit_cell_size for _ in range(spatial_dimension)], - record_samples=record_samples)) - elif algorithm == 'ode': - sampling_parameters = ( - ODESamplingParameters(spatial_dimension=spatial_dimension, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - sample_batchsize=sample_batchsize, - cell_dimensions=[unit_cell_size for _ in range(spatial_dimension)], - record_samples=record_samples)) - - else: - raise NotImplementedError - - return sampling_parameters - - @pytest.fixture() - def pl_model(self): - return MagicMock(spec=LightningModule) - - def test_sample_and_evaluate_energy(self, mocker, mock_compute_lammps_energies, mock_create_generator, - mock_create_create_unit_cell, noise_parameters, sampling_parameters, - pl_model, sample_batchsize, number_of_samples, tmpdir): - sampling_cb = DiffusionSamplingCallback( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=tmpdir) - mocker.patch.object(sampling_cb, "_create_generator", return_value=mock_create_generator) - mocker.patch.object(sampling_cb, "_create_unit_cell", return_value=mock_create_create_unit_cell) - mocker.patch.object(sampling_cb, "_compute_oracle_energies", return_value=mock_compute_lammps_energies) - - sample_energies, _ = sampling_cb.sample_and_evaluate_energy(pl_model) - assert isinstance(sample_energies, np.ndarray) - # each call of compute lammps energy yields a np.array of size 1 - expected_size = int(number_of_samples / sample_batchsize) if sample_batchsize is not None else 1 - assert sample_energies.shape[0] == expected_size - - def test_distances_calculation(self, mocker, mock_compute_lammps_energies, mock_create_generator, - mock_create_create_unit_cell_torch, noise_parameters, sampling_parameters, - pl_model, tmpdir): - sampling_parameters.structure_factor_max_distance = 5.0 - sampling_parameters.compute_structure_factor = True - - sampling_cb = DiffusionSamplingCallback( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=tmpdir) - mocker.patch.object(sampling_cb, "_create_generator", return_value=mock_create_generator) - mocker.patch.object(sampling_cb, "_create_unit_cell", return_value=mock_create_create_unit_cell_torch) - mocker.patch.object(sampling_cb, "_compute_oracle_energies", return_value=mock_compute_lammps_energies) - - _, sample_distances = sampling_cb.sample_and_evaluate_energy(pl_model) - assert isinstance(sample_distances, np.ndarray) - assert all(sample_distances > 0) diff --git a/tests/models/test_position_diffusion_lightning_model.py b/tests/models/test_position_diffusion_lightning_model.py index 257c8b22..0fbacc79 100644 --- a/tests/models/test_position_diffusion_lightning_model.py +++ b/tests/models/test_position_diffusion_lightning_model.py @@ -15,6 +15,10 @@ MLPScoreNetworkParameters from crystal_diffusion.namespace import CARTESIAN_FORCES, RELATIVE_COORDINATES from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.samples_and_metrics.diffusion_sampling_parameters import \ + DiffusionSamplingParameters +from crystal_diffusion.samples_and_metrics.sampling_metrics_parameters import \ + SamplingMetricsParameters from crystal_diffusion.score.wrapped_gaussian_score import \ get_sigma_normalized_score_brute_force from crystal_diffusion.utils.tensor_utils import \ @@ -107,14 +111,23 @@ def sampling_parameters(self, number_of_atoms, spatial_dimension, number_of_samp sampling_parameters = PredictorCorrectorSamplingParameters(number_of_atoms=number_of_atoms, spatial_dimension=spatial_dimension, number_of_samples=number_of_samples, - compute_structure_factor=True, - structure_factor_max_distance=min(cell_dimensions), cell_dimensions=cell_dimensions) return sampling_parameters + @pytest.fixture() + def diffusion_sampling_parameters(self, sampling_parameters): + noise_parameters = NoiseParameters(total_time_steps=5) + metrics_parameters = SamplingMetricsParameters(structure_factor_max_distance=1.) + diffusion_sampling_parameters = DiffusionSamplingParameters( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + metrics_parameters=metrics_parameters) + return diffusion_sampling_parameters + @pytest.fixture() def hyper_params(self, number_of_atoms, spatial_dimension, - optimizer_parameters, scheduler_parameters, loss_parameters, sampling_parameters): + optimizer_parameters, scheduler_parameters, + loss_parameters, sampling_parameters, diffusion_sampling_parameters): score_network_parameters = MLPScoreNetworkParameters( number_of_atoms=number_of_atoms, n_hidden_dimensions=3, @@ -131,7 +144,7 @@ def hyper_params(self, number_of_atoms, spatial_dimension, scheduler_parameters=scheduler_parameters, noise_parameters=noise_parameters, loss_parameters=loss_parameters, - sampling_parameters=sampling_parameters + diffusion_sampling_parameters=diffusion_sampling_parameters ) return hyper_params diff --git a/tests/samples_and_metrics/__init__.py b/tests/samples_and_metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/generators/test_sampling.py b/tests/samples_and_metrics/test_sampling.py similarity index 96% rename from tests/generators/test_sampling.py rename to tests/samples_and_metrics/test_sampling.py index 9d813170..e913ded3 100644 --- a/tests/generators/test_sampling.py +++ b/tests/samples_and_metrics/test_sampling.py @@ -4,9 +4,10 @@ from crystal_diffusion.generators.position_generator import ( PositionGenerator, SamplingParameters) -from crystal_diffusion.generators.sampling import create_batch_of_samples from crystal_diffusion.namespace import (CARTESIAN_POSITIONS, RELATIVE_COORDINATES, UNIT_CELL) +from crystal_diffusion.samples_and_metrics.sampling import \ + create_batch_of_samples from crystal_diffusion.utils.basis_transformations import \ get_positions_from_coordinates diff --git a/tests/test_train_diffusion.py b/tests/test_train_diffusion.py index 53a700a8..0df62750 100644 --- a/tests/test_train_diffusion.py +++ b/tests/test_train_diffusion.py @@ -102,7 +102,6 @@ def get_config(number_of_atoms: int, max_epoch: int, architecture: str, head_nam spatial_dimension=3, number_of_atoms=number_of_atoms, number_of_samples=4, - sample_every_n_epochs=1, record_samples=True, cell_dimensions=[10., 10., 10.]) if sampling_algorithm == 'predictor_corrector': @@ -110,7 +109,11 @@ def get_config(number_of_atoms: int, max_epoch: int, architecture: str, head_nam early_stopping_config = dict(metric='validation_epoch_loss', mode='min', patience=max_epoch) model_checkpoint_config = dict(monitor='validation_epoch_loss', mode='min') - diffusion_sampling_config = dict(noise={'total_time_steps': 10}, sampling=sampling_dict) + diffusion_sampling_config = dict(noise={'total_time_steps': 10}, + sampling=sampling_dict, + metrics={'compute_energies': False, + 'compute_structure_factor': True, + 'structure_factor_max_distance': 5.0}) config = dict(max_epoch=max_epoch, exp_name='smoke_test', From 14af5fa76c06cc7f0bb3c200f7a3c30f49066641 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sat, 21 Sep 2024 18:38:07 -0400 Subject: [PATCH 13/20] Refactor name. --- .../callbacks/sampling_callback.py | 401 ------------------ .../sampling_visualization_callback.py | 294 +++++++++++++ 2 files changed, 294 insertions(+), 401 deletions(-) delete mode 100644 crystal_diffusion/callbacks/sampling_callback.py create mode 100644 crystal_diffusion/callbacks/sampling_visualization_callback.py diff --git a/crystal_diffusion/callbacks/sampling_callback.py b/crystal_diffusion/callbacks/sampling_callback.py deleted file mode 100644 index 6c47fe05..00000000 --- a/crystal_diffusion/callbacks/sampling_callback.py +++ /dev/null @@ -1,401 +0,0 @@ -import logging -import os -import tempfile -from pathlib import Path -from typing import Any, AnyStr, Dict, Optional, Tuple - -import numpy as np -import scipy.stats as ss -import torch -from matplotlib import pyplot as plt -from pytorch_lightning import Callback, LightningModule, Trainer - -from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH -from crystal_diffusion.generators.langevin_generator import LangevinGenerator -from crystal_diffusion.generators.ode_position_generator import ( - ExplodingVarianceODEPositionGenerator, ODESamplingParameters) -from crystal_diffusion.generators.position_generator import ( - PositionGenerator, SamplingParameters) -from crystal_diffusion.generators.predictor_corrector_position_generator import \ - PredictorCorrectorSamplingParameters -from crystal_diffusion.generators.sde_position_generator import ( - ExplodingVarianceSDEPositionGenerator, SDESamplingParameters) -from crystal_diffusion.loggers.logger_loader import log_figure -from crystal_diffusion.namespace import CARTESIAN_POSITIONS -from crystal_diffusion.oracle.lammps import get_energy_and_forces_from_lammps -from crystal_diffusion.samplers.variance_sampler import NoiseParameters -from crystal_diffusion.utils.basis_transformations import \ - get_positions_from_coordinates -from crystal_diffusion.utils.structure_utils import ( - compute_distances_in_batch, get_orthogonal_basis_vectors) - -logger = logging.getLogger(__name__) - -plt.style.use(PLOT_STYLE_PATH) - - -def instantiate_diffusion_sampling_callback(callback_params: Dict[AnyStr, Any], - output_directory: str, - verbose: bool) -> Dict[str, Callback]: - """Instantiate the Diffusion Sampling callback.""" - noise_parameters = NoiseParameters(**callback_params['noise']) - - sampling_parameter_dictionary = callback_params['sampling'] - assert 'algorithm' in sampling_parameter_dictionary, "The sampling parameters must select an algorithm." - algorithm = sampling_parameter_dictionary['algorithm'] - - assert algorithm in ['ode', 'sde', 'predictor_corrector'], \ - "Unknown algorithm. Possible choices are 'ode', 'sde' and 'predictor_corrector'" - - match algorithm: - case 'predictor_corrector': - sampling_parameters = PredictorCorrectorSamplingParameters(**sampling_parameter_dictionary) - diffusion_sampling_callback = ( - PredictorCorrectorDiffusionSamplingCallback(noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=output_directory) - ) - case 'ode': - sampling_parameters = ODESamplingParameters(**sampling_parameter_dictionary) - diffusion_sampling_callback = ( - ODEDiffusionSamplingCallback(noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=output_directory) - ) - case 'sde': - sampling_parameters = SDESamplingParameters(**sampling_parameter_dictionary) - diffusion_sampling_callback = ( - SDEDiffusionSamplingCallback(noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=output_directory) - ) - case _: - raise NotImplementedError("algorithm is not implemented") - - return dict(diffusion_sampling=diffusion_sampling_callback) - - -class DiffusionSamplingCallback(Callback): - """Callback class to periodically generate samples and log their energies.""" - - def __init__(self, noise_parameters: NoiseParameters, - sampling_parameters: SamplingParameters, - output_directory: str - ): - """Init method.""" - self.noise_parameters = noise_parameters - self.sampling_parameters = sampling_parameters - self.output_directory = output_directory - - self.energy_sample_output_directory = os.path.join(output_directory, 'energy_samples') - Path(self.energy_sample_output_directory).mkdir(parents=True, exist_ok=True) - - if self.sampling_parameters.record_samples: - self.position_sample_output_directory = os.path.join(output_directory, 'diffusion_position_samples') - Path(self.position_sample_output_directory).mkdir(parents=True, exist_ok=True) - - self.compute_structure_factor = sampling_parameters.compute_structure_factor - self.structure_factor_max_distance = sampling_parameters.structure_factor_max_distance - - self._initialize_validation_energies_array() - self._initialize_validation_distance_array() - - @staticmethod - def compute_kolmogorov_smirnov_distance_and_pvalue(sampling_energies: np.ndarray, - reference_energies: np.ndarray) -> Tuple[float, float]: - """Compute Kolmogorov Smirnov Distance. - - Compute the two sample Kolmogorov–Smirnov test in order to gauge whether the - sample_energies sample was drawn from the same distribution as the reference_energies. - - Args: - sampling_energies : a sample of energies drawn from the diffusion model. - reference_energies :a sample of energies drawn from the reference distribution. - - Returns: - ks_distance, p_value: the Kolmogorov-Smirnov test statistic (a "distance") - and the statistical test's p-value. - """ - test_result = ss.ks_2samp(sampling_energies, reference_energies, alternative='two-sided', method='auto') - - # The "test statistic" of the two-sided KS test is the largest vertical distance between - # the empirical CDFs of the two samples. The larger this is, the less likely the two - # samples were drawn from the same underlying distribution, hence the idea of 'distance'. - ks_distance = test_result.statistic - - # The null hypothesis of the KS test is that both samples are drawn from the same distribution. - # Thus, a small p-value (which leads to the rejection of the null hypothesis) indicates that - # the samples probably come from different distributions (ie, our samples are bad!). - p_value = test_result.pvalue - return ks_distance, p_value - - def _compute_results_at_this_epoch(self, current_epoch: int) -> bool: - """Check if results should be computed at this epoch.""" - # Do not produce results at epoch 0; it would be meaningless. - if (current_epoch % self.sampling_parameters.sample_every_n_epochs == 0 - and current_epoch >= self.sampling_parameters.first_sampling_epoch): - return True - else: - return False - - def _initialize_validation_energies_array(self): - """Initialize the validation energies array to an empty array.""" - # The validation energies will be extracted at epochs where it is needed. Although this - # data does not change, we will avoid having this in memory at all times. - self.validation_energies = np.array([]) - - def _initialize_validation_distance_array(self): - """Initialize the distances array to an empty array.""" - # this is similar to the energy array - self.validation_distances = np.array([]) - - def _create_generator(self, pl_model: LightningModule) -> PositionGenerator: - """Draw a sample from the generative model.""" - raise NotImplementedError("This method must be implemented in a child class") - - def _create_unit_cell(self, pl_model) -> torch.Tensor: - """Create the batch of unit cells needed by the generative model.""" - # TODO we will have to sample unit cell dimensions at some points instead of working with fixed size - unit_cell = ( - get_orthogonal_basis_vectors(batch_size=self.sampling_parameters.number_of_samples, - cell_dimensions=self.sampling_parameters.cell_dimensions).to(pl_model.device)) - return unit_cell - - @staticmethod - def _plot_energy_histogram(sample_energies: np.ndarray, validation_dataset_energies: np.array, - epoch: int) -> plt.figure: - """Generate a plot of the energy samples.""" - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - - minimum_energy = validation_dataset_energies.min() - maximum_energy = validation_dataset_energies.max() - energy_range = maximum_energy - minimum_energy - - emin = minimum_energy - 0.2 * energy_range - emax = maximum_energy + 0.2 * energy_range - bins = np.linspace(emin, emax, 101) - - number_of_samples_in_range = np.logical_and(sample_energies >= emin, sample_energies <= emax).sum() - - fig.suptitle(f'Sampling Energy Distributions\nEpoch {epoch}') - - common_params = dict(density=True, bins=bins, histtype="stepfilled", alpha=0.25) - - ax1 = fig.add_subplot(111) - - ax1.hist(sample_energies, **common_params, - label=f'Samples \n(total count = {len(sample_energies)}, in range = {number_of_samples_in_range})', - color='red') - ax1.hist(validation_dataset_energies, **common_params, - label=f'Validation Data \n(count = {len(validation_dataset_energies)})', color='green') - - ax1.set_xlabel('Energy (eV)') - ax1.set_ylabel('Density') - ax1.legend(loc='upper right', fancybox=True, shadow=True, ncol=1, fontsize=6) - fig.tight_layout() - return fig - - @staticmethod - def _plot_distance_histogram(sample_distances: np.ndarray, validation_dataset_distances: np.array, - epoch: int) -> plt.figure: - """Generate a plot of the inter-atomic distances of the samples.""" - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - - maximum_distance = validation_dataset_distances.max() - - dmin = 0.0 - dmax = maximum_distance + 0.1 - bins = np.linspace(dmin, dmax, 101) - - fig.suptitle(f'Sampling Distances Distribution\nEpoch {epoch}') - - common_params = dict(density=True, bins=bins, histtype="stepfilled", alpha=0.25) - - ax1 = fig.add_subplot(111) - - ax1.hist(sample_distances, **common_params, - label=f'Samples \n(total count = {len(sample_distances)})', - color='red') - ax1.hist(validation_dataset_distances, **common_params, - label=f'Validation Data \n(count = {len(validation_dataset_distances)})', color='green') - - ax1.set_xlabel(r'Distance ($\AA$)') - ax1.set_ylabel('Density') - ax1.legend(loc='upper right', fancybox=True, shadow=True, ncol=1, fontsize=6) - ax1.set_xlim(left=dmin, right=dmax) - fig.tight_layout() - return fig - - def _compute_oracle_energies(self, batch_relative_coordinates: torch.Tensor) -> np.ndarray: - """Compute energies from samples.""" - batch_size = batch_relative_coordinates.shape[0] - cell_dimensions = self.sampling_parameters.cell_dimensions - basis_vectors = get_orthogonal_basis_vectors(batch_size, cell_dimensions) - batch_cartesian_positions = get_positions_from_coordinates(batch_relative_coordinates, basis_vectors) - - atom_types = np.ones(self.sampling_parameters.number_of_atoms, dtype=int) - - list_energy = [] - - logger.info("Compute energy from Oracle") - - with tempfile.TemporaryDirectory() as tmp_work_dir: - for positions, box in zip(batch_cartesian_positions.numpy(), basis_vectors.numpy()): - energy, forces = get_energy_and_forces_from_lammps(positions, - box, - atom_types, - tmp_work_dir=tmp_work_dir) - list_energy.append(energy) - - return np.array(list_energy) - - def sample_and_evaluate_energy(self, pl_model: LightningModule, current_epoch: int = 0 - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: - """Create samples and estimate their energy with an oracle (LAMMPS). - - Args: - pl_model: pytorch-lightning model - current_epoch (optional): current epoch to save files. Defaults to 0. - - Returns: - array with energy of each sample from LAMMPS - """ - generator = self._create_generator(pl_model) - unit_cell = self._create_unit_cell(pl_model) - - logger.info("Draw samples") - - if self.sampling_parameters.sample_batchsize is None: - self.sampling_parameters.sample_batchsize = self.sampling_parameters.number_of_samples - - sample_energies = [] - sample_distances = [] - - for n in range(0, self.sampling_parameters.number_of_samples, self.sampling_parameters.sample_batchsize): - unit_cell_ = unit_cell[n:min(n + self.sampling_parameters.sample_batchsize, - self.sampling_parameters.number_of_samples)] - samples = generator.sample(min(self.sampling_parameters.number_of_samples - n, - self.sampling_parameters.sample_batchsize), - device=pl_model.device, - unit_cell=unit_cell_) - if self.sampling_parameters.record_samples: - sample_output_path = os.path.join(self.position_sample_output_directory, - f"diffusion_position_sample_epoch={current_epoch}" - + f"_steps={n}.pt") - # write trajectories to disk and reset to save memory - generator.sample_trajectory_recorder.write_to_pickle(sample_output_path) - generator.sample_trajectory_recorder.reset() - if self.compute_structure_factor: - batch_cartesian_positions = get_positions_from_coordinates(samples.detach(), unit_cell_) - sample_distances += [ - compute_distances_in_batch(batch_cartesian_positions, - unit_cell_, - self.structure_factor_max_distance - ).cpu().numpy() - ] - batch_relative_coordinates = samples.detach().cpu() - sample_energies += [self._compute_oracle_energies(batch_relative_coordinates)] - - sample_energies = np.concatenate(sample_energies) - if self.compute_structure_factor: - sample_distances = np.concatenate(sample_distances) - else: - sample_distances = None - - return sample_energies, sample_distances - - def on_validation_batch_start(self, trainer: Trainer, - pl_module: LightningModule, batch: Any, batch_idx: int) -> None: - """On validation batch start, accumulate the validation dataset energies for further processing.""" - if not self._compute_results_at_this_epoch(trainer.current_epoch): - return - self.validation_energies = np.append(self.validation_energies, batch['potential_energy'].cpu().numpy()) - - if self.compute_structure_factor: - unit_cell = torch.diag_embed(batch['box']) - batch_distances = compute_distances_in_batch(batch[CARTESIAN_POSITIONS], unit_cell, - self.structure_factor_max_distance) - self.validation_distances = np.append(self.validation_distances, batch_distances.cpu().numpy()) - - def on_validation_epoch_end(self, trainer: Trainer, pl_model: LightningModule) -> None: - """On validation epoch end.""" - if not self._compute_results_at_this_epoch(trainer.current_epoch): - return - - # generate samples and evaluate their energy with an oracle - sample_energies, sample_distances = self.sample_and_evaluate_energy(pl_model, trainer.current_epoch) - - energy_output_path = os.path.join(self.energy_sample_output_directory, - f"energies_sample_epoch={trainer.current_epoch}.pt") - torch.save(torch.from_numpy(sample_energies), energy_output_path) - - fig = self._plot_energy_histogram(sample_energies, self.validation_energies, trainer.current_epoch) - ks_distance, p_value = self.compute_kolmogorov_smirnov_distance_and_pvalue(sample_energies, - self.validation_energies) - - pl_model.log("validation_epoch_energy_ks_distance", ks_distance, on_step=False, on_epoch=True) - pl_model.log("validation_epoch_energy_ks_p_value", p_value, on_step=False, on_epoch=True) - - for pl_logger in trainer.loggers: - log_figure(figure=fig, global_step=trainer.global_step, pl_logger=pl_logger) - - self._initialize_validation_energies_array() - - if self.compute_structure_factor: - distance_output_path = os.path.join(self.energy_sample_output_directory, - f"distances_sample_epoch={trainer.current_epoch}.pt") - torch.save(torch.from_numpy(sample_distances), distance_output_path) - fig = self._plot_distance_histogram(sample_distances, self.validation_distances, trainer.current_epoch) - ks_distance, p_value = self.compute_kolmogorov_smirnov_distance_and_pvalue(sample_distances, - self.validation_distances) - pl_model.log("validation_epoch_distances_ks_distance", ks_distance, on_step=False, on_epoch=True) - pl_model.log("validation_epoch_distances_ks_p_value", p_value, on_step=False, on_epoch=True) - - for pl_logger in trainer.loggers: - log_figure(figure=fig, global_step=trainer.global_step, pl_logger=pl_logger, name="distances") - - self._initialize_validation_distance_array() - - -class PredictorCorrectorDiffusionSamplingCallback(DiffusionSamplingCallback): - """Callback class to periodically generate samples and log their energies.""" - - def _create_generator(self, pl_model: LightningModule) -> LangevinGenerator: - """Draw a sample from the generative model.""" - logger.info("Creating sampler") - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - generator = LangevinGenerator(noise_parameters=self.noise_parameters, - sampling_parameters=self.sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network) - - return generator - - -class ODEDiffusionSamplingCallback(DiffusionSamplingCallback): - """Callback class to periodically generate samples and log their energies.""" - - def _create_generator(self, pl_model: LightningModule) -> ExplodingVarianceODEPositionGenerator: - """Draw a sample from the generative model.""" - logger.info("Creating sampler") - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - generator = ExplodingVarianceODEPositionGenerator(noise_parameters=self.noise_parameters, - sampling_parameters=self.sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network) - - return generator - - -class SDEDiffusionSamplingCallback(DiffusionSamplingCallback): - """Callback class to periodically generate samples and log their energies.""" - - def _create_generator(self, pl_model: LightningModule) -> ExplodingVarianceODEPositionGenerator: - """Draw a sample from the generative model.""" - logger.info("Creating sampler") - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - generator = ExplodingVarianceSDEPositionGenerator(noise_parameters=self.noise_parameters, - sampling_parameters=self.sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network) - return generator diff --git a/crystal_diffusion/callbacks/sampling_visualization_callback.py b/crystal_diffusion/callbacks/sampling_visualization_callback.py new file mode 100644 index 00000000..3a005982 --- /dev/null +++ b/crystal_diffusion/callbacks/sampling_visualization_callback.py @@ -0,0 +1,294 @@ +import logging +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, AnyStr, Dict + +import numpy as np +import torch +from matplotlib import pyplot as plt +from pytorch_lightning import Callback, LightningModule, Trainer + +from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH +from crystal_diffusion.loggers.logger_loader import log_figure + +logger = logging.getLogger(__name__) + +plt.style.use(PLOT_STYLE_PATH) + + +@dataclass(kw_only=True) +class SamplingVisualizationParameters: + """Parameters to decide what to plot and write to disk.""" + record_every_n_epochs: int = 1 + first_record_epoch: int = 1 + record_trajectories: bool = True + record_energies: bool = True + record_structure: bool = True + + +def instantiate_sampling_visualization_callback( + callback_params: Dict[AnyStr, Any], output_directory: str, verbose: bool +) -> Dict[str, Callback]: + """Instantiate the Diffusion Sampling callback.""" + sampling_visualization_parameters = SamplingVisualizationParameters( + **callback_params + ) + + callback = SamplingVisualizationCallback( + sampling_visualization_parameters, output_directory + ) + + return dict(sampling_visualization=callback) + + +class SamplingVisualizationCallback(Callback): + """Callback class to periodically generate samples and log their energies.""" + + def __init__( + self, + sampling_visualization_parameters: SamplingVisualizationParameters, + output_directory: str, + ): + """Init method.""" + self.parameters = sampling_visualization_parameters + self.output_directory = output_directory + + if self.parameters.record_energies: + self.sample_energies_output_directory = os.path.join( + output_directory, "energy_samples" + ) + Path(self.sample_energies_output_directory).mkdir( + parents=True, exist_ok=True + ) + + if self.parameters.record_structure: + self.sample_distances_output_directory = os.path.join( + output_directory, "distance_samples" + ) + Path(self.sample_distances_output_directory).mkdir( + parents=True, exist_ok=True + ) + + if self.parameters.record_trajectories: + self.sample_trajectories_output_directory = os.path.join( + output_directory, "trajectory_samples" + ) + Path(self.sample_trajectories_output_directory).mkdir( + parents=True, exist_ok=True + ) + + def on_validation_end(self, trainer: Trainer, pl_model: LightningModule) -> None: + """On validation end.""" + if not self._compute_results_at_this_epoch(trainer.current_epoch): + return + + if self.parameters.record_energies: + assert ( + pl_model.energy_ks_metric is not None + ), "The energy_ks_metric is absent. Energy calculation must be requested in order to be visualized!" + reference_energies = ( + pl_model.energy_ks_metric.reference_samples_metric.compute() + ) + sample_energies = ( + pl_model.energy_ks_metric.predicted_samples_metric.compute() + ) + energy_output_path = os.path.join( + self.sample_energies_output_directory, + f"energies_sample_epoch={trainer.current_epoch}.pt", + ) + torch.save(sample_energies, energy_output_path) + + sample_energies = sample_energies.cpu().numpy() + reference_energies = reference_energies.cpu().numpy() + + fig1 = self._plot_energy_histogram( + sample_energies, reference_energies, trainer.current_epoch + ) + fig2 = self._plot_energy_quantiles( + sample_energies, reference_energies, trainer.current_epoch + ) + + for pl_logger in trainer.loggers: + log_figure( + figure=fig1, + global_step=trainer.global_step, + dataset="validation", + pl_logger=pl_logger, + name="energy_distribution", + ) + log_figure( + figure=fig2, + global_step=trainer.global_step, + dataset="validation", + pl_logger=pl_logger, + name="energy_quantiles", + ) + + if self.parameters.record_structure: + assert pl_model.structure_ks_metric is not None, ( + "The structure_ks_metric is absent. Structure factor calculation " + "must be requested in order to be visualized!" + ) + + reference_distances = ( + pl_model.structure_ks_metric.reference_samples_metric.compute() + ) + sample_distances = ( + pl_model.structure_ks_metric.predicted_samples_metric.compute() + ) + + distance_output_path = os.path.join( + self.sample_distances_output_directory, + f"distances_sample_epoch={trainer.current_epoch}.pt", + ) + + torch.save(sample_distances, distance_output_path) + fig = self._plot_distance_histogram( + sample_distances.numpy(), + reference_distances.numpy(), + trainer.current_epoch, + ) + + for pl_logger in trainer.loggers: + log_figure( + figure=fig, + global_step=trainer.global_step, + dataset="validation", + pl_logger=pl_logger, + name="distances", + ) + + if self.parameters.record_trajectories: + assert ( + pl_model.generator is not None + ), "Cannot record trajectories if a generator has not be created." + + pickle_output_path = os.path.join( + self.sample_trajectories_output_directory, + f"trajectories_sample_epoch={trainer.current_epoch}.pt", + ) + pl_model.generator.sample_trajectory_recorder.write_to_pickle( + pickle_output_path + ) + + def _compute_results_at_this_epoch(self, current_epoch: int) -> bool: + """Check if results should be computed at this epoch.""" + if ( + current_epoch % self.parameters.record_every_n_epochs == 0 + and current_epoch >= self.parameters.first_record_epoch + ): + return True + else: + return False + + @staticmethod + def _plot_energy_quantiles( + sample_energies: np.ndarray, validation_dataset_energies: np.array, epoch: int + ) -> plt.figure: + """Generate a plot of the energy quantiles.""" + list_q = np.linspace(0, 1, 101) + sample_quantiles = np.quantile(sample_energies, list_q) + dataset_quantiles = np.quantile(validation_dataset_energies, list_q) + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + fig.suptitle(f"Sampling Energy Quantiles\nEpoch {epoch}") + ax = fig.add_subplot(111) + + label = f"Samples \n(total count = {len(sample_energies)})" + ax.plot(100 * list_q, sample_quantiles, "-", lw=5, color="red", label=label) + + label = f"Validation Data \n(count = {len(validation_dataset_energies)})" + ax.plot( + 100 * list_q, dataset_quantiles, "--", lw=10, color="green", label=label + ) + ax.set_xlabel("Quantile (%)") + ax.set_ylabel("Energy (eV)") + ax.set_xlim(-0.1, 100.1) + ax.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) + fig.tight_layout() + + return fig + + @staticmethod + def _plot_energy_histogram( + sample_energies: np.ndarray, validation_dataset_energies: np.array, epoch: int + ) -> plt.figure: + """Generate a plot of the energy samples.""" + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + + minimum_energy = validation_dataset_energies.min() + maximum_energy = validation_dataset_energies.max() + energy_range = maximum_energy - minimum_energy + + emin = minimum_energy - 0.2 * energy_range + emax = maximum_energy + 0.2 * energy_range + bins = np.linspace(emin, emax, 101) + + number_of_samples_in_range = np.logical_and( + sample_energies >= emin, sample_energies <= emax + ).sum() + + fig.suptitle(f"Sampling Energy Distributions\nEpoch {epoch}") + + common_params = dict(density=True, bins=bins, histtype="stepfilled", alpha=0.25) + + ax1 = fig.add_subplot(111) + + ax1.hist( + sample_energies, + **common_params, + label=f"Samples \n(total count = {len(sample_energies)}, in range = {number_of_samples_in_range})", + color="red", + ) + ax1.hist( + validation_dataset_energies, + **common_params, + label=f"Validation Data \n(count = {len(validation_dataset_energies)})", + color="green", + ) + + ax1.set_xlabel("Energy (eV)") + ax1.set_ylabel("Density") + ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) + fig.tight_layout() + return fig + + @staticmethod + def _plot_distance_histogram( + sample_distances: np.ndarray, validation_dataset_distances: np.array, epoch: int + ) -> plt.figure: + """Generate a plot of the inter-atomic distances of the samples.""" + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + + maximum_distance = validation_dataset_distances.max() + + dmin = 0.0 + dmax = maximum_distance + 0.1 + bins = np.linspace(dmin, dmax, 101) + + fig.suptitle(f"Sampling Distances Distribution\nEpoch {epoch}") + + common_params = dict(density=True, bins=bins, histtype="stepfilled", alpha=0.25) + + ax1 = fig.add_subplot(111) + + ax1.hist( + sample_distances, + **common_params, + label=f"Samples \n(total count = {len(sample_distances)})", + color="red", + ) + ax1.hist( + validation_dataset_distances, + **common_params, + label=f"Validation Data \n(count = {len(validation_dataset_distances)})", + color="green", + ) + + ax1.set_xlabel(r"Distance ($\AA$)") + ax1.set_ylabel("Density") + ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=6) + ax1.set_xlim(left=dmin, right=dmax) + fig.tight_layout() + return fig From d134e5d9d23f7abf434983e1f3794e1ee169dbe5 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sat, 21 Sep 2024 21:49:55 -0400 Subject: [PATCH 14/20] Fix misconfig. --- crystal_diffusion/models/position_diffusion_lightning_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 34364e30..491f1ca4 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -356,7 +356,7 @@ def generate_samples(self): logger.info("Creating Generator for sampling") self.generator = instantiate_generator( sampling_parameters=self.hyper_params.diffusion_sampling_parameters.sampling_parameters, - noise_parameters=self.hyper_params.noise_parameters, + noise_parameters=self.hyper_params.diffusion_sampling_parameters.noise_parameters, sigma_normalized_score_network=self.sigma_normalized_score_network, ) logger.info(f"Generator type : {type(self.generator)}") From 7ec6985ea5678266d1d73921fcca24cf2f29857e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sun, 22 Sep 2024 09:32:52 -0400 Subject: [PATCH 15/20] Close mpl figures. --- crystal_diffusion/callbacks/sampling_visualization_callback.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crystal_diffusion/callbacks/sampling_visualization_callback.py b/crystal_diffusion/callbacks/sampling_visualization_callback.py index 3a005982..7f12f537 100644 --- a/crystal_diffusion/callbacks/sampling_visualization_callback.py +++ b/crystal_diffusion/callbacks/sampling_visualization_callback.py @@ -124,6 +124,8 @@ def on_validation_end(self, trainer: Trainer, pl_model: LightningModule) -> None pl_logger=pl_logger, name="energy_quantiles", ) + plt.close(fig1) + plt.close(fig2) if self.parameters.record_structure: assert pl_model.structure_ks_metric is not None, ( @@ -158,6 +160,7 @@ def on_validation_end(self, trainer: Trainer, pl_model: LightningModule) -> None pl_logger=pl_logger, name="distances", ) + plt.close(fig) if self.parameters.record_trajectories: assert ( From 253878077faea4b2ecac4828673610d6ac64b98a Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Sun, 22 Sep 2024 09:40:23 -0400 Subject: [PATCH 16/20] A bit more logging. --- .../position_diffusion_lightning_model.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 491f1ca4..d9fbfa50 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -277,6 +277,7 @@ def _get_target_normalized_score( def training_step(self, batch, batch_idx): """Runs a prediction step for training, returning the loss.""" + logger.info(f" - Starting training step with batch index {batch_idx}") output = self._generic_step(batch, batch_idx) loss = output["loss"] @@ -293,10 +294,12 @@ def training_step(self, batch, batch_idx): on_step=False, on_epoch=True, ) + logger.info(f" Done training step with batch index {batch_idx}") return output def validation_step(self, batch, batch_idx): """Runs a prediction step for validation, logging the loss.""" + logger.info(f" - Starting validation step with batch index {batch_idx}") output = self._generic_step(batch, batch_idx, no_conditional=True) loss = output["loss"] batch_size = self._get_batch_size(batch) @@ -315,10 +318,12 @@ def validation_step(self, batch, batch_idx): return output if self.metrics_parameters.compute_energies: + logger.info(" * registering reference energies") reference_energies = batch["potential_energy"] self.energy_ks_metric.register_reference_samples(reference_energies.cpu()) if self.metrics_parameters.compute_structure_factor: + logger.info(" * registering reference distances") basis_vectors = torch.diag_embed(batch["box"]) cartesian_positions = get_positions_from_coordinates( relative_coordinates=batch[RELATIVE_COORDINATES], @@ -334,6 +339,7 @@ def validation_step(self, batch, batch_idx): reference_distances.cpu() ) + logger.info(f" Done validation step with batch index {batch_idx}") return output def test_step(self, batch, batch_idx): @@ -361,12 +367,13 @@ def generate_samples(self): ) logger.info(f"Generator type : {type(self.generator)}") - logger.info("Draw samples") + logger.info(" * Drawing samples") samples_batch = create_batch_of_samples( generator=self.generator, sampling_parameters=self.hyper_params.diffusion_sampling_parameters.sampling_parameters, device=self.device, ) + logger.info(" Done drawing samples") return samples_batch def on_validation_epoch_end(self) -> None: @@ -378,7 +385,9 @@ def on_validation_epoch_end(self) -> None: samples_batch = self.generate_samples() if self.metrics_parameters.compute_energies: + logger.info(" * Computing sample energies") sample_energies = compute_oracle_energies(samples_batch) + logger.info(" * Registering sample energies") self.energy_ks_metric.register_predicted_samples(sample_energies.cpu()) ( @@ -394,13 +403,17 @@ def on_validation_epoch_end(self) -> None: self.log( "validation_ks_p_value_energy", p_value, on_step=False, on_epoch=True ) + logger.info(" * Done logging sample energies") if self.metrics_parameters.compute_structure_factor: + logger.info(" * Computing sample distances") sample_distances = compute_distances_in_batch( cartesian_positions=samples_batch[CARTESIAN_POSITIONS], unit_cell=samples_batch[UNIT_CELL], max_distance=self.metrics_parameters.structure_factor_max_distance, ) + + logger.info(" * Registering sample distances") self.structure_ks_metric.register_predicted_samples(sample_distances.cpu()) ( @@ -418,9 +431,22 @@ def on_validation_epoch_end(self) -> None: self.log( "validation_ks_p_value_structure", p_value, on_step=False, on_epoch=True ) + logger.info(" * Done logging sample distances") def on_validation_start(self) -> None: """On validation start.""" + logger.info("Clearing generator and metrics on validation start.") + # Clear out any dangling state. + self.generator = None + if self.metrics_parameters.compute_energies: + self.energy_ks_metric.reset() + + if self.metrics_parameters.compute_structure_factor: + self.structure_ks_metric.reset() + + def on_train_start(self) -> None: + """On train start.""" + logger.info("Clearing generator and metrics on train start.") # Clear out any dangling state. self.generator = None if self.metrics_parameters.compute_energies: From a6ba0937d7997649785fa11d18015e27be9b242c Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 23 Sep 2024 07:54:06 -0400 Subject: [PATCH 17/20] More distance bins. --- crystal_diffusion/callbacks/sampling_visualization_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crystal_diffusion/callbacks/sampling_visualization_callback.py b/crystal_diffusion/callbacks/sampling_visualization_callback.py index 7f12f537..ed7a0472 100644 --- a/crystal_diffusion/callbacks/sampling_visualization_callback.py +++ b/crystal_diffusion/callbacks/sampling_visualization_callback.py @@ -268,7 +268,7 @@ def _plot_distance_histogram( dmin = 0.0 dmax = maximum_distance + 0.1 - bins = np.linspace(dmin, dmax, 101) + bins = np.linspace(dmin, dmax, 251) fig.suptitle(f"Sampling Distances Distribution\nEpoch {epoch}") From 731a3bd06448023d9fbfecdc373955054bd1f1b4 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 23 Sep 2024 09:13:04 -0400 Subject: [PATCH 18/20] Example script to draw samples. --- examples/drawing_samples/draw_samples.py | 97 ++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 examples/drawing_samples/draw_samples.py diff --git a/examples/drawing_samples/draw_samples.py b/examples/drawing_samples/draw_samples.py new file mode 100644 index 00000000..d06b5567 --- /dev/null +++ b/examples/drawing_samples/draw_samples.py @@ -0,0 +1,97 @@ +"""Draw Samples. + +This script draws samples from a checkpoint. + +THIS SCRIPT IS AN EXAMPLE. IT SHOULD BE MODIFIED DEPENDING ON USER PREFERENCES. +""" +import logging +from pathlib import Path + +import numpy as np +import torch + +from crystal_diffusion.generators.instantiate_generator import \ + instantiate_generator +from crystal_diffusion.generators.predictor_corrector_position_generator import \ + PredictorCorrectorSamplingParameters +from crystal_diffusion.models.position_diffusion_lightning_model import \ + PositionDiffusionLightningModel +from crystal_diffusion.oracle.energies import compute_oracle_energies +from crystal_diffusion.samplers.variance_sampler import NoiseParameters +from crystal_diffusion.samples_and_metrics.sampling import \ + create_batch_of_samples +from crystal_diffusion.utils.logging_utils import setup_analysis_logger + +logger = logging.getLogger(__name__) +setup_analysis_logger() + +checkpoint_path = ("/network/scratch/r/rousseab/experiments/sept21_egnn_2x2x2/run4/" + "output/best_model/best_model-epoch=024-step=019550.ckpt") +samples_dir = Path( + "/network/scratch/r/rousseab/experiments/sept21_egnn_2x2x2/run4_samples/samples" +) +samples_dir.mkdir(exist_ok=True) + +device = torch.device("cuda") + + +spatial_dimension = 3 +number_of_atoms = 64 +atom_types = np.ones(number_of_atoms, dtype=int) + +acell = 10.86 +box = np.diag([acell, acell, acell]) + +number_of_samples = 128 +total_time_steps = 1000 +number_of_corrector_steps = 1 + +noise_parameters = NoiseParameters( + total_time_steps=total_time_steps, + corrector_step_epsilon=2e-7, + sigma_min=0.0001, + sigma_max=0.2, +) + +sampling_parameters = PredictorCorrectorSamplingParameters( + number_of_corrector_steps=number_of_corrector_steps, + spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + number_of_samples=number_of_samples, + cell_dimensions=[acell, acell, acell], + record_samples=True, +) + + +if __name__ == "__main__": + logger.info("Loading checkpoint...") + pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model.eval() + + sigma_normalized_score_network = pl_model.sigma_normalized_score_network + + logger.info("Instantiate generator...") + position_generator = instantiate_generator( + sampling_parameters=sampling_parameters, + noise_parameters=noise_parameters, + sigma_normalized_score_network=sigma_normalized_score_network, + ) + + logger.info("Drawing samples...") + with torch.no_grad(): + samples_batch = create_batch_of_samples( + generator=position_generator, + sampling_parameters=sampling_parameters, + device=device, + ) + + sample_output_path = str(samples_dir / "diffusion_samples.pt") + position_generator.sample_trajectory_recorder.write_to_pickle(sample_output_path) + logger.info("Done Generating Samples") + + logger.info("Compute energy from Oracle") + sample_energies = compute_oracle_energies(samples_batch) + + energy_output_path = str(samples_dir / "diffusion_energies.pt") + with open(energy_output_path, "wb") as fd: + torch.save(sample_energies, fd) From 561132e292d0cef05c3dbb879b7ecdf0491bc669 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Sep 2024 07:36:14 -0400 Subject: [PATCH 19/20] Removing excessive logging. --- .../models/position_diffusion_lightning_model.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index d9fbfa50..0765ca5b 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -277,7 +277,6 @@ def _get_target_normalized_score( def training_step(self, batch, batch_idx): """Runs a prediction step for training, returning the loss.""" - logger.info(f" - Starting training step with batch index {batch_idx}") output = self._generic_step(batch, batch_idx) loss = output["loss"] @@ -318,12 +317,10 @@ def validation_step(self, batch, batch_idx): return output if self.metrics_parameters.compute_energies: - logger.info(" * registering reference energies") reference_energies = batch["potential_energy"] self.energy_ks_metric.register_reference_samples(reference_energies.cpu()) if self.metrics_parameters.compute_structure_factor: - logger.info(" * registering reference distances") basis_vectors = torch.diag_embed(batch["box"]) cartesian_positions = get_positions_from_coordinates( relative_coordinates=batch[RELATIVE_COORDINATES], @@ -339,7 +336,6 @@ def validation_step(self, batch, batch_idx): reference_distances.cpu() ) - logger.info(f" Done validation step with batch index {batch_idx}") return output def test_step(self, batch, batch_idx): From 9510afa853f97e1f2ac854bcda63fd14c981e440 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Thu, 26 Sep 2024 07:39:53 -0400 Subject: [PATCH 20/20] Fix commit bjork. --- crystal_diffusion/models/position_diffusion_lightning_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/crystal_diffusion/models/position_diffusion_lightning_model.py b/crystal_diffusion/models/position_diffusion_lightning_model.py index 0765ca5b..0b2d741b 100644 --- a/crystal_diffusion/models/position_diffusion_lightning_model.py +++ b/crystal_diffusion/models/position_diffusion_lightning_model.py @@ -343,6 +343,7 @@ def test_step(self, batch, batch_idx): output = self._generic_step(batch, batch_idx) loss = output["loss"] batch_size = self._get_batch_size(batch) + # The 'test_epoch_loss' is aggregated (batch_size weighted average) and logged once per epoch. self.log( "test_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True