Skip to content

Commit

Permalink
optimize mem use for cpu-sharemem and gpu.
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
Tong Jiang committed Jan 9, 2024
1 parent 8bab365 commit 46e8063
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 27 deletions.
20 changes: 12 additions & 8 deletions ipie/hamiltonians/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class GenericRealChol(GenericBase):
Can be created by passing the one and two electron integrals directly.
"""

def __init__(self, h1e, chol, ecore=0.0, verbose=False):
def __init__(self, h1e, chol, ecore=0.0, shmem=False, chol_packed=None, verbose=False):
assert (
h1e.shape[0] == 2
) # assuming each spin component is given. this should be fixed for GHF...?
Expand All @@ -61,14 +61,18 @@ def __init__(self, h1e, chol, ecore=0.0, verbose=False):
self.nfields = self.nchol
assert self.nbasis**2 == chol.shape[0]

self.chol = self.chol.reshape((self.nbasis, self.nbasis, self.nchol))
self.sym_idx = numpy.triu_indices(self.nbasis)
self.sym_idx_i = self.sym_idx[0].copy()
self.sym_idx_j = self.sym_idx[1].copy()
cp_shape = (self.nbasis * (self.nbasis + 1) // 2, self.chol.shape[-1])
self.chol_packed = numpy.zeros(cp_shape, dtype=self.chol.dtype)
pack_cholesky(self.sym_idx[0], self.sym_idx[1], self.chol_packed, self.chol)
self.chol = self.chol.reshape((self.nbasis * self.nbasis, self.nchol))
if not shmem:
self.chol = self.chol.reshape((self.nbasis, self.nbasis, self.nchol))
cp_shape = (self.nbasis * (self.nbasis + 1) // 2, self.chol.shape[-1])
self.chol_packed = numpy.zeros(cp_shape, dtype=self.chol.dtype)
pack_cholesky(self.sym_idx[0], self.sym_idx[1], self.chol_packed, self.chol)
self.chol = self.chol.reshape((self.nbasis * self.nbasis, self.nchol))
else:
self.chol = chol
self.chol_packed = chol_packed

self.chunked = False

Expand Down Expand Up @@ -146,11 +150,11 @@ def hijkl(self, i, j, k, l): # (ik|jl) somehow physicist notation - terrible!!
return numpy.dot(chol_ik, chol_lj.conj())


def Generic(h1e, chol, ecore=0.0, verbose=False):
def Generic(h1e, chol, ecore=0.0, shmem=False, chol_packed=None, verbose=False):
if chol.dtype == numpy.dtype("complex128"):
return GenericComplexChol(h1e, chol, ecore, verbose)
elif chol.dtype == numpy.dtype("float64"):
return GenericRealChol(h1e, chol, ecore, verbose)
return GenericRealChol(h1e, chol, ecore, shmem, chol_packed, verbose)


def read_integrals(integral_file):
Expand Down
10 changes: 9 additions & 1 deletion ipie/hamiltonians/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def get_hamiltonian(filename, scomm, verbose=False, pack_chol=True):
if scomm.rank == 0 and pack_chol:
pack_cholesky(idx[0], idx[1], chol_packed, chol)
scomm.Barrier()
chol_pack_shmem = get_shared_array(scomm, shape, dtype)
if scomm.rank == 0:
chol_pack_shmem[:] = chol_packed[:]
else:
dtype = chol.dtype
cp_shape = (nbsf * (nbsf + 1) // 2, nchol)
Expand All @@ -84,7 +87,12 @@ def get_hamiltonian(filename, scomm, verbose=False, pack_chol=True):
if verbose:
print(f"# Time to pack Cholesky vectors: {time.time() - start:.6f}")

ham = Generic(h1e=hcore, chol=chol, ecore=enuc, verbose=verbose)
if shmem and pack_chol:
ham = Generic(
h1e=hcore, chol=chol, ecore=enuc, shmem=True, chol_packed=chol_packed, verbose=verbose
)
else:
ham = Generic(h1e=hcore, chol=chol, ecore=enuc, verbose=verbose)

return ham

Expand Down
2 changes: 1 addition & 1 deletion ipie/propagation/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def apply_exponential_batch(phi, VHS, exp_nmax):
xp.copyto(Temp, phi)
if config.get_option("use_gpu"):
for n in range(1, exp_nmax + 1):
Temp = xp.einsum("wik,wkj->wij", VHS, Temp, optimize=True) / n
Temp = xp.matmul(VHS, Temp) / n # matmul use much less GPU memory than einsum
phi += Temp
else:
for iw in range(phi.shape[0]):
Expand Down
14 changes: 9 additions & 5 deletions ipie/propagation/phaseless_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def construct_one_body_propagator(hamiltonian: GenericRealChol, mf_shift: xp.nda
Timestep.
"""
nb = hamiltonian.nbasis
shift = 1j * numpy.einsum("mx,x->m", hamiltonian.chol, mf_shift).reshape(nb, nb)
if hasattr(mf_shift, "get"):
shift = 1j * numpy.einsum("mx,x->m", hamiltonian.chol, mf_shift.get()).reshape(nb, nb)
else:
shift = 1j * numpy.einsum("mx,x->m", hamiltonian.chol, mf_shift).reshape(nb, nb)
shift = xp.array(shift)
H1 = hamiltonian.h1e_mod - xp.array([shift, shift])
if hasattr(H1, "get"):
H1_numpy = H1.get()
Expand All @@ -45,9 +49,8 @@ def construct_one_body_propagator(hamiltonian: GenericRealChol, mf_shift: xp.nda
def construct_one_body_propagator(hamiltonian: GenericComplexChol, mf_shift: xp.ndarray, dt: float):
nb = hamiltonian.nbasis
nchol = hamiltonian.nchol
shift = xp.zeros((nb, nb), dtype=hamiltonian.chol.dtype)
shift = numpy.zeros((nb, nb), dtype=hamiltonian.chol.dtype)
shift = 1j * numpy.einsum("mx,x->m", hamiltonian.A, mf_shift[:nchol]).reshape(nb, nb)

shift += 1j * numpy.einsum("mx,x->m", hamiltonian.B, mf_shift[nchol:]).reshape(nb, nb)

H1 = hamiltonian.h1e_mod - numpy.array([shift, shift])
Expand All @@ -68,8 +71,9 @@ def construct_mean_field_shift(hamiltonian: GenericRealChol, trial: TrialWavefun
"""
# hamiltonian.chol [X, M^2]
Gcharge = (trial.G[0] + trial.G[1]).ravel()
tmp_real = xp.dot(hamiltonian.chol.T, Gcharge.real)
tmp_imag = xp.dot(hamiltonian.chol.T, Gcharge.imag)
# Use numpy to reduce GPU memory use at this point, otherwise will be a problem of large chol cases
tmp_real = numpy.dot(hamiltonian.chol.T, Gcharge.real)
tmp_imag = numpy.dot(hamiltonian.chol.T, Gcharge.imag)
mf_shift = 1.0j * tmp_real - tmp_imag
return xp.array(mf_shift)

Expand Down
14 changes: 9 additions & 5 deletions ipie/propagation/phaseless_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,14 @@ def construct_VHS(self, hamiltonian: GenericBase, xshifted: xp.ndarray) -> xp.nd
def construct_VHS(self, hamiltonian: GenericRealChol, xshifted: xp.ndarray) -> xp.ndarray:
nwalkers = xshifted.shape[-1]

VHS_packed = hamiltonian.chol_packed.dot(
xshifted.real
) + 1.0j * hamiltonian.chol_packed.dot(xshifted.imag)
VHS_packed = hamiltonian.chol_packed.dot(xshifted.real).astype(xshifted.dtype)
VHS_packed += 1.0j * hamiltonian.chol_packed.dot(
xshifted.imag
) # in-place operation reduce gpu mem

# (nb, nb, nw) -> (nw, nb, nb)
VHS_packed = (
self.isqrt_dt * VHS_packed.T.reshape(nwalkers, hamiltonian.chol_packed.shape[0]).copy()
VHS_packed = self.isqrt_dt * VHS_packed.T.reshape(
nwalkers, hamiltonian.chol_packed.shape[0]
)

VHS = xp.zeros(
Expand All @@ -83,6 +84,9 @@ def construct_VHS(self, hamiltonian: GenericRealChol, xshifted: xp.ndarray) -> x
unpack_VHS_batch_gpu[blockspergrid, threadsperblock](
hamiltonian.sym_idx_i, hamiltonian.sym_idx_j, VHS_packed, VHS
)
del VHS_packed
xp.cuda.runtime.deviceSynchronize()
xp._default_memory_pool.free_all_blocks()
else:
unpack_VHS_batch(hamiltonian.sym_idx[0], hamiltonian.sym_idx[1], VHS_packed, VHS)
return VHS
Expand Down
6 changes: 4 additions & 2 deletions ipie/propagation/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,10 @@ def test_vhs():
}
)
legacy_data = build_legacy_test_case_handlers(nelec, nmo, num_dets=1, options=qmc, seed=7)
xshifted = numpy.random.normal(0.0, 1.0, nwalkers * legacy_data.hamiltonian.nfields).reshape(
nwalkers, legacy_data.hamiltonian.nfields
xshifted = (
numpy.random.normal(0.0, 1.0, nwalkers * legacy_data.hamiltonian.nfields)
.reshape(nwalkers, legacy_data.hamiltonian.nfields)
.astype(numpy.complex128)
)
vhs_serial = []
for iw in range(nwalkers):
Expand Down
6 changes: 4 additions & 2 deletions ipie/propagation/tests/test_generic_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def test_generic_propagation_chunked():
)

assert numpy.allclose(vfb, vfb_chunked)
xshifted = numpy.random.normal(0.0, 1.0, ham.nchol * walker_batch.nwalkers).reshape(
walker_batch.nwalkers, ham.nchol
xshifted = (
numpy.random.normal(0.0, 1.0, ham.nchol * walker_batch.nwalkers)
.reshape(walker_batch.nwalkers, ham.nchol)
.astype(numpy.complex128)
)
VHS_chunked = prop.construct_VHS(
ham,
Expand Down
12 changes: 10 additions & 2 deletions ipie/propagation/tests/test_generic_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def test_vhs_complex():
)

ham = test_handler.hamiltonian
xshifted = numpy.random.normal(0.0, 1.0, nwalkers * ham.nfields).reshape(ham.nfields, nwalkers)
xshifted = (
numpy.random.normal(0.0, 1.0, nwalkers * ham.nfields)
.reshape(ham.nfields, nwalkers)
.astype(numpy.complex128)
)

vhs = test_handler.propagator.construct_VHS(ham, xshifted)

Expand Down Expand Up @@ -145,7 +149,11 @@ def test_vhs_complex_vs_real():
ham = test_handler.hamiltonian
chol = ham.chol

xshifted = numpy.random.normal(0.0, 1.0, nwalkers * ham.nfields).reshape(ham.nfields, nwalkers)
xshifted = (
numpy.random.normal(0.0, 1.0, nwalkers * ham.nfields)
.reshape(ham.nfields, nwalkers)
.astype(numpy.complex128)
)

vhs = test_handler.propagator.construct_VHS(ham, xshifted)

Expand Down
3 changes: 2 additions & 1 deletion ipie/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def cast_to_device(self, verbose=False):
expected_bytes = size * 16.0
expected_gb = expected_bytes / 1024.0**3.0
print(f"# {self.__class__.__name__}: expected to allocate {expected_gb} GB")

for k, v in self.__dict__.items():
if k in ["Ga", "Gb"]:
continue # reduce mem usage, Ga/Gb not used, use Ghalf instead
if isinstance(v, _np.ndarray):
self.__dict__[k] = arraylib.array(v)
elif isinstance(v, list) and isinstance(v[0], _np.ndarray):
Expand Down

0 comments on commit 46e8063

Please sign in to comment.