Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Open
wants to merge 106 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
32da243
4424
SumGuo-88 Dec 23, 2024
adf2315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2024
4f6f63d
issues4424-2
SumGuo-88 Dec 26, 2024
b9bac38
ll
SumGuo-88 Dec 26, 2024
1db3408
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
543a318
ll
SumGuo-88 Dec 26, 2024
ba72382
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
26f9a17
lll
SumGuo-88 Dec 26, 2024
25a803c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
dc64307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
8f962b5
allchange
SumGuo-88 Jan 2, 2025
b88e7fc
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
f57498d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
faeb7c5
test
SumGuo-88 Jan 2, 2025
725f1dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
ca7fc84
stat
SumGuo-88 Jan 2, 2025
394cf04
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
05128d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
4828619
check
SumGuo-88 Jan 2, 2025
37ccce4
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
c9406e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
2224f61
chec坑
SumGuo-88 Jan 2, 2025
ba12c2c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
9fcee84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
fe8579e
check3
SumGuo-88 Jan 2, 2025
f004dff
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
11138ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
a4a97a3
test
SumGuo-88 Jan 3, 2025
88566fe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
10e538d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
203dc4e
ttt
SumGuo-88 Jan 3, 2025
bb9fbe1
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
6a65561
t
SumGuo-88 Jan 3, 2025
603aee9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
1c103c4
d
SumGuo-88 Jan 3, 2025
4173040
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
e3a1c9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
533e95e
ll
SumGuo-88 Jan 3, 2025
38dc18c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
714c197
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
6713c1a
last
SumGuo-88 Jan 4, 2025
e42e38d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
1c15cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
33c716d
q
SumGuo-88 Jan 4, 2025
6d38b94
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
6bbced8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
b462c97
ll
SumGuo-88 Jan 4, 2025
0d7154c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
0c7baa0
ll
SumGuo-88 Jan 4, 2025
28d94af
ll
SumGuo-88 Jan 4, 2025
87dcd66
l
SumGuo-88 Jan 4, 2025
5d33060
ll
SumGuo-88 Jan 4, 2025
521e3a6
ll
SumGuo-88 Jan 4, 2025
379d4ad
Merge branch 'devel' into devel
SumGuo-88 Jan 5, 2025
0dabf77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
a23528c
Update stat.py
SumGuo-88 Jan 5, 2025
0a97b54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
49744ed
Update deepmd/pt/utils/stat.py
SumGuo-88 Jan 6, 2025
556a684
Simplify logic and remove "not"
SumGuo-88 Jan 6, 2025
aa2633d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
27999af
check import
SumGuo-88 Jan 6, 2025
83b7f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
817d2ec
Add assert to ensure that the new frame contains the required elements
SumGuo-88 Jan 6, 2025
234e461
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
93a748f
check import
SumGuo-88 Jan 6, 2025
4a38f1d
check import
SumGuo-88 Jan 6, 2025
78b2a10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
6a5d169
check test.py
SumGuo-88 Jan 6, 2025
26205d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
3ccb4b9
check ut
SumGuo-88 Jan 6, 2025
f669ac5
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
87de0e8
check ut
SumGuo-88 Jan 6, 2025
0939ef1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
7ec779f
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
24d1386
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
708bc78
check msi defalut value
SumGuo-88 Jan 7, 2025
02f3f28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
c648e9e
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
d36a24a
check ut cuda
SumGuo-88 Jan 7, 2025
b6a483a
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
050dbaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
b00f8de
check ut
SumGuo-88 Jan 7, 2025
2f37dfe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
47fe45b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
98890c2
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
b9bdee5
make truetype for more sys
SumGuo-88 Jan 9, 2025
a30053f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
cfbc88a
Add skip element check function to Chang bias
SumGuo-88 Jan 9, 2025
73a20b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
0b29b05
make changebias control minframes
SumGuo-88 Jan 10, 2025
85d4da3
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
0400233
check merge
SumGuo-88 Jan 10, 2025
c05ffb1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
139f037
improve ut with all frames
SumGuo-88 Jan 10, 2025
5e826bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
edf1d91
check ut
SumGuo-88 Jan 10, 2025
3887013
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
eb9f068
check
SumGuo-88 Jan 10, 2025
10ef768
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
c2dc7ef
check skip logic and def name
SumGuo-88 Jan 10, 2025
8763165
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
d5596bf
improve warning readable
SumGuo-88 Jan 10, 2025
9f389ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
0c76ad9
check args
SumGuo-88 Jan 10, 2025
58647f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
4ce9cfb
check stat.py
SumGuo-88 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def __init__(
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
self.display_in_training = training_params.get("disp_training", True)
self.timing_in_training = training_params.get("time_training", True)
self.min_frames_per_element_forstat = training_params.get(
"min_frames_per_element_forstat", 10
)
self.change_bias_after_training = training_params.get(
"change_bias_after_training", False
)
Expand Down Expand Up @@ -226,6 +229,7 @@ def get_sample():
_training_data.systems,
_training_data.dataloaders,
_data_stat_nbatch,
self.min_frames_per_element_forstat,
)
return sampled

Expand Down
32 changes: 30 additions & 2 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


import glob
import os
from collections import (
defaultdict,
)
from typing import (
Optional,
)

import numpy as np
from torch.utils.data import (
Dataset,
)
Expand All @@ -13,14 +19,17 @@
DataRequirementItem,
DeepmdData,
)
from deepmd.utils.path import (
DPPath,
)


class DeepmdDataSetForLoader(Dataset):
def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None:
"""Construct DeePMD-style dataset containing frames cross different systems.
"""Construct DeePMD-style dataset containing frames across different systems.

Args:
- systems: Paths to systems.
- system: Path to the system.
- type_map: Atom types.
"""
self.system = system
Expand All @@ -40,6 +49,25 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def true_types(self):
"""Identify and count unique element types present in the dataset,
and count the number of frames each element appears in.
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
"""
element_counts = defaultdict(lambda: {"count": 0, "frames": 0})
set_pattern = os.path.join(self.system, "set.*")
set_files = sorted(glob.glob(set_pattern))
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
for set_file in set_files:
element_data = self._data_system._load_type_mix(DPPath(set_file))
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
unique_elements, counts = np.unique(element_data, return_counts=True)
for elem, cnt in zip(unique_elements, counts):
element_counts[elem]["count"] += cnt
for elem in unique_elements:
frames_with_elem = np.any(element_data == elem, axis=1)
row_count = np.sum(frames_with_elem)
element_counts[elem]["frames"] += row_count
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
element_counts = dict(element_counts)
return element_counts

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
162 changes: 131 additions & 31 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
log = logging.getLogger(__name__)


def make_stat_input(datasets, dataloaders, nbatches):
def make_stat_input(datasets, dataloaders, nbatches, min_frames_per_element_forstat):
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
"""Pack data for statistics.

Args:
Expand All @@ -50,38 +50,138 @@ def make_stat_input(datasets, dataloaders, nbatches):
"""
lst = []
log.info(f"Packing data for statistics from {len(datasets)} systems")
for i in range(len(datasets)):
sys_stat = {}
with torch.device("cpu"):
iterator = iter(dataloaders[i])
numb_batches = min(nbatches, len(dataloaders[i]))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
collect_elements = set()
total_element_types = set()
total_element_counts = {}
if datasets[0].mixed_type:
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
sys_stat = {}
with torch.device("cpu"):
iterator = iter(dataloader)
numb_batches = min(nbatches, len(dataloader))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloader)
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
if "atype" in sys_stat and isinstance(sys_stat["atype"], list):
collect_values = np.unique(torch.cat(sys_stat["atype"]).numpy())
collect_elements.update(collect_values)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or (
isinstance(sys_stat[key], list)
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
):
sys_stat[key] = None
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)
element_counts = dataset.true_types()
for elem, data in element_counts.items():
count = data["count"]
frames = data["frames"]
total_element_types.add(elem)
if elem not in total_element_counts:
total_element_counts[elem] = {
"count": 0,
"frames": 0,
"indices": [],
}
total_element_counts[elem]["count"] += count
if (
len(total_element_counts[elem]["indices"])
< min_frames_per_element_forstat
):
total_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": frames}
)
for elem, data in total_element_counts.items():
count = data["count"]
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
indices_count = len(data["indices"])
if indices_count < min_frames_per_element_forstat:
log.warning(
f"The number of frame with element {elem} is {indices_count}, which is less than the expected maximum value {min_frames_per_element_forstat}"
)
missing_elements = total_element_types - collect_elements
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
for miss in missing_elements:
sys_indices = total_element_counts[miss].get("indices", [])
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
sys = datasets[sys_index]
frame_data = sys.__getitem__(frames)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
elif isinstance(frame_data[dd], np.ndarray):
if dd not in sys_stat_new:
sys_stat_new[dd] = []
frame_data[dd] = torch.from_numpy(frame_data[dd])
frame_data[dd] = frame_data[dd].unsqueeze(0)
sys_stat_new[dd].append(frame_data[dd])
elif isinstance(frame_data[dd], np.float32):
sys_stat_new[dd] = frame_data[dd]
else:
pass

for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:
sys_stat[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)
for key in sys_stat_new:
if isinstance(sys_stat_new[key], np.float32):
pass
elif sys_stat_new[key] is None or sys_stat_new[key][0] is None:
sys_stat_new[key] = None
elif isinstance(frame_data[dd], torch.Tensor):
sys_stat_new[key] = torch.cat(sys_stat_new[key], dim=0)
dict_to_device(sys_stat_new)
lst.append(sys_stat_new)
else:
for i in range(len(datasets)):
sys_stat = {}
with torch.device("cpu"):
iterator = iter(dataloaders[i])
numb_batches = min(nbatches, len(dataloaders[i]))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:
sys_stat[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
lst.append(sys_stat)
return lst


Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,6 +2826,12 @@ def training_args(
optional=True,
doc=doc_only_pt_supported + doc_gradient_max_norm,
),
Argument(
"min_frames_per_element_forstat",
int,
optional=True,
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
doc="The minimum number of frames per element used for statistics.",
),
]
variants = [
Variant(
Expand Down
16 changes: 8 additions & 8 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,14 +530,6 @@ def _load_set(self, set_name: DPPath):
if self.mixed_type:
# nframes x natoms
atom_type_mix = self._load_type_mix(set_name)
if self.enforce_type_map:
try:
atom_type_mix_ = self.type_idx_map[atom_type_mix].astype(np.int32)
except IndexError as e:
raise IndexError(
f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
) from e
atom_type_mix = atom_type_mix_
real_type = atom_type_mix.reshape([nframes, self.natoms])
data["type"] = real_type
natoms = data["type"].shape[1]
Expand Down Expand Up @@ -672,6 +664,14 @@ def _load_type(self, sys_path: DPPath):
def _load_type_mix(self, set_name: DPPath):
type_path = set_name / "real_atom_types.npy"
real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms])
if self.enforce_type_map:
try:
atom_type_mix_ = self.type_idx_map[real_type].astype(np.int32)
except IndexError as e:
raise IndexError(
f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
) from e
real_type = atom_type_mix_
return real_type

def _make_idx_map(self, atom_type):
Expand Down
87 changes: 87 additions & 0 deletions source/tests/pt/test_make_stat_input.py
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from collections import (
defaultdict,
)

import torch
from torch.utils.data import (
DataLoader,
)

from deepmd.pt.utils.stat import (
make_stat_input,
)


class TestDataset:
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, samples):
self.samples = samples
self.element_to_frames = defaultdict(list)
self.mixed_type = True
for idx, sample in enumerate(samples):
atypes = sample["atype"]
for atype in atypes:
self.element_to_frames[atype].append(idx)

@property
def get_all_atype(self):
return set(self.element_to_frames.keys())

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
sample = self.samples[idx]
return {
"atype": torch.tensor(sample["atype"], dtype=torch.long),
"energy": torch.tensor(sample["energy"], dtype=torch.float32),
}

def true_types(self):
element_counts = defaultdict(lambda: {"count": 0, "frames": 0})
for idx, sample in enumerate(self.samples):
atypes = sample["atype"]
unique_atypes = set(atypes)
for atype in atypes:
element_counts[atype]["count"] += 1
for atype in unique_atypes:
element_counts[atype]["frames"] += 1
return dict(element_counts)


class TestMakeStatInput(unittest.TestCase):
def setUp(self):
self.system = TestDataset(
[
{"atype": [1], "energy": -1.0},
{"atype": [2], "energy": -2.0},
]
)
self.datasets = [self.system]
self.dataloaders = [
DataLoader(self.system, batch_size=1, shuffle=False),
]

def test_make_stat_input(self):
nbatches = 1
lst = make_stat_input(
self.datasets,
self.dataloaders,
nbatches=nbatches,
min_frames_per_element_forstat=1,
)
all_elements = self.system.get_all_atype
unique_elements = {1, 2}
self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements")

expected_true_types = {
1: {"count": 1, "frames": 1},
2: {"count": 1, "frames": 1},
}
actual_true_types = self.system.true_types()
self.assertEqual(expected_true_types, actual_true_types, "true_types is wrong")


if __name__ == "__main__":
unittest.main()
Loading