Skip to content

Commit

Permalink
fused contr multi
Browse files Browse the repository at this point in the history
  • Loading branch information
hczhai committed Aug 9, 2024
1 parent eeb263e commit a86cc30
Show file tree
Hide file tree
Showing 11 changed files with 516 additions and 76 deletions.
32 changes: 30 additions & 2 deletions pyblock2/driver/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4349,12 +4349,17 @@ def dmrg(
spectra_with_multiplicity=False,
store_seq_data=False,
lowmem_noise=False,
midmem_noise=False,
sweep_start=0,
forward=None,
kernel=None,
metric_mpo=None,
stacked_mpo=None,
context_ket=None,
delayed_contraction=True,
cached_contraction=True,
fused_contraction_multiplication=False,
fused_contraction_rotation=False,
):
"""
Perform the ground state and/or excited state Density Matrix
Expand Down Expand Up @@ -4456,6 +4461,8 @@ def dmrg(
Only useful for developers. Default is False.
lowmem_noise : bool
If True, the noise step will cost less memory. Default is False.
midmem_noise : bool
If True, the noise step will cost medium memory. Default is False.
sweep_start : int
The starting sweep index in ``bond_dims``, ``noises``, and ``thrds``. Default is 0.
This may be useful in restarting, when one wants to skip the sweep parameters
Expand All @@ -4473,6 +4480,16 @@ def dmrg(
The block2 MPO object stacked with the mpo. Default is None.
context_ket : None or MPS
The block2 MPS object for the symmetry constraint. Default is None (no constraint).
delayed_contraction : bool
If True, delayed contraction (blocking) is used for saving time. Default is True.
cached_contraction : bool
If True, cached contraction (blocking) is used for saving time. Default is True.
fused_contraction_multiplication : bool
If True, fused operation of contraction and multiplication is used for saving memory.
Defult is False.
fused_contraction_rotation : bool
If True, fused operation of contraction and rotation is used for saving memory.
Defult is False.
Returns:
energy : float|complex or list[float|complex]
Expand All @@ -4494,13 +4511,20 @@ def dmrg(
else:
bra = ket
me = bw.bs.MovingEnvironment(mpo, bra, ket, "DMRG")
me.delayed_contraction = bw.b.OpNamesSet.normal_ops()
me.delayed_contraction = bw.b.OpNamesSet.normal_ops() if delayed_contraction else bw.b.OpNamesSet()
if stacked_mpo is not None:
if self.mpi is not None:
raise NotImplementedError()
me.stacked_mpo = stacked_mpo
me.delayed_contraction = bw.b.OpNamesSet()
me.cached_contraction = True
me.cached_contraction = cached_contraction
me.fused_contraction_multiplication = fused_contraction_multiplication
me.fused_contraction_rotation = fused_contraction_rotation
if fused_contraction_multiplication:
assert fused_contraction_rotation
assert not cached_contraction
assert not delayed_contraction
assert bw.b.Global.threading.seq_type != bw.b.SeqTypes.Tasked
dmrg = bw.bs.DMRG(me, bw.b.VectorUBond(bond_dims), bw.VectorFP(noises))
metric_me = None
if metric_mpo is not None:
Expand Down Expand Up @@ -4574,7 +4598,11 @@ def dmrg(
noise_type = "ReducedPerturbativeCollected"
dmrg.noise_type = getattr(bw.b.NoiseTypes, noise_type)
if lowmem_noise:
assert not midmem_noise
dmrg.noise_type = dmrg.noise_type | bw.b.NoiseTypes.LowMem
if midmem_noise:
assert not lowmem_noise
dmrg.noise_type = dmrg.noise_type | bw.b.NoiseTypes.MidMem
if decomp_type is not None:
dmrg.decomp_type = getattr(bw.b.DecompositionTypes, decomp_type)
dmrg.davidson_conv_thrds = bw.VectorFP(thrds)
Expand Down
24 changes: 21 additions & 3 deletions src/core/allocator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ template <typename FL> struct DataFrame {
shared_ptr<FPCodec<FL>> fp_codec =
nullptr; //!< Floating-point compression codec. If nullptr,
//!< floating-point compression will not be used.
// isize and dsize are in Bytes
/** Constructor.
* @param isize Max size (in bytes) of all integer stacks.
* @param dsize Max size (in bytes) of all double stacks.
Expand Down Expand Up @@ -640,7 +639,7 @@ template <typename FL> struct DataFrame {
r += dallocs[i]->used * sizeof(FL) + iallocs[i]->used * 4;
return r;
}
/** Update prak used memory statistics. */
/** Update peak used memory statistics. */
void update_peak_used_memory() const {
for (int i = 0; i < n_frames; i++) {
peak_used_memory[i + 0 * n_frames] =
Expand All @@ -650,7 +649,7 @@ template <typename FL> struct DataFrame {
max(peak_used_memory[i + 1 * n_frames], iallocs[i]->used * 4);
}
}
/** Reset prak used memory statistics to zero. */
/** Reset peak used memory statistics to zero. */
void reset_peak_used_memory() const {
memset(peak_used_memory.data(), 0,
sizeof(size_t) * peak_used_memory.size());
Expand Down Expand Up @@ -697,4 +696,23 @@ inline auto check_signal_() -> void (*&)() {
return check_signal;
}

/** Callback function wrapper. */
struct CallbackKernel {
/** Constructor.*/
CallbackKernel() {}
/** Destructor. */
virtual ~CallbackKernel() = default;
/** Execute callback function.
* @param name Stage name.
* @param iprint Verbosity.
*/
virtual void compute(const string &name, int iprint) const {}
};

/** Function pointer for callback. */
inline shared_ptr<CallbackKernel> &callback_() {
static shared_ptr<CallbackKernel> callback = make_shared<CallbackKernel>();
return callback;
}

} // namespace block2
1 change: 1 addition & 0 deletions src/core/operator_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ enum struct NoiseTypes : uint16_t {
Unscaled = 16,
Collected = 32,
LowMem = 64,
MidMem = 128,
ReducedPerturbative = 4 | 8,
PerturbativeUnscaled = 4 | 16,
ReducedPerturbativeUnscaled = 4 | 8 | 16,
Expand Down
12 changes: 12 additions & 0 deletions src/core/operator_tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ struct DelayedOperatorTensor : OperatorTensor<S, FL> {
shared_ptr<Symbolic<S>> stacked_mat = nullptr;
// SparseMatrix representation of symbols from left and right block
shared_ptr<OperatorTensor<S, FL>> lopt, ropt;
// For fused_contract_multiply
unordered_map<shared_ptr<OpExpr<S>>, pair<size_t, FL>> exprs;
DelayedOperatorTensor() : OperatorTensor<S, FL>() {}
OperatorTensorTypes get_type() const override {
return OperatorTensorTypes::Delayed;
Expand All @@ -227,6 +229,16 @@ struct DelayedOperatorTensor : OperatorTensor<S, FL> {
for (auto it = lopt->ops.cbegin(); it != lopt->ops.cend(); it++)
if (it->second->data != nullptr)
r += it->second->total_memory;
if (ropt->get_type() == OperatorTensorTypes::Delayed &&
dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(ropt)
->exprs.size() != 0)
r += dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(ropt)
->get_total_memory();
if (lopt->get_type() == OperatorTensorTypes::Delayed &&
dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(lopt)
->exprs.size() != 0)
r += dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(lopt)
->get_total_memory();
return r;
}
void reallocate(bool clean) override {
Expand Down
71 changes: 65 additions & 6 deletions src/core/parallel_tensor_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,8 +1114,8 @@ struct ParallelTensorFunctions : TensorFunctions<S, FL> {
shared_ptr<DelayedOperatorTensor<S, FL>>
delayed_contract(const shared_ptr<OperatorTensor<S, FL>> &a,
const shared_ptr<OperatorTensor<S, FL>> &b,
const shared_ptr<OpExpr<S>> &op,
OpNamesSet delayed) const override {
const shared_ptr<OpExpr<S>> &op, OpNamesSet delayed,
bool substitute) const override {
shared_ptr<DelayedOperatorTensor<S, FL>> dopt =
make_shared<DelayedOperatorTensor<S, FL>>();
dopt->lopt = a, dopt->ropt = b;
Expand All @@ -1124,11 +1124,11 @@ struct ParallelTensorFunctions : TensorFunctions<S, FL> {
shared_ptr<Symbolic<S>> exprs = a->lmat * b->rmat;
assert(exprs->data.size() == 1);
bool use_orig = !(rule->get_parallel_type() & ParallelTypes::NewScheme);
if (a->get_type() == OperatorTensorTypes::Delayed)
if (a->get_type() == OperatorTensorTypes::Delayed && substitute)
dopt->mat = substitute_delayed_exprs(
exprs, dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(a),
true, delayed, use_orig);
else if (b->get_type() == OperatorTensorTypes::Delayed)
else if (b->get_type() == OperatorTensorTypes::Delayed && substitute)
dopt->mat = substitute_delayed_exprs(
exprs, dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(b),
false, delayed, use_orig);
Expand All @@ -1148,17 +1148,18 @@ struct ParallelTensorFunctions : TensorFunctions<S, FL> {
const shared_ptr<OperatorTensor<S, FL>> &b,
const shared_ptr<Symbolic<S>> &ops,
const shared_ptr<Symbolic<S>> &exprs, OpNamesSet delayed,
bool substitute,
const shared_ptr<Symbolic<S>> &xexprs = nullptr) const override {
shared_ptr<DelayedOperatorTensor<S, FL>> dopt =
make_shared<DelayedOperatorTensor<S, FL>>();
dopt->lopt = a, dopt->ropt = b;
dopt->dops = ops->data;
bool use_orig = !(rule->get_parallel_type() & ParallelTypes::NewScheme);
if (a->get_type() == OperatorTensorTypes::Delayed)
if (a->get_type() == OperatorTensorTypes::Delayed && substitute)
dopt->mat = substitute_delayed_exprs(
exprs, dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(a),
true, delayed, use_orig);
else if (b->get_type() == OperatorTensorTypes::Delayed)
else if (b->get_type() == OperatorTensorTypes::Delayed && substitute)
dopt->mat = substitute_delayed_exprs(
exprs, dynamic_pointer_cast<DelayedOperatorTensor<S, FL>>(b),
false, delayed, use_orig);
Expand All @@ -1180,6 +1181,64 @@ struct ParallelTensorFunctions : TensorFunctions<S, FL> {
rule->owner(dopt->dops[i]), dleft);
}
return dopt;
}
// c = a x b (dot) (delayed for 3-operator operations)
void delayed_left_contract(
const shared_ptr<OperatorTensor<S, FL>> &a,
const shared_ptr<OperatorTensor<S, FL>> &b,
shared_ptr<OperatorTensor<S, FL>> &c,
const shared_ptr<Symbolic<S>> &cexprs = nullptr,
const shared_ptr<Symbolic<S>> &cnames = nullptr) const override {
if (a == nullptr)
return left_contract(a, b, c, cexprs);
shared_ptr<DelayedOperatorTensor<S, FL>> dopt =
make_shared<DelayedOperatorTensor<S, FL>>();
dopt->mat = cexprs == nullptr ? a->lmat * b->lmat : cexprs->copy();;
dopt->lopt = a, dopt->ropt = b;
dopt->ops = c->ops;
dopt->lmat = c->lmat, dopt->rmat = c->rmat;
if (cnames != nullptr) {
assert(rule->get_parallel_type() & ParallelTypes::NewScheme);
for (size_t i = 0; i < cnames->data.size(); i++) {
shared_ptr<OpElement<S, FL>> cop =
dynamic_pointer_cast<OpElement<S, FL>>(cnames->data[i]);
shared_ptr<OpExpr<S>> op = abs_value(cnames->data[i]);
dopt->exprs[op] = make_pair(i, (FP)1.0 / cop->factor);
assert(dopt->mat->data[i]->get_type() == OpTypes::ExprRef);
dopt->mat->data[i] =
dynamic_pointer_cast<OpExprRef<S>>(dopt->mat->data[i])->op;
}
}
c = dopt;
}
// c = b (dot) x a (delayed for 3-operator operations)
void delayed_right_contract(
const shared_ptr<OperatorTensor<S, FL>> &a,
const shared_ptr<OperatorTensor<S, FL>> &b,
shared_ptr<OperatorTensor<S, FL>> &c,
const shared_ptr<Symbolic<S>> &cexprs = nullptr,
const shared_ptr<Symbolic<S>> &cnames = nullptr) const override {
if (a == nullptr)
return right_contract(a, b, c, cexprs);
shared_ptr<DelayedOperatorTensor<S, FL>> dopt =
make_shared<DelayedOperatorTensor<S, FL>>();
dopt->mat = cexprs == nullptr ? b->rmat * a->rmat : cexprs->copy();
dopt->lopt = b, dopt->ropt = a;
dopt->ops = c->ops;
dopt->lmat = c->lmat, dopt->rmat = c->rmat;
if (cnames != nullptr) {
assert(rule->get_parallel_type() & ParallelTypes::NewScheme);
for (size_t i = 0; i < cnames->data.size(); i++) {
shared_ptr<OpElement<S, FL>> cop =
dynamic_pointer_cast<OpElement<S, FL>>(cnames->data[i]);
shared_ptr<OpExpr<S>> op = abs_value(cnames->data[i]);
dopt->exprs[op] = make_pair(i, (FP)1.0 / cop->factor);
assert(dopt->mat->data[i]->get_type() == OpTypes::ExprRef);
dopt->mat->data[i] =
dynamic_pointer_cast<OpExprRef<S>>(dopt->mat->data[i])->op;
}
}
c = dopt;
}
// c = a x b (dot)
void left_contract(const shared_ptr<OperatorTensor<S, FL>> &a,
Expand Down
Loading

0 comments on commit a86cc30

Please sign in to comment.