Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regularizers #114

Merged
merged 27 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0907f02
Utility class to compute sigma and its derivative.
rousseab Dec 27, 2024
da4040c
A utility module to only compute sigma(t).
rousseab Dec 27, 2024
e05e81f
A dummy score network where we can compute the analytical derivatives…
rousseab Dec 28, 2024
674d4a6
A score fokker planck regularizer.
rousseab Dec 28, 2024
b83d7a6
A method to compute the loss contribution from FP.
rousseab Dec 28, 2024
d08b382
Pick up regularizer parameters from config.
rousseab Dec 28, 2024
0193a30
Add an option for the exact laplacian calculation.
rousseab Dec 29, 2024
e190018
Embed the relative coordinates ensuring periodicity in the MLP.
rousseab Dec 29, 2024
e99369b
Fixed broken tests because of new config parameter.
rousseab Dec 29, 2024
11d0332
Make regularizers more generic and configurable, so we can define dif…
rousseab Jan 2, 2025
08d6768
init file for regularizer folder.
rousseab Jan 2, 2025
e241386
Better separation of concerns.
rousseab Jan 2, 2025
a42ba3e
Implement an Oracle regularizer.
rousseab Jan 2, 2025
2ca0873
A useful script to create videos of vector fields.
rousseab Jan 3, 2025
ea38b6e
Modify the predictor corrector generator to be able to generate parti…
rousseab Jan 4, 2025
e2fbf49
A new consistency-based regularizer.
rousseab Jan 6, 2025
06e7dc6
Fix issue.
rousseab Jan 6, 2025
9a3c9e2
Fix typo.
rousseab Jan 6, 2025
57e9954
Always pick a large enough 'random time' so that we can make full num…
rousseab Jan 6, 2025
6ece0fb
Optionally pass in an analytical score network to draw trajectories.
rousseab Jan 7, 2025
2e347e5
Some analysis and plotting.
rousseab Jan 7, 2025
05f2d4b
Add permutation equivariance to the MLP model.
rousseab Jan 7, 2025
e4ef12c
Modify the sampling script so that we can pass an ad-hoc score networ…
rousseab Jan 8, 2025
c81d62a
Various scripts to run and analyse preliminary regularization experim…
rousseab Jan 8, 2025
b7677ee
Fix docstring.
rousseab Jan 8, 2025
cc5a2be
Fix MLP error.
rousseab Jan 8, 2025
f07d879
Fix error in comment.
rousseab Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions experiments/analysis/visualize_score_vector_field_2D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import glob
import subprocess
import tempfile
from pathlib import Path

import einops
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from diffusion_for_multi_scale_molecular_dynamics.analysis import \
PLOT_STYLE_PATH
from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import (
AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters)
from diffusion_for_multi_scale_molecular_dynamics.namespace import (
AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL)
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \
VarianceScheduler
from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \
NoiseParameters
from diffusion_for_multi_scale_molecular_dynamics.sample_diffusion import \
get_axl_network
from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \
map_relative_coordinates_to_unit_cell
from experiments.two_atoms_in_one_dimension.utils import \
get_2d_vector_field_figure

plt.style.use(PLOT_STYLE_PATH)


base_path = Path(
"/Users/brunorousseau/courtois/local_experiments/egnn_small_regularizer_orion/orion_working_dir/"
"785ed337118e5c748ca70517ff8569ee/last_model"
)

output_path = Path("/Users/brunorousseau/courtois/local_experiments/videos")

output_name = "785ed337118e5c748ca70517ff8569ee"


sigma_min = 0.001
sigma_max = 0.2
sigma_d = 0.01

n1 = 100
n2 = 100
nt = 100

spatial_dimension = 1
x0_1 = 0.25
x0_2 = 0.75

if __name__ == "__main__":
noise_parameters = NoiseParameters(
total_time_steps=nt, sigma_min=sigma_min, sigma_max=sigma_max
)

score_network_parameters = AnalyticalScoreNetworkParameters(
number_of_atoms=2,
num_atom_types=1,
kmax=5,
equilibrium_relative_coordinates=[[x0_1], [x0_2]],
sigma_d=sigma_d,
spatial_dimension=spatial_dimension,
use_permutation_invariance=True,
)

analytical_score_network = AnalyticalScoreNetwork(score_network_parameters)

checkpoint_path = glob.glob(str(base_path / "**/last_model*.ckpt"), recursive=True)[
0
]
checkpoint_name = Path(checkpoint_path).name
axl_network = get_axl_network(checkpoint_path)

list_times = torch.linspace(0.0, 1.0, nt)
list_sigmas = VarianceScheduler(noise_parameters).get_sigma(list_times).numpy()

x1 = torch.linspace(0, 1, n1)
x2 = torch.linspace(0, 1, n2)

X1, X2_ = torch.meshgrid(x1, x2, indexing="xy")
X2 = torch.flip(X2_, dims=[0])

relative_coordinates = einops.repeat(
[X1, X2], "natoms n1 n2 -> (n1 n2) natoms space", space=spatial_dimension
).contiguous()
relative_coordinates = map_relative_coordinates_to_unit_cell(relative_coordinates)

forces = torch.zeros_like(relative_coordinates)
batch_size, natoms, _ = relative_coordinates.shape

atom_types = torch.ones(batch_size, natoms, dtype=torch.int64)

list_ground_truth_probabilities = []
list_sigma_normalized_scores = []
for time, sigma in tqdm(zip(list_times, list_sigmas), "SIGMAS"):
grid_sigmas = sigma * torch.ones_like(relative_coordinates)
flat_probabilities, flat_normalized_scores = (
analytical_score_network.get_probabilities_and_normalized_scores(
relative_coordinates, grid_sigmas
)
)
probabilities = einops.rearrange(
flat_probabilities, "(n1 n2) -> n1 n2", n1=n1, n2=n2
)
list_ground_truth_probabilities.append(probabilities)

sigma_t = sigma * torch.ones(batch_size, 1)
times = time * torch.ones(batch_size, 1)
unit_cell = torch.ones(batch_size, 1, 1)

composition = AXL(
A=atom_types,
X=relative_coordinates,
L=torch.zeros_like(relative_coordinates),
)

batch = {
NOISY_AXL_COMPOSITION: composition,
NOISE: sigma_t,
TIME: times,
UNIT_CELL: unit_cell,
CARTESIAN_FORCES: forces,
}

model_predictions = axl_network(batch)
sigma_normalized_scores = einops.rearrange(
model_predictions.X.detach(),
"(n1 n2) natoms space -> n1 n2 natoms space",
n1=n1,
n2=n2,
)

list_sigma_normalized_scores.append(sigma_normalized_scores)

sigma_normalized_scores = torch.stack(list_sigma_normalized_scores).squeeze(-1)
ground_truth_probabilities = torch.stack(list_ground_truth_probabilities)

# ================================================================================

s = 2
with tempfile.TemporaryDirectory() as tmpdirname:

tmp_dir = Path(tmpdirname)

for time_idx in tqdm(range(len(list_times)), "VIDEO"):
sigma_t = list_sigmas[time_idx]
time = list_times[time_idx].item()

fig = get_2d_vector_field_figure(
X1=X1,
X2=X2,
probabilities=ground_truth_probabilities[time_idx],
sigma_normalized_scores=sigma_normalized_scores[time_idx],
time=time,
sigma_t=sigma_t,
sigma_d=sigma_d,
supsampling_scale=s,
)

output_image = tmp_dir / f"vector_field_{time_idx}.png"
fig.savefig(output_image)
plt.close(fig)

output_path.mkdir(parents=True, exist_ok=True)
output_file_path = output_path / f"vector_field_{output_name}.mp4"

# ffmpeg -r 10 -start_number 0 -i vector_field_%d.png -vcodec libx264 -pix_fmt yuv420p mlp_vector_field.mp4
commands = [
"ffmpeg",
"-r",
"10",
"-start_number",
"0",
"-i",
str(tmp_dir / "vector_field_%d.png"),
"-vcodec",
"libx264",
"-pix_fmt",
"yuv420p",
str(output_file_path),
]

process = subprocess.run(commands, capture_output=True, text=True)
6 changes: 6 additions & 0 deletions experiments/regularization_toy_problem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pathlib import Path

EXPERIMENTS_DIR = Path(__file__).parent / "experiments"

RESULTS_DIR: Path = Path(__file__).parent / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#================================================================================
# Configuration file for a diffusion experiment for 2 pseudo-atoms in 1D.
#================================================================================
exp_name: mlp
run_name: consistency_regularizer
max_epoch: 1000
log_every_n_steps: 1
gradient_clipping: 0.0
accumulate_grad_batches: 1

elements: [A]

# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
seed: 1234

# Data: a fake dataloader will recreate the same example over and over.
data:
data_source: gaussian
random_seed: 42
number_of_atoms: 2
sigma_d: 0.01
equilibrium_relative_coordinates:
- [0.25]
- [0.75]

train_dataset_size: 8_192
valid_dataset_size: 1_024

batch_size: 64 # batch size for everyone
num_workers: 0
max_atom: 2
spatial_dimension: 1

# architecture
spatial_dimension: 1

model:
loss:
coordinates_algorithm: mse
atom_types_ce_weight: 0.0
atom_types_lambda_weight: 0.0
relative_coordinates_lambda_weight: 1.0
lattice_lambda_weight: 0.0
score_network:
architecture: mlp
use_permutation_invariance: True
spatial_dimension: 1
number_of_atoms: 2
num_atom_types: 1
n_hidden_dimensions: 3
hidden_dimensions_size: 64
relative_coordinates_embedding_dimensions_size: 32
noise_embedding_dimensions_size: 16
time_embedding_dimensions_size: 16
atom_type_embedding_dimensions_size: 8
condition_embedding_size: 8
noise:
total_time_steps: 100
sigma_min: 0.001
sigma_max: 0.2

# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
weight_decay: 5.0e-8



regularizer:
type: consistency
maximum_number_of_steps: 5
number_of_burn_in_epochs: 0
regularizer_lambda_weight: 0.001

noise:
total_time_steps: 100
sigma_min: 0.001
sigma_max: 0.2

sampling:
num_atom_types: 1
number_of_atoms: 2
number_of_samples: 64
spatial_dimension: 1
number_of_corrector_steps: 0
cell_dimensions: [1.0]


scheduler:
name: CosineAnnealingLR
T_max: 1000
eta_min: 0.0

# early stopping
early_stopping:
metric: validation_epoch_loss
mode: min
patience: 1000

model_checkpoint:
monitor: validation_epoch_loss
mode: min

score_viewer:
record_every_n_epochs: 1

score_viewer_parameters:
sigma_min: 0.001
sigma_max: 0.2
number_of_space_steps: 100
starting_relative_coordinates:
- [0.0]
- [1.0]
ending_relative_coordinates:
- [1.0]
- [0.0]
analytical_score_network:
architecture: "analytical"
spatial_dimension: 1
number_of_atoms: 2
num_atom_types: 1
kmax: 5
equilibrium_relative_coordinates:
- [0.25]
- [0.75]
sigma_d: 0.01
use_permutation_invariance: True

logging:
- tensorboard
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash

export OMP_PATH="/opt/homebrew/opt/libomp/include/"
export PYTORCH_ENABLE_MPS_FALLBACK=1


CONFIG=config.yaml

OUTPUT=./output/run1

SRC=/Users/brunorousseau/PycharmProjects/diffusion_for_multi_scale_molecular_dynamics/src/diffusion_for_multi_scale_molecular_dynamics


python ${SRC}/train_diffusion.py --accelerator "cpu" --config $CONFIG --output $OUTPUT
Loading
Loading