Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel PCA: C++ Algorithm Implementation #5987

Open
wants to merge 3 commits into
base: branch-24.10
Choose a base branch
from

Conversation

tomasjoh
Copy link

@tomasjoh tomasjoh commented Jul 27, 2024

Description

Worked with @garrisonhess to add C++ code for KernelPCA. This implementation of Kernel PCA supports three functions with float/double matrix input:

  • kpcaFit()
  • kpcaTransform()
    • Called after fit(X). Is used to support transform(Y). X and Y might be different, so we have to compute the centered gram matrix between training data X and test data Y. See sklearn reference.
  • kpcaTransformWithFitData()
    • Used after kpcaFit() in fit_transform(X). We don't need to calculate the kernel matrix for X, since it's the same input data used in kpcaFit(). See sklearn reference.

Feature request: #1317

Tests were performed on an EC2 g4dn.xlarge instance with CUDA 12.2.

Click here to see environment details
 **git***
 commit ade61faff6ac261028bc9b8bbca8b7e67be00d16 (HEAD -> fea-kernel-pca, fork/fea-kernel-pca)
 Author: Tomas Johannesson <[email protected]>
 Date:   Tue Jul 23 22:01:19 2024 -0500

 syntax fix
 **git submodules***

 ***OS Information***
 DISTRIB_ID=Ubuntu
 DISTRIB_RELEASE=22.04
 DISTRIB_CODENAME=jammy
 DISTRIB_DESCRIPTION="Ubuntu 22.04.4 LTS"
 PRETTY_NAME="Ubuntu 22.04.4 LTS"
 NAME="Ubuntu"
 VERSION_ID="22.04"
 VERSION="22.04.4 LTS (Jammy Jellyfish)"
 VERSION_CODENAME=jammy
 ID=ubuntu
 ID_LIKE=debian
 HOME_URL="https://www.ubuntu.com/"
 SUPPORT_URL="https://help.ubuntu.com/"
 BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
 PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
 UBUNTU_CODENAME=jammy
 Linux ip-172-31-36-86 6.5.0-1020-aws #20~22.04.1-Ubuntu SMP Wed May  1 16:10:50 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

 ***GPU Information***
 Thu Jul 25 01:28:58 2024
 +---------------------------------------------------------------------------------------+
 | NVIDIA-SMI 535.183.01             Driver Version: 535.183.01   CUDA Version: 12.2     |
 |-----------------------------------------+----------------------+----------------------+
 | GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
 | Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
 |                                         |                      |               MIG M. |
 |=========================================+======================+======================|
 |   0  Tesla T4                       On  | 00000000:00:1E.0 Off |                    0 |
 | N/A   32C    P8              14W /  70W |      2MiB / 15360MiB |      0%      Default |
 |                                         |                      |                  N/A |
 +-----------------------------------------+----------------------+----------------------+

 +---------------------------------------------------------------------------------------+
 | Processes:                                                                            |
 |  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
 |        ID   ID                                                             Usage      |
 |=======================================================================================|
 |  No running processes found                                                           |
 +---------------------------------------------------------------------------------------+

 ***CPU***
 Architecture:                       x86_64
 CPU op-mode(s):                     32-bit, 64-bit
 Address sizes:                      46 bits physical, 48 bits virtual
 Byte Order:                         Little Endian
 CPU(s):                             4
 On-line CPU(s) list:                0-3
 Vendor ID:                          GenuineIntel
 Model name:                         Intel(R) Xeon(R) Platinum 8259CL CPU @ 2.50GHz
 CPU family:                         6
 Model:                              85
 Thread(s) per core:                 2
 Core(s) per socket:                 2
 Socket(s):                          1
 Stepping:                           7
 BogoMIPS:                           4999.99
 Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke avx512_vnni
 Hypervisor vendor:                  KVM
 Virtualization type:                full
 L1d cache:                          64 KiB (2 instances)
 L1i cache:                          64 KiB (2 instances)
 L2 cache:                           2 MiB (2 instances)
 L3 cache:                           35.8 MiB (1 instance)
 NUMA node(s):                       1
 NUMA node0 CPU(s):                  0-3
 Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
 Vulnerability Itlb multihit:        KVM: Mitigation: VMX unsupported
 Vulnerability L1tf:                 Mitigation; PTE Inversion
 Vulnerability Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
 Vulnerability Meltdown:             Mitigation; PTI
 Vulnerability Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
 Vulnerability Retbleed:             Vulnerable
 Vulnerability Spec rstack overflow: Not affected
 Vulnerability Spec store bypass:    Vulnerable
 Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
 Vulnerability Spectre v2:           Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline
 Vulnerability Srbds:                Not affected
 Vulnerability Tsx async abort:      Not affected

 ***CMake***
 /home/ubuntu/miniconda3/envs/cuml_dev/bin/cmake
 cmake version 3.29.6

 CMake suite maintained and supported by Kitware (kitware.com/cmake).

 ***g++***
 /home/ubuntu/miniconda3/envs/cuml_dev/bin/g++
 g++ (conda-forge gcc 11.4.0-12) 11.4.0
 Copyright (C) 2021 Free Software Foundation, Inc.
 This is free software; see the source for copying conditions.  There is NO
 warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.


 ***nvcc***
 /home/ubuntu/miniconda3/envs/cuml_dev/bin/nvcc
 nvcc: NVIDIA (R) Cuda compiler driver
 Copyright (c) 2005-2023 NVIDIA Corporation
 Built on Tue_Aug_15_22:02:13_PDT_2023
 Cuda compilation tools, release 12.2, V12.2.140
 Build cuda_12.2.r12.2/compiler.33191640_0

 ***Python***
 /home/ubuntu/miniconda3/envs/cuml_dev/bin/python
 Python 3.11.9

 ***Environment Variables***
 PATH                            : /home/ubuntu/miniconda3/envs/cuml_dev/bin:/home/ubuntu/miniconda3/condabin:/opt/amazon/openmpi/bin:/opt/amazon/efa/bin:/usr/local/cuda-12.1/bin:/usr/local/cuda-12.1/include:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin
 LD_LIBRARY_PATH                 : /opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/opt/aws-ofi-nccl/lib:/usr/local/cuda-12.1/lib:/usr/local/cuda-12.1/lib64:/usr/local/cuda-12.1:/usr/local/cuda-12.1/targets/x86_64-linux/lib/:/usr/local/cuda-12.1/extras/CUPTI/lib64:/usr/local/lib:/usr/lib
 NUMBAPRO_NVVM                   :
 NUMBAPRO_LIBDEVICE              :
 CONDA_PREFIX                    : /home/ubuntu/miniconda3/envs/cuml_dev
 PYTHON_PATH                     :

 ***conda packages***
 /home/ubuntu/miniconda3/condabin/conda
 # packages in environment at /home/ubuntu/miniconda3/envs/cuml_dev:
 #
 # Name                    Version                   Build  Channel
 _libgcc_mutex             0.1                 conda_forge    conda-forge
 _openmp_mutex             4.5                       2_gnu    conda-forge
 _sysroot_linux-64_curr_repodata_hack 3                   h69a702a_14    conda-forge
 accessible-pygments       0.0.5              pyhd8ed1ab_0    conda-forge
 alabaster                 0.7.16             pyhd8ed1ab_0    conda-forge
 anyio                     4.4.0              pyhd8ed1ab_0    conda-forge
 argon2-cffi               23.1.0             pyhd8ed1ab_0    conda-forge
 argon2-cffi-bindings      21.2.0          py311h459d7ec_4    conda-forge
 arrow                     1.3.0              pyhd8ed1ab_0    conda-forge
 asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
 async-lru                 2.0.4              pyhd8ed1ab_0    conda-forge
 atk-1.0                   2.38.0               h04ea711_2    conda-forge
 attrs                     23.2.0             pyh71513ae_0    conda-forge
 aws-c-auth                0.7.22               h9137712_5    conda-forge
 aws-c-cal                 0.6.15               h88a6e22_0    conda-forge
 aws-c-common              0.9.19               h4ab18f5_0    conda-forge
 aws-c-compression         0.2.18               h83b837d_6    conda-forge
 aws-c-event-stream        0.4.2               h0cbf018_13    conda-forge
 aws-c-http                0.8.2                h360477d_2    conda-forge
 aws-c-io                  0.14.9               h2d549f9_2    conda-forge
 aws-c-mqtt                0.10.4               hf85b563_6    conda-forge
 aws-c-s3                  0.5.10               h679ed35_3    conda-forge
 aws-c-sdkutils            0.1.16               h83b837d_2    conda-forge
 aws-checksums             0.1.18               h83b837d_6    conda-forge
 aws-crt-cpp               0.26.12              h8bc9c4d_0    conda-forge
 aws-sdk-cpp               1.11.329             hf74b5d1_5    conda-forge
 azure-core-cpp            1.12.0               h830ed8b_0    conda-forge
 azure-identity-cpp        1.8.0                hdb0d106_1    conda-forge
 azure-storage-blobs-cpp   12.11.0              ha67cba7_1    conda-forge
 azure-storage-common-cpp  12.6.0               he3f277c_1    conda-forge
 azure-storage-files-datalake-cpp 12.10.0              h29b5301_1    conda-forge
 babel                     2.14.0             pyhd8ed1ab_0    conda-forge
 backports.zoneinfo        0.2.1           py311h38be061_8    conda-forge
 beautifulsoup4            4.12.3             pyha770c72_0    conda-forge
 binutils                  2.40                 h4852527_7    conda-forge
 binutils_impl_linux-64    2.40                 ha1999f0_7    conda-forge
 binutils_linux-64         2.40                 hb3c18ed_4    conda-forge
 bleach                    6.1.0              pyhd8ed1ab_0    conda-forge
 bokeh                     3.4.1              pyhd8ed1ab_0    conda-forge
 brotli                    1.1.0                hd590300_1    conda-forge
 brotli-bin                1.1.0                hd590300_1    conda-forge
 brotli-python             1.1.0           py311hb755f60_1    conda-forge
 bzip2                     1.0.8                h5eee18b_6
 c-ares                    1.28.1               hd590300_0    conda-forge
 c-compiler                1.5.2                h0b41bf4_0    conda-forge
 ca-certificates           2024.6.2             hbcca054_0    conda-forge
 cached-property           1.5.2                hd8ed1ab_1    conda-forge
 cached_property           1.5.2              pyha770c72_1    conda-forge
 cachetools                5.3.3              pyhd8ed1ab_0    conda-forge
 cairo                     1.18.0               hbb29018_2    conda-forge
 certifi                   2024.6.2           pyhd8ed1ab_0    conda-forge
 cffi                      1.16.0          py311hb3a22ac_0    conda-forge
 charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
 click                     8.1.7           unix_pyh707e725_0    conda-forge
 cloudpickle               3.0.0              pyhd8ed1ab_0    conda-forge
 cmake                     3.29.6               hcafd917_0    conda-forge
 colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
 comm                      0.2.2              pyhd8ed1ab_0    conda-forge
 commonmark                0.9.1                      py_0    conda-forge
 contourpy                 1.2.1           py311h9547e67_0    conda-forge
 coverage                  7.5.4           py311h331c9d8_0    conda-forge
 cuda-cccl_linux-64        12.2.140             ha770c72_0    conda-forge
 cuda-crt-dev_linux-64     12.2.140             ha770c72_1    conda-forge
 cuda-crt-tools            12.2.140             ha770c72_1    conda-forge
 cuda-cudart               12.2.140             hd3aeb46_0    conda-forge
 cuda-cudart-dev           12.2.140             hd3aeb46_0    conda-forge
 cuda-cudart-dev_linux-64  12.2.140             h59595ed_0    conda-forge
 cuda-cudart-static        12.2.140             hd3aeb46_0    conda-forge
 cuda-cudart-static_linux-64 12.2.140             h59595ed_0    conda-forge
 cuda-cudart_linux-64      12.2.140             h59595ed_0    conda-forge
 cuda-driver-dev_linux-64  12.2.140             h59595ed_0    conda-forge
 cuda-nvcc                 12.2.140             hcdd1206_0    conda-forge
 cuda-nvcc-dev_linux-64    12.2.140             ha770c72_1    conda-forge
 cuda-nvcc-impl            12.2.140             hd3aeb46_1    conda-forge
 cuda-nvcc-tools           12.2.140             hd3aeb46_1    conda-forge
 cuda-nvcc_linux-64        12.2.140             h8a487aa_0    conda-forge
 cuda-nvrtc                12.2.140             hd3aeb46_0    conda-forge
 cuda-nvvm-dev_linux-64    12.2.140             ha770c72_1    conda-forge
 cuda-nvvm-impl            12.2.140             h59595ed_1    conda-forge
 cuda-nvvm-tools           12.2.140             h59595ed_1    conda-forge
 cuda-profiler-api         12.2.140             ha770c72_0    conda-forge
 cuda-python               12.5.0          py311h817de4b_0    conda-forge
 cuda-version              12.2                 he2b69de_3    conda-forge
 cudf                      24.08.00a189    cuda12_py311_240623_gf536e30172_189    rapidsai-nightly
 cuml                      24.8.0                   pypi_0    pypi
 cupy                      13.2.0          py311he5a987b_0    conda-forge
 cupy-core                 13.2.0          py311h3bdf873_0    conda-forge
 cxx-compiler              1.5.2                hf52228f_0    conda-forge
 cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
 cython                    3.0.10          py311hb755f60_0    conda-forge
 cytoolz                   0.12.3          py311h459d7ec_0    conda-forge
 dask                      2024.5.1           pyhd8ed1ab_0    conda-forge
 dask-core                 2024.5.1           pyhd8ed1ab_0    conda-forge
 dask-cuda                 24.08.00a6      py311_240623_g098109a_6    rapidsai-nightly
 dask-cudf                 24.08.00a189    cuda12_py311_240623_gf536e30172_189    rapidsai-nightly
 dask-expr                 1.1.1              pyhd8ed1ab_1    conda-forge
 dask-glm                  0.3.0                    pypi_0    pypi
 dask-ml                   2024.4.4           pyhd8ed1ab_0    conda-forge
 debugpy                   1.8.1           py311hb755f60_0    conda-forge
 decopatch                 1.4.10             pyhd8ed1ab_0    conda-forge
 decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
 defusedxml                0.7.1              pyhd8ed1ab_0    conda-forge
 distributed               2024.5.1           pyhd8ed1ab_0    conda-forge
 distributed-ucxx          0.39.00a        py3.11_240623_g1e6d80c_3    rapidsai-nightly
 dlpack                    0.8                  h59595ed_3    conda-forge
 docutils                  0.19            py311h38be061_1    conda-forge
 doxygen                   1.9.1                hb166930_1    conda-forge
 entrypoints               0.4                pyhd8ed1ab_0    conda-forge
 exceptiongroup            1.2.0              pyhd8ed1ab_2    conda-forge
 execnet                   2.1.1              pyhd8ed1ab_0    conda-forge
 executing                 2.0.1              pyhd8ed1ab_0    conda-forge
 expat                     2.6.2                h59595ed_0    conda-forge
 fastrlock                 0.8.2           py311hb755f60_2    conda-forge
 fmt                       10.2.1               h00ab1b0_0    conda-forge
 font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
 font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
 font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
 font-ttf-ubuntu           0.83                 h77eed37_2    conda-forge
 fontconfig                2.14.2               h14ed4e7_0    conda-forge
 fonts-conda-ecosystem     1                             0    conda-forge
 fonts-conda-forge         1                             0    conda-forge
 fonttools                 4.53.0          py311h331c9d8_0    conda-forge
 fqdn                      1.5.1              pyhd8ed1ab_0    conda-forge
 freetype                  2.12.1               h267a509_2    conda-forge
 fribidi                   1.0.10               h36c2ea0_0    conda-forge
 fsspec                    2024.6.0           pyhff2d567_0    conda-forge
 future                    1.0.0              pyhd8ed1ab_0    conda-forge
 gcc                       11.4.0              h602e360_12    conda-forge
 gcc_impl_linux-64         11.4.0              h00c12a0_12    conda-forge
 gcc_linux-64              11.4.0               ha077dfb_4    conda-forge
 gdk-pixbuf                2.42.12              hb9ae30d_0    conda-forge
 gflags                    2.2.2             he1b5a44_1004    conda-forge
 giflib                    5.2.2                hd590300_0    conda-forge
 glog                      0.7.1                hbabe93e_0    conda-forge
 graphite2                 1.3.13            h59595ed_1003    conda-forge
 graphviz                  11.0.0               hc68bbd7_0    conda-forge
 gtk2                      2.24.33              h280cfa0_4    conda-forge
 gts                       0.7.6                h977cf35_4    conda-forge
 gxx                       11.4.0              h602e360_12    conda-forge
 gxx_impl_linux-64         11.4.0              h634f3ee_12    conda-forge
 gxx_linux-64              11.4.0               h35bfe5d_4    conda-forge
 h11                       0.14.0             pyhd8ed1ab_0    conda-forge
 h2                        4.1.0              pyhd8ed1ab_0    conda-forge
 harfbuzz                  8.5.0                hfac3d4d_0    conda-forge
 hdbscan                   0.8.30          py311h1f0f07a_0    conda-forge
 hpack                     4.0.0              pyh9f0ad1d_0    conda-forge
 httpcore                  1.0.5              pyhd8ed1ab_0    conda-forge
 httpx                     0.27.0             pyhd8ed1ab_0    conda-forge
 hyperframe                6.0.1              pyhd8ed1ab_0    conda-forge
 hypothesis                6.103.2            pyha770c72_0    conda-forge
 icu                       73.2                 h59595ed_0    conda-forge
 idna                      3.7                pyhd8ed1ab_0    conda-forge
 imagesize                 1.4.1              pyhd8ed1ab_0    conda-forge
 importlib-metadata        7.2.0              pyha770c72_0    conda-forge
 importlib-resources       6.4.0              pyhd8ed1ab_0    conda-forge
 importlib_metadata        7.2.0                hd8ed1ab_0    conda-forge
 importlib_resources       6.4.0              pyhd8ed1ab_0    conda-forge
 iniconfig                 2.0.0              pyhd8ed1ab_0    conda-forge
 ipykernel                 6.29.4             pyh3099207_0    conda-forge
 ipython                   8.25.0             pyh707e725_0    conda-forge
 isoduration               20.11.0            pyhd8ed1ab_0    conda-forge
 jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
 jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
 joblib                    1.4.2              pyhd8ed1ab_0    conda-forge
 json5                     0.9.25             pyhd8ed1ab_0    conda-forge
 jsonpointer               3.0.0           py311h38be061_0    conda-forge
 jsonschema                4.22.0             pyhd8ed1ab_0    conda-forge
 jsonschema-specifications 2023.12.1          pyhd8ed1ab_0    conda-forge
 jsonschema-with-format-nongpl 4.22.0             pyhd8ed1ab_0    conda-forge
 jupyter-lsp               2.2.5              pyhd8ed1ab_0    conda-forge
 jupyter_client            8.6.2              pyhd8ed1ab_0    conda-forge
 jupyter_core              5.7.2           py311h38be061_0    conda-forge
 jupyter_events            0.10.0             pyhd8ed1ab_0    conda-forge
 jupyter_server            2.14.1             pyhd8ed1ab_0    conda-forge
 jupyter_server_terminals  0.5.3              pyhd8ed1ab_0    conda-forge
 jupyterlab                4.2.3              pyhd8ed1ab_0    conda-forge
 jupyterlab_pygments       0.3.0              pyhd8ed1ab_1    conda-forge
 jupyterlab_server         2.27.2             pyhd8ed1ab_0    conda-forge
 kernel-headers_linux-64   3.10.0              h4a8ded7_14    conda-forge
 keyutils                  1.6.1                h166bdaf_0    conda-forge
 kiwisolver                1.4.5           py311h9547e67_1    conda-forge
 krb5                      1.21.2               h659d440_0    conda-forge
 lcms2                     2.16                 hb7c19ff_0    conda-forge
 ld_impl_linux-64          2.40                 hf3520f5_7    conda-forge
 lerc                      4.0.0                h27087fc_0    conda-forge
 libabseil                 20240116.2      cxx17_h59595ed_0    conda-forge
 libarrow                  16.1.0          h4a673ee_10_cpu    conda-forge
 libarrow-acero            16.1.0          hac33072_10_cpu    conda-forge
 libarrow-dataset          16.1.0          hac33072_10_cpu    conda-forge
 libarrow-substrait        16.1.0          h7e0c224_10_cpu    conda-forge
 libblas                   3.9.0           22_linux64_openblas    conda-forge
 libbrotlicommon           1.1.0                hd590300_1    conda-forge
 libbrotlidec              1.1.0                hd590300_1    conda-forge
 libbrotlienc              1.1.0                hd590300_1    conda-forge
 libcblas                  3.9.0           22_linux64_openblas    conda-forge
 libcrc32c                 1.1.2                h9c3ff4c_0    conda-forge
 libcublas                 12.2.5.6             hd3aeb46_0    conda-forge
 libcublas-dev             12.2.5.6             hd3aeb46_0    conda-forge
 libcudf                   24.08.00a189    cuda12_240623_gf536e30172_189    rapidsai-nightly
 libcufft                  11.0.8.103           hd3aeb46_0    conda-forge
 libcufft-dev              11.0.8.103           hd3aeb46_0    conda-forge
 libcufile                 1.7.2.10             hd3aeb46_0    conda-forge
 libcufile-dev             1.7.2.10             hd3aeb46_0    conda-forge
 libcumlprims              24.08.00a       cuda12_240623_g10e088a_1    rapidsai-nightly
 libcurand                 10.3.3.141           hd3aeb46_0    conda-forge
 libcurand-dev             10.3.3.141           hd3aeb46_0    conda-forge
 libcurl                   8.8.0                hca28451_0    conda-forge
 libcusolver               11.5.2.141           hd3aeb46_0    conda-forge
 libcusolver-dev           11.5.2.141           hd3aeb46_0    conda-forge
 libcusparse               12.1.2.141           hd3aeb46_0    conda-forge
 libcusparse-dev           12.1.2.141           hd3aeb46_0    conda-forge
 libdeflate                1.20                 hd590300_0    conda-forge
 libedit                   3.1.20191231         he28a2e2_2    conda-forge
 libev                     4.33                 hd590300_2    conda-forge
 libevent                  2.1.12               hf998b51_1    conda-forge
 libexpat                  2.6.2                h59595ed_0    conda-forge
 libffi                    3.4.4                h6a678d5_1
 libgcc-devel_linux-64     11.4.0             h8f596e0_112    conda-forge
 libgcc-ng                 13.2.0              h77fa898_12    conda-forge
 libgd                     2.3.3                h119a65a_9    conda-forge
 libgfortran-ng            13.2.0              h69a702a_12    conda-forge
 libgfortran5              13.2.0              h3d2ce59_12    conda-forge
 libglib                   2.80.2               h8a4344b_1    conda-forge
 libgomp                   13.2.0              h77fa898_12    conda-forge
 libgoogle-cloud           2.25.0               h2736e30_0    conda-forge
 libgoogle-cloud-storage   2.25.0               h3d9a0c8_0    conda-forge
 libgrpc                   1.62.2               h15f2491_0    conda-forge
 libhwloc                  2.10.0          default_h5622ce7_1001    conda-forge
 libiconv                  1.17                 hd590300_2    conda-forge
 libjpeg-turbo             3.0.0                hd590300_1    conda-forge
 libkvikio                 24.08.00a       cuda12_240623_g3cc6678_10    rapidsai-nightly
 liblapack                 3.9.0           22_linux64_openblas    conda-forge
 libllvm14                 14.0.6               hcd5def8_4    conda-forge
 libnghttp2                1.58.0               h47da74e_1    conda-forge
 libnl                     3.9.0                hd590300_0    conda-forge
 libnsl                    2.0.1                hd590300_0    conda-forge
 libnvjitlink              12.2.140             hd3aeb46_0    conda-forge
 libopenblas               0.3.27          pthreads_h413a1c8_0    conda-forge
 libparquet                16.1.0          h6a7eafb_10_cpu    conda-forge
 libpng                    1.6.43               h2797004_0    conda-forge
 libprotobuf               4.25.3               h08a7969_0    conda-forge
 libraft                   24.08.00a33     cuda12_240623_gb86a5f90_33    rapidsai-nightly
 libraft-headers           24.08.00a33     cuda12_240623_gb86a5f90_33    rapidsai-nightly
 libraft-headers-only      24.08.00a33     cuda12_240623_gb86a5f90_33    rapidsai-nightly
 libre2-11                 2023.09.01           h5a48ba9_2    conda-forge
 librmm                    24.08.00a17     cuda12_240623_gf2d07976_17    rapidsai-nightly
 librsvg                   2.58.1               hadf69e7_0    conda-forge
 libsanitizer              11.4.0              h5763a12_12    conda-forge
 libsodium                 1.0.18               h36c2ea0_1    conda-forge
 libsqlite                 3.46.0               hde9e2c9_0    conda-forge
 libssh2                   1.11.0               h0841786_0    conda-forge
 libstdcxx-devel_linux-64  11.4.0             h8f596e0_112    conda-forge
 libstdcxx-ng              13.2.0              hc0a3c3a_12    conda-forge
 libthrift                 0.19.0               hb90f79a_1    conda-forge
 libtiff                   4.6.0                h1dd3fc0_3    conda-forge
 libucxx                   0.39.00a        cuda12_240623_g1e6d80c_3    rapidsai-nightly
 libutf8proc               2.8.0                h166bdaf_0    conda-forge
 libuuid                   2.38.1               h0b41bf4_0    conda-forge
 libuv                     1.48.0               hd590300_0    conda-forge
 libwebp                   1.4.0                h2c329e2_0    conda-forge
 libwebp-base              1.4.0                hd590300_0    conda-forge
 libxcb                    1.16                 hd590300_0    conda-forge
 libxcrypt                 4.4.36               hd590300_1    conda-forge
 libxml2                   2.12.7               hc051c1a_1    conda-forge
 libzlib                   1.3.1                h4ab18f5_1    conda-forge
 llvmlite                  0.43.0          py311hbde99c3_0    conda-forge
 locket                    1.0.0              pyhd8ed1ab_0    conda-forge
 lz4                       4.3.3           py311h38e4bf4_0    conda-forge
 lz4-c                     1.9.4                hcb278e6_0    conda-forge
 makefun                   1.15.2             pyhd8ed1ab_0    conda-forge
 markdown                  3.6                pyhd8ed1ab_0    conda-forge
 markdown-it-py            3.0.0              pyhd8ed1ab_0    conda-forge
 markupsafe                2.1.5           py311h459d7ec_0    conda-forge
 matplotlib-base           3.8.4           py311ha4ca890_2    conda-forge
 matplotlib-inline         0.1.7              pyhd8ed1ab_0    conda-forge
 mdurl                     0.1.2              pyhd8ed1ab_0    conda-forge
 mistune                   3.0.2              pyhd8ed1ab_0    conda-forge
 msgpack-python            1.0.8           py311h52f7536_0    conda-forge
 multipledispatch          0.6.0                      py_0    conda-forge
 munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
 nbclient                  0.10.0             pyhd8ed1ab_0    conda-forge
 nbconvert                 7.16.4               hd8ed1ab_1    conda-forge
 nbconvert-core            7.16.4             pyhd8ed1ab_1    conda-forge
 nbconvert-pandoc          7.16.4               hd8ed1ab_1    conda-forge
 nbformat                  5.10.4             pyhd8ed1ab_0    conda-forge
 nbsphinx                  0.9.4              pyhd8ed1ab_0    conda-forge
 nccl                      2.22.3.1             hbc370b7_0    conda-forge
 ncurses                   6.5                  h59595ed_0    conda-forge
 nest-asyncio              1.6.0              pyhd8ed1ab_0    conda-forge
 ninja                     1.12.1               h297d8ca_0    conda-forge
 nltk                      3.8.1              pyhd8ed1ab_0    conda-forge
 notebook                  7.2.1              pyhd8ed1ab_0    conda-forge
 notebook-shim             0.2.4              pyhd8ed1ab_0    conda-forge
 numba                     0.60.0          py311h4bc866e_0    conda-forge
 numpy                     1.26.4          py311h64a7726_0    conda-forge
 numpydoc                  1.6.0              pyhd8ed1ab_0    conda-forge
 nvcomp                    3.0.6                h10b603f_0    conda-forge
 nvtx                      0.2.10          py311h459d7ec_0    conda-forge
 openjpeg                  2.5.2                h488ebb8_0    conda-forge
 openssl                   3.3.1                h4ab18f5_1    conda-forge
 orc                       2.0.1                h17fec99_1    conda-forge
 overrides                 7.7.0              pyhd8ed1ab_0    conda-forge
 packaging                 24.1               pyhd8ed1ab_0    conda-forge
 pandas                    2.2.2           py311h14de704_1    conda-forge
 pandoc                    3.2                  ha770c72_0    conda-forge
 pandocfilters             1.5.0              pyhd8ed1ab_0    conda-forge
 pango                     1.54.0               h84a9a3c_0    conda-forge
 parso                     0.8.4              pyhd8ed1ab_0    conda-forge
 partd                     1.4.2              pyhd8ed1ab_0    conda-forge
 pathspec                  0.12.1             pyhd8ed1ab_0    conda-forge
 patsy                     0.5.6              pyhd8ed1ab_0    conda-forge
 pcre2                     10.44                h0f59acf_0    conda-forge
 pexpect                   4.9.0              pyhd8ed1ab_0    conda-forge
 pickleshare               0.7.5                   py_1003    conda-forge
 pillow                    10.3.0          py311h82a398c_1    conda-forge
 pip                       24.0               pyhd8ed1ab_0    conda-forge
 pixman                    0.43.2               h59595ed_0    conda-forge
 pkgutil-resolve-name      1.3.10             pyhd8ed1ab_1    conda-forge
 platformdirs              4.2.2              pyhd8ed1ab_0    conda-forge
 pluggy                    1.5.0              pyhd8ed1ab_0    conda-forge
 prometheus_client         0.20.0             pyhd8ed1ab_0    conda-forge
 prompt-toolkit            3.0.47             pyha770c72_0    conda-forge
 psutil                    5.9.8           py311h459d7ec_0    conda-forge
 pthread-stubs             0.4               h36c2ea0_1001    conda-forge
 ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
 pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
 py-cpuinfo                9.0.0              pyhd8ed1ab_0    conda-forge
 pyarrow                   16.1.0          py311hbd00459_3    conda-forge
 pyarrow-core              16.1.0          py311h8c3dac4_3_cpu    conda-forge
 pyarrow-hotfix            0.6                pyhd8ed1ab_0    conda-forge
 pycparser                 2.22               pyhd8ed1ab_0    conda-forge
 pydata-sphinx-theme       0.15.3             pyhd8ed1ab_0    conda-forge
 pygments                  2.18.0             pyhd8ed1ab_0    conda-forge
 pylibraft                 24.08.00a33     cuda12_py311_240623_gb86a5f90_33    rapidsai-nightly
 pynndescent               0.5.8              pyh1a96a4e_0    conda-forge
 pynvjitlink               0.2.4           py311hd269673_0    rapidsai
 pynvml                    11.4.1             pyhd8ed1ab_0    conda-forge
 pyparsing                 3.1.2              pyhd8ed1ab_0    conda-forge
 pysocks                   1.7.1              pyha2e5f31_6    conda-forge
 pytest                    7.4.4              pyhd8ed1ab_0    conda-forge
 pytest-benchmark          4.0.0              pyhd8ed1ab_0    conda-forge
 pytest-cases              3.8.5              pyhd8ed1ab_0    conda-forge
 pytest-cov                5.0.0              pyhd8ed1ab_0    conda-forge
 pytest-xdist              3.6.1              pyhd8ed1ab_0    conda-forge
 python                    3.11.9          hb806964_0_cpython    conda-forge
 python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
 python-fastjsonschema     2.20.0             pyhd8ed1ab_0    conda-forge
 python-json-logger        2.0.7              pyhd8ed1ab_0    conda-forge
 python-tzdata             2024.1             pyhd8ed1ab_0    conda-forge
 python_abi                3.11                    4_cp311    conda-forge
 pytz                      2024.1             pyhd8ed1ab_0    conda-forge
 pyyaml                    6.0.1           py311h459d7ec_1    conda-forge
 pyzmq                     26.0.3          py311h08a0b41_0    conda-forge
 raft-dask                 24.08.00a33     cuda12_py311_240623_gb86a5f90_33    rapidsai-nightly
 rapids-build-backend      0.3.1                      py_0    rapidsai-nightly
 rapids-dask-dependency    24.08.00a4                 py_0    rapidsai-nightly
 rapids-dependency-file-generator 1.13.11                    py_0    rapidsai
 rdma-core                 51.1                 he02047a_0    conda-forge
 re2                       2023.09.01           h7f4b329_2    conda-forge
 readline                  8.2                  h5eee18b_0
 recommonmark              0.7.1              pyhd8ed1ab_0    conda-forge
 referencing               0.35.1             pyhd8ed1ab_0    conda-forge
 regex                     2024.5.15       py311h331c9d8_0    conda-forge
 requests                  2.32.3             pyhd8ed1ab_0    conda-forge
 rfc3339-validator         0.1.4              pyhd8ed1ab_0    conda-forge
 rfc3986-validator         0.1.1              pyh9f0ad1d_0    conda-forge
 rhash                     1.4.4                hd590300_0    conda-forge
 rich                      13.7.1             pyhd8ed1ab_0    conda-forge
 rmm                       24.08.00a17     cuda12_py311_240623_gf2d07976_17    rapidsai-nightly
 rpds-py                   0.18.1          py311h5ecf98a_0    conda-forge
 s2n                       1.4.16               he19d79f_0    conda-forge
 scikit-build-core         0.9.6              pyh4af843d_0    conda-forge
 scikit-learn              1.5.0           py311he08f58d_1    conda-forge
 scipy                     1.13.1          py311h517d4fd_0    conda-forge
 seaborn                   0.13.2               hd8ed1ab_2    conda-forge
 seaborn-base              0.13.2             pyhd8ed1ab_2    conda-forge
 send2trash                1.8.3              pyh0d859eb_0    conda-forge
 setuptools                69.5.1          py311h06a4308_0
 six                       1.16.0             pyh6c4a22f_0    conda-forge
 snappy                    1.2.0                hdb0a2a9_1    conda-forge
 sniffio                   1.3.1              pyhd8ed1ab_0    conda-forge
 snowballstemmer           2.2.0              pyhd8ed1ab_0    conda-forge
 sortedcontainers          2.4.0              pyhd8ed1ab_0    conda-forge
 soupsieve                 2.5                pyhd8ed1ab_1    conda-forge
 sparse                    0.15.4             pyhd8ed1ab_0    conda-forge
 spdlog                    1.12.0               hd2e6256_2    conda-forge
 sphinx                    5.3.0              pyhd8ed1ab_0    conda-forge
 sphinx-copybutton         0.5.2              pyhd8ed1ab_0    conda-forge
 sphinx-markdown-tables    0.0.17             pyh6c4a22f_0    conda-forge
 sphinxcontrib-applehelp   1.0.8              pyhd8ed1ab_0    conda-forge
 sphinxcontrib-devhelp     1.0.6              pyhd8ed1ab_0    conda-forge
 sphinxcontrib-htmlhelp    2.0.5              pyhd8ed1ab_0    conda-forge
 sphinxcontrib-jsmath      1.0.1              pyhd8ed1ab_0    conda-forge
 sphinxcontrib-qthelp      1.0.7              pyhd8ed1ab_0    conda-forge
 sphinxcontrib-serializinghtml 1.1.10             pyhd8ed1ab_0    conda-forge
 sqlite                    3.46.0               h6d4b2fc_0    conda-forge
 stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
 statsmodels               0.14.2          py311h18e1886_0    conda-forge
 sysroot_linux-64          2.17                h4a8ded7_14    conda-forge
 tabulate                  0.9.0              pyhd8ed1ab_1    conda-forge
 tbb                       2021.12.0            h297d8ca_1    conda-forge
 tblib                     3.0.0              pyhd8ed1ab_0    conda-forge
 terminado                 0.18.1             pyh0d859eb_0    conda-forge
 threadpoolctl             3.5.0              pyhc1e730c_0    conda-forge
 tinycss2                  1.3.0              pyhd8ed1ab_0    conda-forge
 tk                        8.6.13          noxft_h4845f30_101    conda-forge
 toml                      0.10.2             pyhd8ed1ab_0    conda-forge
 tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
 tomlkit                   0.12.5             pyha770c72_0    conda-forge
 toolz                     0.12.1             pyhd8ed1ab_0    conda-forge
 tornado                   6.4.1           py311h331c9d8_0    conda-forge
 tqdm                      4.66.4             pyhd8ed1ab_0    conda-forge
 traitlets                 5.14.3             pyhd8ed1ab_0    conda-forge
 treelite                  4.2.1           py311he8f9275_0    conda-forge
 types-python-dateutil     2.9.0.20240316     pyhd8ed1ab_0    conda-forge
 typing-extensions         4.12.2               hd8ed1ab_0    conda-forge
 typing_extensions         4.12.2             pyha770c72_0    conda-forge
 typing_utils              0.1.0              pyhd8ed1ab_0    conda-forge
 tzdata                    2024a                h04d1e81_0
 ucx                       1.15.0               hda83522_8    conda-forge
 ucx-py                    0.39.00a3       py311_240623_g42c03ef_3    rapidsai-nightly
 ucxx                      0.39.00a        cuda12_py3.11_240623_g1e6d80c_3    rapidsai-nightly
 umap-learn                0.5.3           py311h38be061_1    conda-forge
 uri-template              1.3.0              pyhd8ed1ab_0    conda-forge
 urllib3                   2.2.2              pyhd8ed1ab_0    conda-forge
 wcwidth                   0.2.13             pyhd8ed1ab_0    conda-forge
 webcolors                 24.6.0             pyhd8ed1ab_0    conda-forge
 webencodings              0.5.1              pyhd8ed1ab_2    conda-forge
 websocket-client          1.8.0              pyhd8ed1ab_0    conda-forge
 wheel                     0.43.0          py311h06a4308_0
 xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
 xorg-libice               1.1.1                hd590300_0    conda-forge
 xorg-libsm                1.2.4                h7391055_0    conda-forge
 xorg-libx11               1.8.9                hb711507_1    conda-forge
 xorg-libxau               1.0.11               hd590300_0    conda-forge
 xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
 xorg-libxext              1.3.4                h0b41bf4_2    conda-forge
 xorg-libxrender           0.9.11               hd590300_0    conda-forge
 xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
 xorg-xextproto            7.3.0             h0b41bf4_1003    conda-forge
 xorg-xproto               7.0.31            h7f98852_1007    conda-forge
 xyzservices               2024.6.0           pyhd8ed1ab_0    conda-forge
 xz                        5.4.6                h5eee18b_1
 yaml                      0.2.5                h7f98852_2    conda-forge
 zeromq                    4.3.5                h75354e8_4    conda-forge
 zict                      3.0.0              pyhd8ed1ab_0    conda-forge
 zipp                      3.19.2             pyhd8ed1ab_0    conda-forge
 zlib                      1.3.1                h4ab18f5_1    conda-forge
 zstd                      1.5.6                ha6fb4c9_0    conda-forge

Notes for Reviewers

Cython wrapper PR contains more varied unit and manual tests: #5988

C++ unit test reference

Sklearn script to generate C++ test reference data
import numpy as np
from sklearn.decomposition import KernelPCA, PCA

Input data

n_rows = 3
n_cols = 2
n_components = 2
data = np.array([
[1.0, 4.0],
[2.0, 2.0],
[5.0, 1.0]
])

Kernel Parameters

class KernelParams:
def init(self, kernel, degree=3, gamma=1.0, coef0=1):
self.kernel = kernel
self.degree = degree
self.gamma = gamma
self.coef0 = coef0

Define kernel parameters

lin_kern = KernelParams(kernel='linear', degree=0, gamma=0, coef0=0)
poly_kern = KernelParams(kernel='poly', degree=3, gamma=1.0/2.0, coef0=1)
rbf_kern = KernelParams(kernel='rbf', degree=0, gamma=1.0/2.0, coef0=0)

Function to compute and print kernel matrices

def compute_and_print_kernel_matrices(data, kernel_params, n_components):
kpca = KernelPCA(n_components=n_components,
kernel=kernel_params.kernel,
degree=kernel_params.degree,
gamma=kernel_params.gamma,
coef0=kernel_params.coef0)

# Fit and transform data
transformed_data = kpca.fit_transform(data)

# Print results
print(f"transformed_data for {kernel_params.kernel} kernel:\n{transformed_data}")
print(f"eigenvalues_ for {kernel_params.kernel} kernel:\n{kpca.eigenvalues_}")
print(f"eigenvectors_ for {kernel_params.kernel} kernel:\n{kpca.eigenvectors_}")

Compute and print kernel matrices for linear, polynomial, and RBF kernels

kernels = [lin_kern, poly_kern, rbf_kern]
for kernel_params in kernels:
compute_and_print_kernel_matrices(data, kernel_params, n_components)

Output from script and C++ expected values are below. Note that C++ matrices are in column order while Python w. numpy uses row order.
Sklearn linear kernel:

transformed_data for linear kernel:
[[-2.32318647 -0.39794495]
 [-0.35170213  0.65716145]
 [ 2.6748886  -0.25921649]]
eigenvalues_ for linear kernel:
[12.67591879  0.65741454]
eigenvectors_ for linear kernel:
[[-0.65252078 -0.49079864]
 [-0.0987837   0.81049889]
 [ 0.75130448 -0.31970025]]

C++ test reference:

std::vector<float> lin_trans_data_ref_h = {-2.32318647,-0.35170213, 2.6748886, -0.39794495, 0.65716145,-0.25921649};
std::vector<float> lin_eigenvalues_ref_h = {12.6759, 0.6574};
std::vector<float> lin_eigenvectors_ref_h = {-0.6525, -0.0987, 0.7513, -0.4907, 0.8105, -0.3197};

Sklearn poly kernel:

transformed_data for poly kernel:
[[-22.97601835  -8.84385051]
 [-10.85540799  11.24260772]
 [ 33.83142634  -2.39875721]]
eigenvalues_ for poly kernel:
[1790.30271016  210.36395651]
eigenvectors_ for poly kernel:
[[-0.54301464 -0.6097555 ]
 [-0.25655644  0.77514222]
 [ 0.79957107 -0.16538672]]

C++ test reference:

std::vector<float> poly_trans_data_ref_h = {-22.9760, -10.8554, 33.8314, -8.8438, 11.2426, -2.3987};
std::vector<float> poly_eigenvalues_ref_h = {1790.3207, 210.3639};
std::vector<float> poly_eigenvectors_ref_h = {-0.5430, -0.2565, 0.7995, -0.6097, 0.7751, -0.1653};

Sklearn rbf kernel:

transformed_data for rbf kernel:
[[-0.43907684 -0.66248897]
 [-0.38619599  0.69140675]
 [ 0.82527283 -0.02891778]]
eigenvalues_ for rbf kernel:
[1.02301105 0.91777117]
eigenvectors_ for rbf kernel:
[[-0.43411058 -0.69153067]
 [-0.38182784  0.72171613]
 [ 0.81593842 -0.03018545]]

C++ test reference:

std::vector<float> rbf_trans_data_ref_h = {-0.4391, -0.3862, 0.8253, -0.6624, 0.6914, -0.0289};
std::vector<float> rbf_eigenvalues_ref_h = {1.0230, 0.9177};
std::vector<float> rbf_eigenvectors_ref_h = {-0.4341, -0.3818, 0.8159, -0.6915, 0.7217, -0.0301};

Definition of Done Criteria Checklist

C++ Checklist

Design

  • Existing prims are used wherever possible
  • Array inputs and outputs to algorithms are accepted on device
  • New prims created wherever there is potential for reuse across different algorithms or prims
  • User-facing API is stateless and follows the plain-old data (POD) design paradigm
  • Public API contains a C-Wrapper around the stateless API
  • (optional) Public API contains an Scikit-learn-like stateful wrapper around the stateless API

Testing

  • Prims: GTests with different inputs
  • Algorithms: End-to-end GTests with different inputs and different datasets

Documentation

  • Complete and comprehensive Doxygen strings explaining the public API, restrictions, and gotchas. Any array parameters should also note whether the underlying memory is host or device.
  • Array inputs/outputs should also mention their expected size/dimension.
  • If there are references to the underlying algorithm, they must be cited too.

Unit test results

C++ Test Results
(cuml_dev) ubuntu@ip-172-31-36-86:~/cuml$ ./cpp/build/test/SG_KPCA_TEST
Running main() from /home/ubuntu/cuml/cpp/build/_deps/gtest-src/googletest/src/gtest_main.cc
[==========] Running 36 tests from 6 test suites.
[----------] Global test environment set-up.
[----------] 6 tests from KPcaTests/KPcaTestEigenvaluesF
[ RUN      ] KPcaTests/KPcaTestEigenvaluesF.Result/0
[       OK ] KPcaTests/KPcaTestEigenvaluesF.Result/0 (430 ms)
[ RUN      ] KPcaTests/KPcaTestEigenvaluesF.Result/1
[       OK ] KPcaTests/KPcaTestEigenvaluesF.Result/1 (2 ms)
[ RUN      ] KPcaTests/KPcaTestEigenvaluesF.Result/2
[       OK ] KPcaTests/KPcaTestEigenvaluesF.Result/2 (2 ms)
[ RUN      ] KPcaTests/KPcaTestEigenvaluesF.Result/3
[       OK ] KPcaTests/KPcaTestEigenvaluesF.Result/3 (2 ms)
[ RUN      ] KPcaTests/KPcaTestEigenvaluesF.Result/4
[       OK ] KPcaTests/KPcaTestEigenvaluesF.Result/4 (2 ms)
[ RUN      ] KPcaTests/KPcaTestEigenvaluesF.Result/5
[       OK ] KPcaTests/KPcaTestEigenvaluesF.Result/5 (2 ms)
[----------] 6 tests from KPcaTests/KPcaTestEigenvaluesF (442 ms total)

[----------] 6 tests from KPcaTests/KPcaTestEigenvectorsF
[ RUN ] KPcaTests/KPcaTestEigenvectorsF.Result/0
[ OK ] KPcaTests/KPcaTestEigenvectorsF.Result/0 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsF.Result/1
[ OK ] KPcaTests/KPcaTestEigenvectorsF.Result/1 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsF.Result/2
[ OK ] KPcaTests/KPcaTestEigenvectorsF.Result/2 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsF.Result/3
[ OK ] KPcaTests/KPcaTestEigenvectorsF.Result/3 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsF.Result/4
[ OK ] KPcaTests/KPcaTestEigenvectorsF.Result/4 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsF.Result/5
[ OK ] KPcaTests/KPcaTestEigenvectorsF.Result/5 (2 ms)
[----------] 6 tests from KPcaTests/KPcaTestEigenvectorsF (13 ms total)

[----------] 6 tests from KPcaTests/KPcaTestTransDataF
[ RUN ] KPcaTests/KPcaTestTransDataF.Result/0
[ OK ] KPcaTests/KPcaTestTransDataF.Result/0 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataF.Result/1
[ OK ] KPcaTests/KPcaTestTransDataF.Result/1 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataF.Result/2
[ OK ] KPcaTests/KPcaTestTransDataF.Result/2 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataF.Result/3
[ OK ] KPcaTests/KPcaTestTransDataF.Result/3 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataF.Result/4
[ OK ] KPcaTests/KPcaTestTransDataF.Result/4 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataF.Result/5
[ OK ] KPcaTests/KPcaTestTransDataF.Result/5 (2 ms)
[----------] 6 tests from KPcaTests/KPcaTestTransDataF (13 ms total)

[----------] 6 tests from KPcaTests/KPcaTestEigenvaluesD
[ RUN ] KPcaTests/KPcaTestEigenvaluesD.Result/0
[ OK ] KPcaTests/KPcaTestEigenvaluesD.Result/0 (9 ms)
[ RUN ] KPcaTests/KPcaTestEigenvaluesD.Result/1
[ OK ] KPcaTests/KPcaTestEigenvaluesD.Result/1 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvaluesD.Result/2
[ OK ] KPcaTests/KPcaTestEigenvaluesD.Result/2 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvaluesD.Result/3
[ OK ] KPcaTests/KPcaTestEigenvaluesD.Result/3 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvaluesD.Result/4
[ OK ] KPcaTests/KPcaTestEigenvaluesD.Result/4 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvaluesD.Result/5
[ OK ] KPcaTests/KPcaTestEigenvaluesD.Result/5 (2 ms)
[----------] 6 tests from KPcaTests/KPcaTestEigenvaluesD (23 ms total)

[----------] 6 tests from KPcaTests/KPcaTestEigenvectorsD
[ RUN ] KPcaTests/KPcaTestEigenvectorsD.Result/0
[ OK ] KPcaTests/KPcaTestEigenvectorsD.Result/0 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsD.Result/1
[ OK ] KPcaTests/KPcaTestEigenvectorsD.Result/1 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsD.Result/2
[ OK ] KPcaTests/KPcaTestEigenvectorsD.Result/2 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsD.Result/3
[ OK ] KPcaTests/KPcaTestEigenvectorsD.Result/3 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsD.Result/4
[ OK ] KPcaTests/KPcaTestEigenvectorsD.Result/4 (2 ms)
[ RUN ] KPcaTests/KPcaTestEigenvectorsD.Result/5
[ OK ] KPcaTests/KPcaTestEigenvectorsD.Result/5 (2 ms)
[----------] 6 tests from KPcaTests/KPcaTestEigenvectorsD (16 ms total)

[----------] 6 tests from KPcaTests/KPcaTestTransDataD
[ RUN ] KPcaTests/KPcaTestTransDataD.Result/0
[ OK ] KPcaTests/KPcaTestTransDataD.Result/0 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataD.Result/1
[ OK ] KPcaTests/KPcaTestTransDataD.Result/1 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataD.Result/2
[ OK ] KPcaTests/KPcaTestTransDataD.Result/2 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataD.Result/3
[ OK ] KPcaTests/KPcaTestTransDataD.Result/3 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataD.Result/4
[ OK ] KPcaTests/KPcaTestTransDataD.Result/4 (2 ms)
[ RUN ] KPcaTests/KPcaTestTransDataD.Result/5
[ OK ] KPcaTests/KPcaTestTransDataD.Result/5 (2 ms)
[----------] 6 tests from KPcaTests/KPcaTestTransDataD (16 ms total)

[----------] Global test environment tear-down
[==========] 36 tests from 6 test suites ran. (525 ms total)
[ PASSED ] 36 tests.

@tomasjoh tomasjoh requested review from a team as code owners July 27, 2024 18:07
Copy link

copy-pr-bot bot commented Jul 27, 2024

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Comment on lines +53 to +54
// All the positive eigenvalues that are too small (with a value smaller than the maximum
// eigenvalue multiplied by for double precision 1e-12 (2e-7 for float)) are set to zero.
Copy link
Author

@tomasjoh tomasjoh Jul 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When removing zero eigen values, we treat values below a certain threshold as zero. See reference from sklearn: https://github.com/scikit-learn/scikit-learn/blob/b72e81af473c079ae95314efbca86557a836defa/sklearn/utils/validation.py#L1901-L1903

@tomasjoh tomasjoh changed the title Kernel PCA Algorithm C++ Kernel PCA: C++ Algorithm Implementation Jul 27, 2024
@tomasjoh tomasjoh mentioned this pull request Jul 27, 2024
22 tasks
@dantegd
Copy link
Member

dantegd commented Jul 27, 2024

/ok to test

@dantegd dantegd added feature request New feature or request non-breaking Non-breaking change labels Jul 27, 2024
@dantegd
Copy link
Member

dantegd commented Jul 28, 2024

/ok to test

public:
MLCommon::Matrix::KernelParams kernel;
size_t n_training_samples = 0;
bool copy = true; // TODO unused
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left these TODO's in here for future compatability and since paramsPCATemplate does the same for copy.

@tomasjoh tomasjoh changed the base branch from branch-24.08 to branch-24.10 August 7, 2024 17:06
@dantegd
Copy link
Member

dantegd commented Aug 13, 2024

/ok to test

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I've given it a look over and I think it's shaping up well. Most of the things are minor, but important to maintenance and quality. Looking forward to having this feature.

Small question- what data scales are you targeting? The unfortunate thing about the naive implementation of kernel algorithms is that they require n^2 space, whereas there are some ways we can make them scale by computing more on the fly (and even sparsifying the kernel gramm by using nearest neighbors methods). Our Kernel APIs also allow for caching and tiling so if we can use an iterative solver, we can scale to much larger datasets without having to resort to multiple GPUs.

CUML_KERNEL void subtractMeanKernel(
T* mat, const T* row_means, const T* col_means, T overall_mean, int n_rows, int n_cols)
{
const int row = blockIdx.x * blockDim.x + threadIdx.x;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have primitives to do this in raft and we prefer using them wherever possible to centralize impls.

const value_t* sqrt_vals,
size_t n_training_samples,
size_t n_components)
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please utilize primitives in RAFT to do this instead of writing raw kernels.


#include <raft/core/handle.hpp>
#include <raft/distance/kernels.cuh>
#include <raft/linalg/detail/cublas_wrappers.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know there are still a few places in cuML that are using these wrappers but these are in a detail namespace for a reason. We should only be using public API functions from RAFT (and exposing addiitonal ones if needed).

// Step 2: Compute overall mean
value_t overall_mean;
thrust::device_ptr<value_t> d_kernel_mat_ptr(kernel_mat.data());
value_t sum = thrust::reduce(thrust_policy,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use corresponding primitives in RAFT to do this. In general we centralize computations like this in RAFT so that we can optimize and update them when needed (since we have control over RAFT APIs / implementations but not Thrust).

rmm::device_uvector<value_t> fitted_kernel_mat(prms.n_training_samples * prms.n_training_samples,
stream);
auto thrust_policy = rmm::exec_policy(stream);
thrust::fill(thrust_policy, kernel_mat.begin(), kernel_mat.end(), 0.0f);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use raft::matrix::fill for this.

// Step 3: Compute overall mean
value_t overall_mean;
thrust::device_ptr<value_t> d_kernel_mat_ptr(fitted_kernel_mat.data());
value_t sum = thrust::reduce(thrust_policy,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use raft for reductions.


#include <cuml/decomposition/params.hpp>

#include <raft/linalg/detail/cublas_wrappers.hpp>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I knkow I mentioned earlier not to use these wrappers, and it actually looks like the wrappers aren't being used in your code anywhere. Can you please go through these imports (across your PR) and make sure you are only importanting what is being used?

int algo = 1;
std::vector<float> data_h = {1.0, 2.0, 5.0, 4.0, 2.0, 1.0};

raft::distance::kernels::KernelParams lin_kern = {raft::distance::kernels::LINEAR, 0, 0, 0};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hardcoding the values here, I think we can use our existing PCA as a source of ground truth and validate our results match when we compute the kernel gramm + PCA. That would at least remove a level of hardcoding. This is really hard to maintain. Ideally we'll eventually update our existing PCA tests to compare against a naive PCA solver (like simple manual power iteration) so that we can remove the hardcoding there too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean doing something like this?

import numpy as np
from sklearn.decomposition import PCA, KernelPCA
from sklearn.metrics.pairwise import rbf_kernel

X = np.array([
    [1.0, 4.0],
    [2.0, 2.0],
    [5.0, 1.0],
    [0.0, 3.0],
    [3.0, 5.0],
    [6.0, 2.0],
    [7.0, 8.0],
    [8.0, 1.0],
    [9.0, 6.0],
    [4.0, 7.0]
])

gamma = 0.5
K = rbf_kernel(X, gamma=gamma)

N = K.shape[0]
one_n = np.ones((N, N)) / N
K_centered = K - one_n @ K - K @ one_n + one_n @ K @ one_n

pca = PCA(n_components=2)
X_pca = pca.fit_transform(K_centered)

kpca = KernelPCA(kernel='rbf', gamma=gamma, n_components=2)
X_kpca = kpca.fit_transform(X)

# Results
print("PCA Transformed Data:", X_pca, "KPCA Transformed Data:", X_kpca, sep="\n",end="\n\n")
print("PCA explained_variance_:", pca.explained_variance_, "KPCA eigenvalues_:", kpca.eigenvalues_, sep="\n", end="\n\n")
print("PCA components_:", pca.components_, "KPCA eigenvectors_:", kpca.eigenvectors_, sep="\n", end="\n\n")

Looking at the results I see that transformed data and eigenvalues aren't the same between pca/kpca. I haven't had the chance to dig deeper, but I'm wondering if there could be differences in the two algorithms that makes this approach not feasible.

As an alternative, I could rewrite the tests, so we can remove about a lot of the duplicate code for better readability.

Maybe another way to make the test cleaner would be to read the input and expected output from a file. The file could be the output from a Python script using sklearn.KernelPCA.

@tomasjoh
Copy link
Author

tomasjoh commented Sep 2, 2024

Thanks for the PR. I've given it a look over and I think it's shaping up well. Most of the things are minor, but important to maintenance and quality. Looking forward to having this feature.

Small question- what data scales are you targeting? The unfortunate thing about the naive implementation of kernel algorithms is that they require n^2 space, whereas there are some ways we can make them scale by computing more on the fly (and even sparsifying the kernel gramm by using nearest neighbors methods). Our Kernel APIs also allow for caching and tiling so if we can use an iterative solver, we can scale to much larger datasets without having to resort to multiple GPUs.

Thank you for the review. I have started incorporating most of the requested changes, but still working one replacing one of the raw kernels and tests. I just hit a busy period, so I probably won't be able to make the updates until beginning of October.

Regarding your question on targeted data scales: I observed the algorithm being memory bound in benchmarking. It would be great to be able to support larger matrices. Could we aim for this for V2 of the algorithm? I might need some guidance on the changes needed for kernel caching + tiling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake CUDA/C++ feature request New feature or request non-breaking Non-breaking change
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

3 participants