forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTHCSortUtils.cuh
92 lines (76 loc) · 2.66 KB
/
THCSortUtils.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#ifndef THC_SORT_UTILS_INC
#define THC_SORT_UTILS_INC
#include <THC/THCReduceApplyUtils.cuh>
#include <THC/THCTensorTypeUtils.cuh>
#include <THC/THCNumerics.cuh>
#include <c10/macros/Macros.h>
// Collection of kernel sort routines
template <typename T, bool handleNaN = false>
struct LTComp {
__device__ inline bool operator()(const T& a, const T& b) const {
return (handleNaN && THCNumerics<T>::isnan(b) && !THCNumerics<T>::isnan(a)) || THCNumerics<T>::lt(a, b);
}
};
template <typename T, bool handleNaN = false>
struct GTComp {
__device__ inline bool operator()(const T& a, const T& b) const {
return (handleNaN && THCNumerics<T>::isnan(a) && !THCNumerics<T>::isnan(b)) || THCNumerics<T>::gt(a, b);
}
};
template <typename T>
__device__ inline void swapVars(T& t1, T& t2) {
T tmp = t1;
t1 = t2;
t2 = tmp;
}
template <typename Comparator, typename K, typename V>
__device__ inline void bitonicSwap(K& kA, V& vA, bool& validA,
K& kB, V& vB, bool& validB,
bool dir,
const Comparator& comp) {
// Invalid entries always sort to the end
bool swap = (comp(kA, kB) && validA) || !validB;
if (swap == dir) {
swapVars(kA, kB);
swapVars(vA, vB);
swapVars(validA, validB);
}
};
template <typename Comparator, typename K, typename V,
typename IndexType, int Power2SortSize>
__device__ inline void bitonicSort(K keys[Power2SortSize],
V values[Power2SortSize],
bool valid[Power2SortSize],
const Comparator& comp) {
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
bool flag = ((threadIdx.x & (size / 2)) != 0);
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
__syncthreads();
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos], valid[pos],
keys[pos + stride], values[pos + stride], valid[pos + stride],
flag, comp);
}
}
#ifndef __HIP_PLATFORM_HCC__
#pragma unroll
#endif
for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
__syncthreads();
unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
bitonicSwap<Comparator, K, V>(
keys[pos], values[pos], valid[pos],
keys[pos + stride], values[pos + stride], valid[pos + stride],
false, comp);
}
__syncthreads();
}
uint64_t nextHighestPowerOf2(uint64_t n);
#endif // THC_SORT_UTILS_INC