Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make neighbor list compatible with float16 and bfloat16 #273

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions torchmdnet/extensions/neighbors/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>

using at::BFloat16;
using at::Half;
using c10::cuda::CUDAStreamGuard;
using c10::cuda::getCurrentCUDAStream;
using torch::empty;
Expand All @@ -23,6 +25,9 @@ using torch::autograd::AutogradContext;
using torch::autograd::Function;
using torch::autograd::tensor_list;

#define DISPATCH_FOR_ALL_FLOAT_TYPES(...)\
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, __VA_ARGS__)\

template <typename scalar_t, int num_dims>
using Accessor = torch::PackedTensorAccessor32<scalar_t, num_dims, torch::RestrictPtrTraits>;

Expand All @@ -34,14 +39,6 @@ inline Accessor<scalar_t, num_dims> get_accessor(const Tensor& tensor) {
return tensor.packed_accessor32<scalar_t, num_dims, torch::RestrictPtrTraits>();
};

template <typename scalar_t> __device__ __forceinline__ scalar_t sqrt_(scalar_t x){};
template <> __device__ __forceinline__ float sqrt_(float x) {
return ::sqrtf(x);
};
template <> __device__ __forceinline__ double sqrt_(double x) {
return ::sqrt(x);
};

template <typename scalar_t> struct vec3 {
using type = void;
};
Expand All @@ -54,6 +51,22 @@ template <> struct vec3<double> {
using type = double3;
};

struct Half3 {
Half x, y, z;
};

template <> struct vec3<Half> {
using type = Half3;
};

struct BFloat163 {
BFloat16 x, y, z;
};

template <> struct vec3<BFloat16> {
using type = BFloat163;
};

template <typename scalar_t> using scalar3 = typename vec3<scalar_t>::type;

/*
Expand Down Expand Up @@ -194,12 +207,12 @@ __device__ auto apply_pbc(scalar3<scalar_t> delta, const KernelAccessor<scalar_t
return delta;
}

template <typename scalar_t>
template <typename scalar_t>
__device__ auto compute_distance(scalar3<scalar_t> pos_i, scalar3<scalar_t> pos_j,
bool use_periodic, const KernelAccessor<scalar_t, 2>& box) {
scalar3<scalar_t> delta = {pos_i.x - pos_j.x, pos_i.y - pos_j.y, pos_i.z - pos_j.z};
if (use_periodic) {
delta = apply_pbc(delta, box);
delta = apply_pbc(delta, box);
}
return delta;
}
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/extensions/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ forward(const Tensor& positions, const Tensor& batch, const Tensor& in_box_vecto
deltas.index_put_({Slice(), 0}, deltas.index({Slice(), 0}) -
scale1 * box_vectors.index({pair_batch, 0, 0}));
}
distances = frobenius_norm(deltas, 1);
distances = torch::linalg::norm(deltas, c10::nullopt, 1, false, c10::nullopt);
mask = (distances < cutoff_upper) * (distances >= cutoff_lower);
neighbors = neighbors.index({Slice(), mask});
deltas = deltas.index({mask, Slice()});
Expand Down
4 changes: 2 additions & 2 deletions torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ __global__ void forward_kernel_brute(uint32_t num_all_pairs, const Accessor<scal
const auto delta = triclinic::compute_distance(pos_i, pos_j, list.use_periodic, box_row);
const scalar_t distance2 = delta.x * delta.x + delta.y * delta.y + delta.z * delta.z;
if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) {
const scalar_t r2 = sqrt_(distance2);
const scalar_t r2 = ::sqrt(distance2);
addAtomPairToList(list, row, column, delta, r2, list.include_transpose);
}
}
Expand Down Expand Up @@ -100,7 +100,7 @@ forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box
const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL;
const uint64_t num_threads = 128;
const uint64_t num_blocks = std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL);
AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() {
DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() {
PairListAccessor<scalar_t> list_accessor(list);
auto box = triclinic::get_box_accessor<scalar_t>(box_vectors, use_periodic);
const scalar_t cutoff_upper_ = cutoff_upper.to<scalar_t>();
Expand Down
8 changes: 4 additions & 4 deletions torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ static auto sortAtomsByCellIndex(const Tensor& positions, const Tensor& box_size
const int threads = 128;
const int blocks = (num_atoms + threads - 1) / threads;
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "assignHash", [&] {
DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "assignHash", [&] {
scalar_t cutoff_ = cutoff.to<scalar_t>();
scalar3<scalar_t> box_size_ = {box_size[0][0].item<scalar_t>(),
box_size[1][1].item<scalar_t>(),
Expand Down Expand Up @@ -229,7 +229,7 @@ CellList constructCellList(const Tensor& positions, const Tensor& batch, const T
cl.sorted_batch = batch.index_select(0, cl.sorted_indices);
// Step 3
int3 cell_dim;
AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "computeCellDim", [&] {
DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "computeCellDim", [&] {
scalar_t cutoff_ = cutoff.to<scalar_t>();
scalar3<scalar_t> box_size_ = {box_size[0][0].item<scalar_t>(),
box_size[1][1].item<scalar_t>(),
Expand Down Expand Up @@ -270,7 +270,7 @@ __device__ void addNeighborPair(PairListAccessor<scalar_t>& list, const int i, c
const int ni = max(i, j);
const int nj = min(i, j);
const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0);
const scalar_t distance = sqrt_(distance2);
const scalar_t distance = ::sqrt(distance2);
delta = {delta_sign * delta.x, delta_sign * delta.y, delta_sign * delta.z};
addAtomPairToList(list, ni, nj, delta, distance, requires_transpose);
}
Expand Down Expand Up @@ -368,7 +368,7 @@ forward_cell(const Tensor& positions, const Tensor& batch, const Tensor& in_box_
const auto stream = getCurrentCUDAStream(positions.get_device());
{ // Traverse the cell list to find the neighbors
const CUDAStreamGuard guard(stream);
AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "forward", [&] {
DISPATCH_FOR_ALL_FLOAT_TYPES(positions.scalar_type(), "forward", [&] {
const scalar_t cutoff_upper_ = cutoff_upper.to<scalar_t>();
TORCH_CHECK(cutoff_upper_ > 0, "Expected cutoff_upper to be positive");
const scalar_t cutoff_lower_ = cutoff_lower.to<scalar_t>();
Expand Down
33 changes: 17 additions & 16 deletions torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor<scalar_
delta.x * delta.x + delta.y * delta.y + delta.z * delta.z;
if (distance2 < cutoff_upper2 && distance2 >= cutoff_lower2) {
const bool requires_transpose = list.include_transpose && !(cur_j == id);
const auto distance = sqrt_(distance2);
const scalar_t distance = ::sqrt(distance2);
addAtomPairToList(list, id, cur_j, delta, distance, requires_transpose);
}
}
Expand Down Expand Up @@ -104,21 +104,22 @@ forward_shared(const Tensor& positions, const Tensor& batch, const Tensor& in_bo
const auto stream = getCurrentCUDAStream(positions.get_device());
PairList list(num_pairs, positions.options(), loop, include_transpose, use_periodic);
const CUDAStreamGuard guard(stream);
AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() {
const scalar_t cutoff_upper_ = cutoff_upper.to<scalar_t>();
const scalar_t cutoff_lower_ = cutoff_lower.to<scalar_t>();
auto box = triclinic::get_box_accessor<scalar_t>(box_vectors, use_periodic);
TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive");
constexpr int BLOCKSIZE = 64;
const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1);
const int num_threads = BLOCKSIZE;
const int num_tiles = num_blocks;
PairListAccessor<scalar_t> list_accessor(list);
forward_kernel_shared<BLOCKSIZE><<<num_blocks, num_threads, 0, stream>>>(
num_atoms, get_accessor<scalar_t, 2>(positions), get_accessor<int64_t, 1>(batch),
cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor, num_tiles,
box);
});
DISPATCH_FOR_ALL_FLOAT_TYPES(
positions.scalar_type(), "get_neighbor_pairs_shared_forward", [&]() {
const scalar_t cutoff_upper_ = cutoff_upper.to<scalar_t>();
const scalar_t cutoff_lower_ = cutoff_lower.to<scalar_t>();
auto box = triclinic::get_box_accessor<scalar_t>(box_vectors, use_periodic);
TORCH_CHECK(cutoff_upper_ > 0, "Expected \"cutoff\" to be positive");
constexpr int BLOCKSIZE = 64;
const int num_blocks = std::max((num_atoms + BLOCKSIZE - 1) / BLOCKSIZE, 1);
const int num_threads = BLOCKSIZE;
const int num_tiles = num_blocks;
PairListAccessor<scalar_t> list_accessor(list);
forward_kernel_shared<BLOCKSIZE><<<num_blocks, num_threads, 0, stream>>>(
num_atoms, get_accessor<scalar_t, 2>(positions), get_accessor<int64_t, 1>(batch),
cutoff_lower_ * cutoff_lower_, cutoff_upper_ * cutoff_upper_, list_accessor,
num_tiles, box);
});
return {list.neighbors, list.deltas, list.distances, list.i_curr_pair};
}

Expand Down
Loading