Skip to content

Commit

Permalink
Merge pull request #45 from nicolas-chaulet/bug/ballquery
Browse files Browse the repository at this point in the history
Fix various annoying bugs
  • Loading branch information
nicolas-chaulet authored Jul 8, 2020
2 parents 7994646 + e7ac33c commit 83ac174
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
rev: stable
hooks:
- id: black
language_version: python3.6
language_version: python3.7
args: ["--config", ".black.toml"]
- repo: local
hooks:
Expand Down
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# 0.6.6
## Additions
- Windows support


## Change
- Develop with python 3.7

## Bug fix
- Fixed bug in region growing related to batching
- Ball query for partial dense data on GPU was returning only the first point. Fixed now


# 0.6.5

## Additions
Expand Down
5 changes: 2 additions & 3 deletions cuda/src/ball_query_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ __global__ void query_ball_point_kernel_partial_dense(
// taken from
// https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;

const ptrdiff_t start_idx_x = batch_x[batch_idx];
const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];
Expand All @@ -68,10 +67,10 @@ __global__ void query_ball_point_kernel_partial_dense(
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
float radius2 = radius * radius;

for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS_SPARSE)
for (ptrdiff_t n_y = start_idx_y + threadIdx.x; n_y < end_idx_y; n_y += blockDim.x)
{
int64_t count = 0;
for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++)
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++)
{
float dist = 0;
for (ptrdiff_t d = 0; d < 3; d++)
Expand Down
16 changes: 12 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def get_ext_modules():
extra_compile_args += ["-DVERSION_GE_1_3"]

ext_src_root = "cuda"
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root))
ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob(
"{}/src/*.cu".format(ext_src_root)
)

ext_modules = []
if CUDA_HOME:
Expand All @@ -37,7 +39,10 @@ def get_ext_modules():
name="torch_points_kernels.points_cuda",
sources=ext_sources,
include_dirs=["{}/include".format(ext_src_root)],
extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,},
extra_compile_args={
"cxx": extra_compile_args,
"nvcc": extra_compile_args,
},
)
)

Expand Down Expand Up @@ -67,7 +72,7 @@ def get_cmdclass():
requirements = ["torch>=1.1.0", "numba", "scikit-learn"]

url = "https://github.com/nicolas-chaulet/torch-points-kernels"
__version__ = "0.6.5"
__version__ = "0.6.6"
setup(
name="torch-points-kernels",
version=__version__,
Expand All @@ -81,5 +86,8 @@ def get_cmdclass():
cmdclass=get_cmdclass(),
long_description=long_description,
long_description_content_type="text/markdown",
classifiers=["Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
],
)
48 changes: 33 additions & 15 deletions test/test_ballquerry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,18 @@ def test_cpu_gpu_equality(self):
class TestBallPartial(unittest.TestCase):
@run_if_cuda
def test_simple_gpu(self):
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float).cuda()
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [0.2, 0, 0], [0.1, 0, 0]]).to(torch.float).cuda()
y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda()
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
batch_x = torch.from_numpy(np.asarray([0, 0, 0, 1])).long().cuda()
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()

batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda()
batch_y = torch.from_numpy(np.asarray([0])).long().cuda()

idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
idx, dist2 = ball_query(0.2, 4, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)

idx = idx.detach().cpu().numpy()
dist2 = dist2.detach().cpu().numpy()

idx_answer = np.asarray([[1, -1]])
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
idx_answer = np.asarray([[1, 2, -1, -1]])
dist2_answer = np.asarray([[0.0100, 0.04, -1, -1]]).astype(np.float32)

npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)
Expand All @@ -98,30 +95,29 @@ def test_simple_cpu(self):
npt.assert_array_almost_equal(idx, idx_answer)
npt.assert_array_almost_equal(dist2, dist2_answer)


def test_breaks(self):
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
y = torch.tensor([[0, 0, 0]]).to(torch.float)

batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
batch_y = torch.from_numpy(np.asarray([0])).long()

with self.assertRaises(RuntimeError):
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)

def test_random_cpu(self):
def test_random_cpu(self, cuda=False):
a = torch.randn(100, 3).to(torch.float)
b = torch.randn(50, 3).to(torch.float)
batch_a = torch.tensor([0 for i in range(a.shape[0] // 2)] + [1 for i in range(a.shape[0] // 2, a.shape[0])])
batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])])
R = 1

idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,)
torch.testing.assert_allclose(idx1, idx)
with self.assertRaises(AssertionError):
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,)
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,)
torch.testing.assert_allclose(idx1, idx)

self.assertEqual(idx.shape[0], b.shape[0])
Expand All @@ -136,6 +132,28 @@ def test_random_cpu(self):
if p >= 0 and p < len(batch_a):
assert p in idx3_sk[i]

@run_if_cuda
def test_random_gpu(self):
a = torch.randn(100, 3).to(torch.float).cuda()
b = torch.randn(50, 3).to(torch.float).cuda()
batch_a = torch.tensor(
[0 for i in range(a.shape[0] // 2)] + [1 for i in range(a.shape[0] // 2, a.shape[0])]
).cuda()
batch_b = torch.tensor(
[0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])]
).cuda()
R = 1

idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,)

# Comparison to see if we have the same result
tree = KDTree(a.cpu().detach().numpy())
idx3_sk = tree.query_radius(b.cpu().detach().numpy(), r=R)
i = np.random.randint(len(batch_b))
for p in idx[i].cpu().detach().numpy():
if p >= 0 and p < len(batch_a):
assert p in idx3_sk[i]


if __name__ == "__main__":
unittest.main()
12 changes: 11 additions & 1 deletion torch_points_kernels/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,19 @@ def region_grow(
# Build clusters for a given label (ignore other points)
label_mask = labels == l
local_ind = ind[label_mask]

# Remap batch to a continuous sequence
label_batch = batch[label_mask]
unique_in_batch = torch.unique(label_batch)
remaped_batch = torch.empty_like(label_batch)
for new, old in enumerate(unique_in_batch):
mask = label_batch == old
remaped_batch[mask] = new

# Cluster
label_clusters = grow_proximity(
pos[label_mask, :],
batch[label_mask],
remaped_batch,
nsample=nsample,
radius=radius,
min_cluster_size=min_cluster_size,
Expand Down

0 comments on commit 83ac174

Please sign in to comment.