Skip to content

Commit

Permalink
align workspace space to multiple of 256 bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
lijinf2 committed Mar 13, 2024
1 parent 96b3f7d commit be042c7
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
21 changes: 15 additions & 6 deletions cpp/src/glm/qn/mg/standardization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,26 @@ struct Standardizer {
Standardizer(const raft::handle_t& handle,
const SimpleSparseMat<T>& X,
int n_samples,
rmm::device_uvector<T>& mean_std_buff)
rmm::device_uvector<T>& mean_std_buff,
size_t vec_size)
{
int D = X.n;
ASSERT(mean_std_buff.size() == 4 * D, "buff size must be four times the dimension");
ASSERT(mean_std_buff.size() == 4 * vec_size, "buff size must be four times the aligned size");

auto stream = handle.get_stream();

mean.reset(mean_std_buff.data(), D);
std.reset(mean_std_buff.data() + D, D);
std_inv.reset(mean_std_buff.data() + 2 * D, D);
scaled_mean.reset(mean_std_buff.data() + 3 * D, D);
T* p_ws = mean_std_buff.data();

mean.reset(p_ws, D);
p_ws += vec_size;

std.reset(p_ws, D);
p_ws += vec_size;

std_inv.reset(p_ws, D);
p_ws += vec_size;

scaled_mean.reset(p_ws, D);

mean_stddev(handle, X, n_samples, mean.data, std.data);
raft::linalg::unaryOp(std_inv.data, std.data, D, inverse_op(), stream);
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/glm/qn_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,12 @@ void qnFitSparse_impl(const raft::handle_t& handle,
{
auto X_simple = SimpleSparseMat<T>(X_values, X_cols, X_row_ids, X_nnz, N, D);

rmm::device_uvector<T> mean_std_buff(4 * D, handle.get_stream());
size_t vec_size = raft::alignTo<size_t>(sizeof(T) * D, ML::GLM::detail::qn_align);
rmm::device_uvector<T> mean_std_buff(4 * vec_size, handle.get_stream());
Standardizer<T>* stder = NULL;

if (standardization) stder = new Standardizer(handle, X_simple, n_samples, mean_std_buff);
if (standardization)
stder = new Standardizer(handle, X_simple, n_samples, mean_std_buff, vec_size);

ML::GLM::opg::qn_fit_x_mg(handle,
pams,
Expand Down
15 changes: 15 additions & 0 deletions python/cuml/tests/dask/test_dask_aaa_bug.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from cuml.internals.safe_imports import gpu_only_import
import pytest
from cuml.dask.common import utils as dask_utils
Expand Down

0 comments on commit be042c7

Please sign in to comment.