Skip to content

Commit

Permalink
Merge pull request #10 from nicolas-chaulet/singleapi
Browse files Browse the repository at this point in the history
Singleapi
  • Loading branch information
nicolas-chaulet authored Jan 13, 2020
2 parents 2884bcc + 3c9a8e4 commit 9c1355e
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 149 deletions.
7 changes: 4 additions & 3 deletions cuda/include/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,19 @@

#include <vector>

#define TOTAL_THREADS 1024
#define TOTAL_THREADS_DENSE 512
#define TOTAL_THREADS_SPARSE 1024

inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

return max(min(1 << pow_2, TOTAL_THREADS), 1);
return max(min(1 << pow_2, TOTAL_THREADS_DENSE), 1);
}

inline dim3 opt_block_config(int x, int y) {
const int x_threads = opt_n_threads(x);
const int y_threads =
max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
max(min(opt_n_threads(y), TOTAL_THREADS_DENSE / x_threads), 1);
dim3 block_config(x_threads, y_threads, 1);

return block_config;
Expand Down
4 changes: 2 additions & 2 deletions cuda/src/ball_query_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ __global__ void query_ball_point_kernel_partial_dense(int size_x,
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
float radius2 = radius * radius;

for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS) {
for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS_SPARSE) {
int64_t count = 0;
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) {
float dist = 0;
Expand Down Expand Up @@ -108,7 +108,7 @@ void query_ball_point_kernel_partial_wrapper(long batch_size,
int64_t *idx_out,
float *dist_out) {

query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS>>>(
query_ball_point_kernel_partial_dense<<<batch_size, TOTAL_THREADS_SPARSE>>>(
size_x, size_y, radius, nsample, x, y,
batch_x, batch_y, idx_out, dist_out);

Expand Down
13 changes: 10 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, CppExtension
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
CUDA_HOME,
CppExtension,
)
import glob

ext_src_root = "cuda"
Expand Down Expand Up @@ -33,12 +38,14 @@
)
)

requirements = ["torch^1.1.0"]

setup(
name="torch_points",
version="0.1.3",
version="0.1.5",
author="Nicolas Chaulet",
packages=find_packages(),
install_requires=[],
install_requires=requirements,
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)
82 changes: 74 additions & 8 deletions test/test_ballquerry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import torch
from torch_points import ball_query_dense
from torch_points import ball_query
import numpy.testing as npt
import numpy as np

Expand All @@ -10,22 +10,88 @@ def test_simple_gpu(self):
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
b = torch.tensor([[[0, 0, 0]]]).to(torch.float).cuda()

npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]]))
npt.assert_array_equal(
ball_query(1, 2, a, b).detach().cpu().numpy(), np.array([[[0, 0]]])
)

def test_simple_cpu(self):
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
b = torch.tensor([[[0, 0, 0]]]).to(torch.float)
npt.assert_array_equal(ball_query_dense(1, 2, a, b).detach().numpy(), np.array([[[0, 0]]]))
def test_larger_gpu(self):
a = torch.randn(32, 4096, 3).to(torch.float).cuda()
idx = ball_query(1, 64, a, a).detach().cpu().numpy()
self.assertGreaterEqual(idx.min(), 0)

def test_cpu_gpu_equality(self):
a = torch.randn(5, 1000, 3)
res_cpu = ball_query_dense(0.1, 17, a, a).detach().numpy()
res_cuda = ball_query_dense(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy()
res_cpu = ball_query(0.1, 17, a, a).detach().numpy()
res_cuda = ball_query(0.1, 17, a.cuda(), a.cuda()).cpu().detach().numpy()
for i in range(a.shape[0]):
for j in range(a.shape[1]):
# Because it is not necessary the same order
assert set(res_cpu[i][j]) == set(res_cuda[i][j])


class TestBallPartial(unittest.TestCase):
def test_simple_gpu(self):
x = (
torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]])
.to(torch.float)
.cuda()
)
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()

batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()

idx, dist2 = ball_query(
1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y
)

idx = idx.detach().cpu().numpy()
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, 4]])
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)

npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)

def test_simple_cpu(self):
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(
torch.float
)
y = torch.tensor([[0, 0, 0]]).to(torch.float)

batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
batch_y = torch.from_numpy(np.asarray([0])).long()

idx, dist2 = ball_query(
1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y
)

idx = idx.detach().cpu().numpy()
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, 1], [0, 1], [1, 1], [1, 1]])
dist2_answer = np.asarray([[-1, -1], [0.01, -1], [-1, -1], [-1, -1]]).astype(
np.float32
)

npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)

def test_random_cpu(self):
a = torch.randn(1000, 3).to(torch.float)
b = torch.randn(1500, 3).to(torch.float)
batch_a = torch.randint(1, (1000,)).sort(0)[0].long()
batch_b = torch.randint(1, (1500,)).sort(0)[0].long()
idx, dist = ball_query(
1.0, 12, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b
)
idx2, dist2 = ball_query(
1.0, 12, b, a, mode="PARTIAL_DENSE", batch_x=batch_b, batch_y=batch_a
)


if __name__ == "__main__":
unittest.main()
69 changes: 0 additions & 69 deletions test/test_ballquerry_partial.py

This file was deleted.

109 changes: 45 additions & 64 deletions torch_points/torchpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.autograd import Function
import torch.nn as nn
import sys
from typing import Optional

import torch_points.points_cpu as tpcpu

Expand Down Expand Up @@ -251,98 +252,78 @@ def forward(ctx, radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=No
if new_xyz.is_cuda:
return tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
else:
ind, dist = tpcpu.dense_ball_query(new_xyz,
xyz,
radius, nsample, mode=0)
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0)
return ind

@staticmethod
def backward(ctx, a=None):
return None, None, None, None


def ball_query_dense(radius, nsample, xyz, new_xyz):
r"""
Parameters
----------
radius : float
radius of the balls
nsample : int
maximum number of features in the balls
xyz : torch.Tensor
(B, N, 3) xyz coordinates of the features
new_xyz : torch.Tensor
(B, npoint, 3) centers of the ball query
Returns
-------
torch.Tensor
(B, npoint, nsample) tensor with the indicies of the features that form the query balls
"""
return BallQueryDense.apply(radius, nsample, xyz, new_xyz)


class BallQueryPartialDense(Function):
@staticmethod
def forward(ctx, radius, nsample, x, y, batch_x, batch_y):
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
if x.is_cuda:
return tpcuda.ball_query_partial_dense(x, y,
batch_x,
batch_y,
radius, nsample)
return tpcuda.ball_query_partial_dense(
x, y, batch_x, batch_y, radius, nsample
)
else:
ind, dist = tpcpu.batch_ball_query(x, y,
batch_x,
batch_y,
radius, nsample, mode=0)
ind, dist = tpcpu.batch_ball_query(
x, y, batch_x, batch_y, radius, nsample, mode=0
)
return ind, dist

@staticmethod
def backward(ctx, a=None):
return None, None, None, None


def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y):
r"""
Parameters
----------
radius : float
radius of the balls
nsample : int
maximum number of features in the balls
x : torch.Tensor
(M, 3) xyz coordinates of the features (The neighbours are going to be looked for there)
y : torch.Tensor
(N, npoint, 3) centers of the ball query
batch_x : torch.Tensor
(M, ) Contains indexes to indicate within batch it belongs to.
batch_y : torch.Tensor
(N, ) Contains indexes to indicate within batch it belongs to
Returns
-------
torch.Tensor
idx: (N, nsample) Default value: N. It contains the indexes of the element within y at radius distance to x
dist2: (N, nsample) Default value: -1. It contains the square distances of the element within y at radius distance to x
def ball_query(
radius: float,
nsample: int,
x: torch.Tensor,
y: torch.Tensor,
mode: Optional[str] = "dense",
batch_x: Optional[torch.tensor] = None,
batch_y: Optional[torch.tensor] = None,
) -> torch.Tensor:
"""
Arguments:
radius {float} -- radius of the balls
nsample {int} -- maximum number of features in the balls
x {torch.Tensor} --
(M, 3) [partial_dense] or (B, M, 3) [dense] xyz coordinates of the features
y {torch.Tensor} --
(npoint, 3) [partial_dense] or or (B, npoint, 3) [dense] centers of the ball query
mode {str} -- switch between "dense" or "partial_dense" data layout
Keyword Arguments:
batch_x -- (M, ) [partial_dense] or (B, M, 3) [dense] Contains indexes to indicate within batch it belongs to.
batch_y -- (N, ) Contains indexes to indicate within batch it belongs to
Returns:
idx: (npoint, nsample) or (B, npoint, nsample) [dense] It contains the indexes of the element within x at radius distance to y
OPTIONAL[partial_dense] dist2: (N, nsample) Default value: -1.
It contains the square distances of the element within x at radius distance to y
"""
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)


def ball_query(radius: float, nsample: int, x, y, batch_x=None, batch_y=None, mode=None):
if mode is None:
raise Exception('The mode should be defined within ["PARTIAL_DENSE | DENSE"]')
raise Exception('The mode should be defined within ["partial_dense | dense"]')

if mode.lower() == "partial_dense":
if (batch_x is None) or (batch_y is None):
raise Exception('batch_x and batch_y should be provided')
raise Exception("batch_x and batch_y should be provided")
assert x.size(0) == batch_x.size(0)
assert y.size(0) == batch_y.size(0)
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y)
assert x.dim() == 2
return BallQueryPartialDense.apply(radius, nsample, x, y, batch_x, batch_y)

elif mode.lower() == "dense":
if (batch_x is not None) or (batch_y is not None):
raise Exception('batch_x and batch_y should not be provided')
return ball_query_dense(radius, nsample, x, y)
raise Exception("batch_x and batch_y should not be provided")
assert x.dim() == 3
return BallQueryDense.apply(radius, nsample, x, y)
else:
raise Exception('unrecognized mode {}'.format(mode))
raise Exception("unrecognized mode {}".format(mode))

0 comments on commit 9c1355e

Please sign in to comment.