Skip to content
This repository has been archived by the owner on Feb 5, 2024. It is now read-only.

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD committed Aug 24, 2023
1 parent 877760a commit 6a8cce0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
5 changes: 2 additions & 3 deletions pennylane_lightning_gpu/src/algorithms/ObservablesGPUMPI.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,8 @@ class HamiltonianGPUMPI final : public ObservableGPUMPI<T> {
for (size_t term_idx = 0; term_idx < coeffs_.size(); term_idx++) {
DevTag<int> dt_local(sv.getDataBuffer().getDevTag());
dt_local.refresh();
StateVectorCudaMPI<PrecisionT> tmp(
dt_local, sv.getNumGlobalQubits(), sv.getNumLocalQubits(),
sv.getData());
StateVectorCudaMPI<T> tmp(dt_local, sv.getNumGlobalQubits(),
sv.getNumLocalQubits(), sv.getData());
obs_[term_idx]->applyInPlace(tmp);
scaleAndAddC_CUDA(std::complex<T>{coeffs_[term_idx], 0.0},
tmp.getData(), buffer.getData(), tmp.getLength(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,6 @@ class StateVectorCudaMPI
const index_type *csrOffsets_ptr, const index_type csrOffsets_size,
const index_type *columns_ptr,
const std::complex<Precision> *values_ptr, const index_type numNNZ) {

if (mpi_manager_.getRank() == 0) {
PL_ABORT_IF_NOT(static_cast<size_t>(csrOffsets_size - 1) ==
(size_t{1} << this->getTotalNumQubits()),
Expand Down
11 changes: 7 additions & 4 deletions pennylane_lightning_gpu/src/util/CSRMatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ template <class Precision, class index_type> class CSRMatrix {
* @brief Convert a global CSR (Compressed Sparse Row) format matrix into
* local blocks. This operation should be conducted on the rank 0.
*
* @tparam Precision Floating-point precision type.
* @tparam index_type Integer type used as indices of the sparse matrix.
* @param num_row_blocks Number of local blocks per global row.
* @param num_col_blocks Number of local blocks per global column.
* @param mpi_manager MPIManager object.
* @param num_rows Number of rows of the CSR matrix.
* @param csrOffsets_ptr Pointer to the array of row offsets of the sparse
* matrix. Array of size csrOffsets_size.
* @param columns_ptr Pointer to the array of column indices of the sparse
Expand Down Expand Up @@ -147,16 +148,18 @@ auto splitCSRMatrix(MPIManager &mpi_manager, const size_t &num_rows,
/**
* @brief Scatter a CSR (Compressed Sparse Row) format matrix.
*
* @tparam Precision Floating-point precision type.
* @tparam index_type Integer type used as indices of the sparse matrix.
* @param matrix CSR (Compressed Sparse Row) format matrix.
* @param mpi_manager MPIManager object.
* @param matrix CSR (Compressed Sparse Row) format matrix vector.
* @param local_num_rows Number of rows of local CSR matrix.
* @param root Root rank of the scatter operation.
*/
template <class Precision, class index_type>
auto scatterCSRMatrix(MPIManager &mpi_manager,
std::vector<CSRMatrix<Precision, index_type>> &matrix,
size_t local_num_rows, size_t root)
-> CSRMatrix<Precision, index_type> {
// Bcast num_rows and num_cols
size_t num_col_blocks = mpi_manager.getSize();

std::vector<size_t> nnzs;
Expand Down

0 comments on commit 6a8cce0

Please sign in to comment.