Skip to content

Commit

Permalink
removed inplace training on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Daria Tikhonovich committed Oct 10, 2023
1 parent 20544f5 commit cc10147
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions rectools/models/implicit_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def fit_als_with_features_together(
ui_csr = model.alpha * ui_csr

if isinstance(model, GPUAlternatingLeastSquares): # pragma: no cover
_fit_combined_factors_on_gpu_inplace(
user_factors, item_factors = _fit_combined_factors_on_gpu(
model,
ui_csr,
user_factors,
Expand Down Expand Up @@ -388,15 +388,15 @@ def _fit_combined_factors_on_cpu_inplace(
item_factors[:, n_factors - n_item_explicit_factors :] = item_explicit_factors


def _fit_combined_factors_on_gpu_inplace(
def _fit_combined_factors_on_gpu(
model: GPUAlternatingLeastSquares,
ui_csr: sparse.csr_matrix,
user_factors: np.ndarray,
item_factors: np.ndarray,
n_user_explicit_factors: int,
n_item_explicit_factors: int,
verbose: int,
) -> None:
) -> tp.Tuple[implicit.gpu.Matrix, implicit.gpu.Matrix]:
n_factors = user_factors.shape[1]
user_explicit_factors = user_factors[:, :n_user_explicit_factors].copy()
item_explicit_factors = item_factors[:, n_factors - n_item_explicit_factors :].copy()
Expand All @@ -407,9 +407,6 @@ def _fit_combined_factors_on_gpu_inplace(
X = implicit.gpu.Matrix(user_factors)
Y = implicit.gpu.Matrix(item_factors)

user_factors = X
item_factors = Y

# invalidate cached norms and squared factors
model._item_norms = model._user_norms = None # pylint: disable=protected-access
model._item_norms_host = model._user_norms_host = None # pylint: disable=protected-access
Expand All @@ -433,3 +430,5 @@ def _fit_combined_factors_on_gpu_inplace(
item_factors_np = Y.to_numpy()
item_factors_np[:, n_factors - n_item_explicit_factors :] = item_explicit_factors
Y = implicit.gpu.Matrix(item_factors_np)

return X, Y

0 comments on commit cc10147

Please sign in to comment.