diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7adfaff..9e24bb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,7 +12,7 @@ repos: rev: stable hooks: - id: black - language_version: python3.6 + language_version: python3.7 args: ["--config", ".black.toml"] - repo: local hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 16a12f8..72104dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,16 @@ +# 0.6.6 +## Additions +- Windows support + + +## Change +- Develop with python 3.7 + +## Bug fix +- Fixed bug in region growing related to batching +- Ball query for partial dense data on GPU was returning only the first point. Fixed now + + # 0.6.5 ## Additions diff --git a/cuda/src/ball_query_gpu.cu b/cuda/src/ball_query_gpu.cu index fd6cbe7..5966baa 100644 --- a/cuda/src/ball_query_gpu.cu +++ b/cuda/src/ball_query_gpu.cu @@ -59,7 +59,6 @@ __global__ void query_ball_point_kernel_partial_dense( // taken from // https://github.com/rusty1s/pytorch_cluster/blob/master/cuda/radius_kernel.cu const ptrdiff_t batch_idx = blockIdx.x; - const ptrdiff_t idx = threadIdx.x; const ptrdiff_t start_idx_x = batch_x[batch_idx]; const ptrdiff_t end_idx_x = batch_x[batch_idx + 1]; @@ -68,10 +67,10 @@ __global__ void query_ball_point_kernel_partial_dense( const ptrdiff_t end_idx_y = batch_y[batch_idx + 1]; float radius2 = radius * radius; - for (ptrdiff_t n_x = start_idx_x + idx; n_x < end_idx_x; n_x += TOTAL_THREADS_SPARSE) + for (ptrdiff_t n_y = start_idx_y + threadIdx.x; n_y < end_idx_y; n_y += blockDim.x) { int64_t count = 0; - for (ptrdiff_t n_y = start_idx_y; n_y < end_idx_y; n_y++) + for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) { float dist = 0; for (ptrdiff_t d = 0; d < 3; d++) diff --git a/setup.py b/setup.py index a3488d0..f014ade 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,9 @@ def get_ext_modules(): extra_compile_args += ["-DVERSION_GE_1_3"] ext_src_root = "cuda" - ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root)) + ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob( + "{}/src/*.cu".format(ext_src_root) + ) ext_modules = [] if CUDA_HOME: @@ -37,7 +39,10 @@ def get_ext_modules(): name="torch_points_kernels.points_cuda", sources=ext_sources, include_dirs=["{}/include".format(ext_src_root)], - extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,}, + extra_compile_args={ + "cxx": extra_compile_args, + "nvcc": extra_compile_args, + }, ) ) @@ -67,7 +72,7 @@ def get_cmdclass(): requirements = ["torch>=1.1.0", "numba", "scikit-learn"] url = "https://github.com/nicolas-chaulet/torch-points-kernels" -__version__ = "0.6.5" +__version__ = "0.6.6" setup( name="torch-points-kernels", version=__version__, @@ -81,5 +86,8 @@ def get_cmdclass(): cmdclass=get_cmdclass(), long_description=long_description, long_description_content_type="text/markdown", - classifiers=["Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",], + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + ], ) diff --git a/test/test_ballquerry.py b/test/test_ballquerry.py index 0b32b96..2139f41 100644 --- a/test/test_ballquerry.py +++ b/test/test_ballquerry.py @@ -61,21 +61,18 @@ def test_cpu_gpu_equality(self): class TestBallPartial(unittest.TestCase): @run_if_cuda def test_simple_gpu(self): - x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float).cuda() + x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [0.2, 0, 0], [0.1, 0, 0]]).to(torch.float).cuda() y = torch.tensor([[0, 0, 0]]).to(torch.float).cuda() - batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda() + batch_x = torch.from_numpy(np.asarray([0, 0, 0, 1])).long().cuda() batch_y = torch.from_numpy(np.asarray([0])).long().cuda() - batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long().cuda() - batch_y = torch.from_numpy(np.asarray([0])).long().cuda() - - idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y) + idx, dist2 = ball_query(0.2, 4, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y) idx = idx.detach().cpu().numpy() dist2 = dist2.detach().cpu().numpy() - idx_answer = np.asarray([[1, -1]]) - dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32) + idx_answer = np.asarray([[1, 2, -1, -1]]) + dist2_answer = np.asarray([[0.0100, 0.04, -1, -1]]).astype(np.float32) npt.assert_array_almost_equal(idx, idx_answer) npt.assert_array_almost_equal(dist2, dist2_answer) @@ -98,30 +95,29 @@ def test_simple_cpu(self): npt.assert_array_almost_equal(idx, idx_answer) npt.assert_array_almost_equal(dist2, dist2_answer) - def test_breaks(self): x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float) y = torch.tensor([[0, 0, 0]]).to(torch.float) batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long() batch_y = torch.from_numpy(np.asarray([0])).long() - + with self.assertRaises(RuntimeError): idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y) - def test_random_cpu(self): + def test_random_cpu(self, cuda=False): a = torch.randn(100, 3).to(torch.float) b = torch.randn(50, 3).to(torch.float) batch_a = torch.tensor([0 for i in range(a.shape[0] // 2)] + [1 for i in range(a.shape[0] // 2, a.shape[0])]) batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])]) R = 1 - idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True) - idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True) + idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,) + idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True,) torch.testing.assert_allclose(idx1, idx) with self.assertRaises(AssertionError): - idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False) - idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False) + idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,) + idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,) torch.testing.assert_allclose(idx1, idx) self.assertEqual(idx.shape[0], b.shape[0]) @@ -136,6 +132,28 @@ def test_random_cpu(self): if p >= 0 and p < len(batch_a): assert p in idx3_sk[i] + @run_if_cuda + def test_random_gpu(self): + a = torch.randn(100, 3).to(torch.float).cuda() + b = torch.randn(50, 3).to(torch.float).cuda() + batch_a = torch.tensor( + [0 for i in range(a.shape[0] // 2)] + [1 for i in range(a.shape[0] // 2, a.shape[0])] + ).cuda() + batch_b = torch.tensor( + [0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])] + ).cuda() + R = 1 + + idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False,) + + # Comparison to see if we have the same result + tree = KDTree(a.cpu().detach().numpy()) + idx3_sk = tree.query_radius(b.cpu().detach().numpy(), r=R) + i = np.random.randint(len(batch_b)) + for p in idx[i].cpu().detach().numpy(): + if p >= 0 and p < len(batch_a): + assert p in idx3_sk[i] + if __name__ == "__main__": unittest.main() diff --git a/torch_points_kernels/cluster.py b/torch_points_kernels/cluster.py index eff32fc..6bc64de 100644 --- a/torch_points_kernels/cluster.py +++ b/torch_points_kernels/cluster.py @@ -86,9 +86,19 @@ def region_grow( # Build clusters for a given label (ignore other points) label_mask = labels == l local_ind = ind[label_mask] + + # Remap batch to a continuous sequence + label_batch = batch[label_mask] + unique_in_batch = torch.unique(label_batch) + remaped_batch = torch.empty_like(label_batch) + for new, old in enumerate(unique_in_batch): + mask = label_batch == old + remaped_batch[mask] = new + + # Cluster label_clusters = grow_proximity( pos[label_mask, :], - batch[label_mask], + remaped_batch, nsample=nsample, radius=radius, min_cluster_size=min_cluster_size,