Skip to content

Commit

Permalink
Merge pull request #221 from amcadmus/master
Browse files Browse the repository at this point in the history
Merge bug fixings to master
  • Loading branch information
wanghan-iapcm authored Nov 17, 2021
2 parents 1d1084d + 6569780 commit f242cbe
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 14 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/test_import.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: test Python import

on:
- push
- pull_request

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.9'
architecture: 'x64'
- run: python -m pip install .
- run: python -c 'import dpdata'

13 changes: 8 additions & 5 deletions dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ def _cond_load_data(fname) :
tmp = np.load(fname)
return tmp

def _load_set(folder) :
cells = np.load(os.path.join(folder, 'box.npy'))
def _load_set(folder, nopbc: bool) :
coords = np.load(os.path.join(folder, 'coord.npy'))
if nopbc:
cells = np.zeros((coords.shape[0], 3,3))
else:
cells = np.load(os.path.join(folder, 'box.npy'))
eners = _cond_load_data(os.path.join(folder, 'energy.npy'))
forces = _cond_load_data(os.path.join(folder, 'force.npy'))
virs = _cond_load_data(os.path.join(folder, 'virial.npy'))
Expand All @@ -22,14 +25,16 @@ def to_system_data(folder,
# data is empty
data = load_type(folder, type_map = type_map)
data['orig'] = np.zeros([3])
if os.path.isfile(os.path.join(folder, "nopbc")):
data['nopbc'] = True
sets = sorted(glob.glob(os.path.join(folder, 'set.*')))
all_cells = []
all_coords = []
all_eners = []
all_forces = []
all_virs = []
for ii in sets :
cells, coords, eners, forces, virs = _load_set(ii)
cells, coords, eners, forces, virs = _load_set(ii, data.get('nopbc', False))
nframes = np.reshape(cells, [-1,3,3]).shape[0]
all_cells.append(np.reshape(cells, [nframes,3,3]))
all_coords.append(np.reshape(coords, [nframes,-1,3]))
Expand All @@ -50,8 +55,6 @@ def to_system_data(folder,
data['forces'] = np.concatenate(all_forces, axis = 0)
if len(all_virs) > 0:
data['virials'] = np.concatenate(all_virs, axis = 0)
if os.path.isfile(os.path.join(folder, "nopbc")):
data['nopbc'] = True
return data


Expand Down
11 changes: 7 additions & 4 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ def to_system_data(folder, type_map = None, labels = True) :
if os.path.isdir(folder) :
data = load_type(folder, type_map = type_map)
data['orig'] = np.zeros([3])
data['cells'] = np.loadtxt(os.path.join(folder, 'box.raw'))
data['coords'] = np.loadtxt(os.path.join(folder, 'coord.raw'))
data['cells'] = np.reshape(data['cells'], [-1, 3, 3])
nframes = data['cells'].shape[0]
data['coords'] = np.loadtxt(os.path.join(folder, 'coord.raw'), ndmin=2)
nframes = data['coords'].shape[0]
if os.path.isfile(os.path.join(folder, "nopbc")):
data['nopbc'] = True
data['cells'] = np.zeros((nframes, 3,3))
else:
data['cells'] = np.loadtxt(os.path.join(folder, 'box.raw'), ndmin=2)
data['cells'] = np.reshape(data['cells'], [nframes, 3, 3])
data['coords'] = np.reshape(data['coords'], [nframes, -1, 3])
if labels :
Expand Down
5 changes: 4 additions & 1 deletion dpdata/pymatgen/molecule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
from pymatgen.core import Molecule
try:
from pymatgen.core import Molecule
except ImportError:
pass
from collections import Counter
import dpdata

Expand Down
9 changes: 5 additions & 4 deletions tests/comp_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def test_nframs(self):
def test_cell(self):
self.assertEqual(self.system_1.get_nframes(),
self.system_2.get_nframes())
np.testing.assert_almost_equal(self.system_1.data['cells'],
self.system_2.data['cells'],
decimal = self.places,
err_msg = 'cell failed')
if not self.system_1.nopbc and not self.system_2.nopbc:
np.testing.assert_almost_equal(self.system_1.data['cells'],
self.system_2.data['cells'],
decimal = self.places,
err_msg = 'cell failed')

def test_coord(self):
self.assertEqual(self.system_1.get_nframes(),
Expand Down

0 comments on commit f242cbe

Please sign in to comment.