Skip to content

Commit

Permalink
update .npy files loading procedure in DefaultDataset (deepmodeling#18
Browse files Browse the repository at this point in the history
)
  • Loading branch information
SharpLonde authored Jan 23, 2024
1 parent a9e3685 commit 3e07ce2
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions dptb/data/dataset/_default_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,13 @@ def __init__(self,
assert os.path.exists(os.path.join(self.root, "kpoints.npy"))
kpoints = np.load(os.path.join(self.root, "kpoints.npy"))
if kpoints.ndim == 2:
# same kpoints, then copy it to all frames.
if kpoints.shape[0] == self.info["bandinfo"]["nkpoints"]:
kpoints = np.expand_dims(kpoints, axis=0)
self.data["kpoints"] = np.broadcast_to(kpoints, (self.info["nframes"],
self.info["bandinfo"]["nkpoints"], 3))
else:
raise ValueError("kpoints in `.npy` file not equal to nkpoints in bandinfo. ")
elif atomic_numbers.shape[0] == self.info["nframes"]:
# only one frame or same kpoints, then copy it to all frames.
# shape: (nkpoints, 3)
kpoints = np.expand_dims(kpoints, axis=0)
self.data["kpoints"] = np.broadcast_to(kpoints, (self.info["nframes"],
kpoints.shape[1], 3))
if kpoints.shape[0] == self.info["nframes"]:
# array of kpoints, (nframes, nkpoints, 3)
self.data["kpoints"] = kpoints
else:
raise ValueError("Wrong kpoint dimensions.")
Expand All @@ -107,12 +106,8 @@ def __init__(self,
if eigenvalues.ndim == 2:
eigenvalues = np.expand_dims(eigenvalues, axis=0)
assert eigenvalues.shape[0] == self.info["nframes"]
assert eigenvalues.shape[1] == self.info["bandinfo"]["nkpoints"]
assert eigenvalues.shape[2] == self.info["bandinfo"]["nbands"]
self.data["eigenvalues"] = eigenvalues
#self.data["eigenvalues"] = eigenvalues.reshape(self.info["nframes"],
# self.info["bandinfo"]["nkpoints"],
# self.info["bandinfo"]["nbands"])
assert eigenvalues.shape[1] == self.data["kpoints"].shape[1]
self.data["eigenvalues"] = eigenvalues
if os.path.exists(os.path.join(self.root, "hamiltonians.h5")) and get_Hamiltonian==True:
self.data["hamiltonian_blocks"] = h5py.File(os.path.join(self.root, "hamiltonians.h5"), "r")
if os.path.exists(os.path.join(self.root, "overlaps.h5")):
Expand Down

0 comments on commit 3e07ce2

Please sign in to comment.