diff --git a/adet/layers/csrc/ml_nms/ml_nms.cu b/adet/layers/csrc/ml_nms/ml_nms.cu index f1c1a4206..c565da7e5 100644 --- a/adet/layers/csrc/ml_nms/ml_nms.cu +++ b/adet/layers/csrc/ml_nms/ml_nms.cu @@ -2,8 +2,9 @@ #include #include #include -#include - +#include +#include +#include #include #include @@ -65,7 +66,7 @@ __global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh, t |= 1ULL << i; } } - const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock); dev_mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -82,20 +83,20 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) { int boxes_num = boxes.size(0); - const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock); scalar_t* boxes_dev = boxes_sorted.data(); - THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState + // THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState unsigned long long* mask_dev = NULL; //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, // boxes_num * col_blocks * sizeof(unsigned long long))); - mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long)); - dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), - THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock), + at::ceil_div(boxes_num, threadsPerBlock)); dim3 threads(threadsPerBlock); ml_nms_kernel<<>>(boxes_num, nms_overlap_thresh, @@ -103,7 +104,7 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) { mask_dev); std::vector mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpy(&mask_host[0], + c10::cuda::CUDACachingAllocator::raw_alloc(cudaMemcpy(&mask_host[0], mask_dev, sizeof(unsigned long long) * boxes_num * col_blocks, cudaMemcpyDeviceToHost)); @@ -128,7 +129,7 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) { } } - THCudaFree(state, mask_dev); + c10::cuda::CUDACachingAllocator::raw_delete(mask_dev); // TODO improve this part return std::get<0>(order_t.index({ keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to(