diff --git a/conf/dataset/default.yaml b/conf/dataset/default.yaml index e735428..14dc4db 100644 --- a/conf/dataset/default.yaml +++ b/conf/dataset/default.yaml @@ -4,6 +4,7 @@ cfg: batch_size: ${training.batch_size} num_workers: ${training.num_workers} dataroot: data + conv_type: ${model.model.conv_type} common_transform: aug_transform: @@ -13,4 +14,4 @@ cfg: test_transform: "${dataset.cfg.val_transform}" train_transform: - "${dataset.cfg.aug_transform}" - - "${dataset.cfg.common_transform}" \ No newline at end of file + - "${dataset.cfg.common_transform}" diff --git a/conf/model/segmentation/default.yaml b/conf/model/segmentation/default.yaml index ebf3207..97647af 100644 --- a/conf/model/segmentation/default.yaml +++ b/conf/model/segmentation/default.yaml @@ -13,3 +13,4 @@ model: backbone: input_nc: ${dataset.cfg.feature_dimension} architecture: unet + conv_type: null diff --git a/conf/model/segmentation/kpconv/KPFCNN.yaml b/conf/model/segmentation/kpconv/KPFCNN.yaml new file mode 100644 index 0000000..66c851e --- /dev/null +++ b/conf/model/segmentation/kpconv/KPFCNN.yaml @@ -0,0 +1,85 @@ +# @package model +defaults: + - /model/segmentation/default + +model: + conv_type: "PARTIAL_DENSE" + backbone: + _target_: torch_points3d.applications.kpconv.KPConv + config: + define_constants: + in_grid_size: 0.02 + in_feat: 64 + bn_momentum: 0.2 + max_neighbors: 25 + down_conv: + down_conv_nn: + [ + [[FEAT + 1, in_feat], [in_feat, 2*in_feat]], + [[2*in_feat, 2*in_feat], [2*in_feat, 4*in_feat]], + [[4*in_feat, 4*in_feat], [4*in_feat, 8*in_feat]], + [[8*in_feat, 8*in_feat], [8*in_feat, 16*in_feat]], + [[16*in_feat, 16*in_feat], [16*in_feat, 32*in_feat]], + ] + grid_size: + [ + [in_grid_size, in_grid_size], + [2*in_grid_size, 2*in_grid_size], + [4*in_grid_size, 4*in_grid_size], + [8*in_grid_size, 8*in_grid_size], + [16*in_grid_size, 16*in_grid_size], + ] + prev_grid_size: + [ + [in_grid_size, in_grid_size], + [in_grid_size, 2*in_grid_size], + [2*in_grid_size, 4*in_grid_size], + [4*in_grid_size, 8*in_grid_size], + [8*in_grid_size, 16*in_grid_size], + ] + block_names: + [ + ["SimpleBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ] + has_bottleneck: + [ + [False, True], + [True, True], + [True, True], + [True, True], + [True, True], + ] + deformable: + [ + [False, False], + [False, False], + [False, False], + [False, False], + [False, False], + ] + max_num_neighbors: + [[max_neighbors,max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors]] + module_name: KPDualBlock + up_conv: + module_name: FPModule_PD + up_conv_nn: + [ + [32*in_feat + 16*in_feat, 8*in_feat], + [8*in_feat + 8*in_feat, 4*in_feat], + [4*in_feat + 4*in_feat, 2*in_feat], + [2*in_feat + 2*in_feat, in_feat], + ] + skip: True + up_k: [1,1,1,1] + bn_momentum: + [ + bn_momentum, + bn_momentum, + bn_momentum, + bn_momentum, + bn_momentum, + ] diff --git a/conf/model/segmentation/pointnet2/pointnet2_largemsg.yaml b/conf/model/segmentation/pointnet2/pointnet2_largemsg.yaml new file mode 100644 index 0000000..52a963b --- /dev/null +++ b/conf/model/segmentation/pointnet2/pointnet2_largemsg.yaml @@ -0,0 +1,40 @@ +# @package model +defaults: + - /model/segmentation/default + +model: + conv_type: "DENSE" + backbone: + _target_: torch_points3d.applications.pointnet2.PointNet2 + config: + down_conv: + module_name: PointNetMSGDown + npoint: [1024, 256, 64, 16] + radii: [[0.05, 0.1], [0.1, 0.2], [0.2, 0.4], [0.4, 0.8]] + nsamples: [[16, 32], [16, 32], [16, 32], [16, 32]] + down_conv_nn: + [ + [[FEAT+3, 16, 16, 32], [FEAT+3, 32, 32, 64]], + [[32 + 64+3, 64, 64, 128], [32 + 64+3, 64, 96, 128]], + [ + [128 + 128+3, 128, 196, 256], + [128 + 128+3, 128, 196, 256], + ], + [ + [256 + 256+3, 256, 256, 512], + [256 + 256+3, 256, 384, 512], + ], + ] + up_conv: + module_name: DenseFPModule + up_conv_nn: + [ + [512 + 512 + 256 + 256, 512, 512], + [512 + 128 + 128, 512, 512], + [512 + 64 + 32, 256, 256], + [256 + FEAT, 128, 128], + ] + skip: True + mlp_cls: + nn: [128, 128] + dropout: 0.5 diff --git a/conf/model/segmentation/sparseconv3d/Res16UNet34.yaml b/conf/model/segmentation/sparseconv3d/Res16UNet34.yaml index 7e63894..12283a7 100644 --- a/conf/model/segmentation/sparseconv3d/Res16UNet34.yaml +++ b/conf/model/segmentation/sparseconv3d/Res16UNet34.yaml @@ -2,9 +2,9 @@ defaults: - /model/segmentation/ResUNet32 -model: +model: backbone: down_conv: N: [ 0, 2, 3, 4, 6 ] up_conv: - N: [ 1, 1, 1, 1, 1 ] \ No newline at end of file + N: [ 1, 1, 1, 1, 1 ] diff --git a/conf/model/segmentation/sparseconv3d/ResUNet32.yaml b/conf/model/segmentation/sparseconv3d/ResUNet32.yaml index c6d273b..01869ee 100644 --- a/conf/model/segmentation/sparseconv3d/ResUNet32.yaml +++ b/conf/model/segmentation/sparseconv3d/ResUNet32.yaml @@ -2,11 +2,11 @@ defaults: - /model/segmentation/default -model: +model: + conv_type: "SPARSE" backbone: _target_: torch_points3d.applications.sparseconv3d.SparseConv3d backend: torchsparse - config: define_constants: in_feat: 32 diff --git a/test/.#test_hydra.py b/test/.#test_hydra.py new file mode 120000 index 0000000..0f6e312 --- /dev/null +++ b/test/.#test_hydra.py @@ -0,0 +1 @@ +admincaor@admincaor.17583:1624465837 \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..ec46fec --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,52 @@ +from typing import List +import os +import os.path as osp +import pytest + +from hydra import compose, initialize +from hydra.test_utils.test_utils import find_parent_dir_containing + +from torch_points3d.trainer import LitTrainer +from torch_points3d.core.instantiator import HydraInstantiator + + +class ScriptRunner: + + @staticmethod + def find_hydra_conf_dir(config_dir: str = "conf") -> str: + """ + Util function to find the hydra config directory from the main repository for testing. + Args: + config_dir: Name of config directory. + Returns: Relative config path + """ + parent_dir = find_parent_dir_containing(config_dir) + relative_conf_dir = osp.relpath(parent_dir, os.path.dirname(__file__)) + return osp.join(relative_conf_dir, config_dir) + + def train(self, cmd_args: List[str]) -> None: + relative_conf_dir = self.find_hydra_conf_dir() + with initialize(config_path=relative_conf_dir, job_name="test_app"): + cfg = compose(config_name="config", overrides=cmd_args) + instantiator = HydraInstantiator() + trainer = LitTrainer( + instantiator, + dataset=cfg.get("dataset"), + trainer=cfg.get("trainer"), + model=cfg.get("model")) + trainer.train() + + def hf_train(self, dataset: str, model: str, num_workers: int = 0, fast_dev_run: int = 1): + cmd_args = [] + cmd_args.extend([ + f'model={model}', + f'dataset={dataset}', + f'trainer.max_epochs=1', + f'training.num_workers=1' + ]) + self.train(cmd_args) + + +@pytest.fixture(scope="session") +def script_runner() -> ScriptRunner: + return ScriptRunner() diff --git a/test/mockdatasets.py b/test/mockdatasets.py new file mode 100644 index 0000000..70ad85d --- /dev/null +++ b/test/mockdatasets.py @@ -0,0 +1,189 @@ +import numpy as np +import torch +from typing import Callable, Optional +from torch.utils.data import Dataset +from torch_geometric.data import Data, Batch + +from torch_points3d.data.batch import SimpleBatch +# from torch_points3d.core.data_transform import MultiScaleTransform +from torch_points3d.data.multiscale_data import MultiScaleBatch +from torch_points3d.data.pair import Pair, PairBatch, PairMultiScaleBatch, DensePairBatch + + +class MockDatasetConfig(object): + def __init__(self): + pass + + def keys(self): + return [] + + def get(self, dataset_name, default): + return None + + +class MockDataset(torch.utils.data.Dataset): + def __init__(self, feature_size=0, transform=None, num_points=100, panoptic=False, include_box=False, batch_size=2): + self.feature_dimension = feature_size + self.num_classes = 10 + self.num_points = num_points + self.batch_size = batch_size + self.weight_classes = None + self.feature_size = feature_size + if feature_size > 0: + self._feature = torch.tensor([range(feature_size) for i in range(self.num_points)], dtype=torch.float,) + else: + self._feature = None + self._y = torch.randint(0, 10, (self.num_points,)) + self._category = torch.ones((self.num_points,), dtype=torch.long) + self._ms_transform = None + self._transform = transform + self.mean_size_arr = torch.rand((11, 3)) + self.include_box = include_box + self.panoptic = panoptic + + def __len__(self): + return self.num_points + + def _generate_data(self): + data = Data( + pos=torch.randn((self.num_points, 3)), + x=torch.randn((self.num_points, self.feature_size)) if self.feature_size else None, + y=torch.randint(0, 10, (self.num_points,)), + category=self._category, + ) + if self.include_box: + num_boxes = 10 + data.center_label = torch.randn(num_boxes, 3) + data.heading_class_label = torch.zeros((num_boxes,)) + data.heading_residual_label = torch.randn((num_boxes,)) + data.size_class_label = torch.randint(0, len(self.mean_size_arr), (num_boxes,)) + data.size_residual_label = torch.randn(num_boxes, 3) + data.sem_cls_label = torch.randint(0, 10, (num_boxes,)) + data.box_label_mask = torch.randint(0, 1, (num_boxes,)).bool() + data.vote_label = torch.randn(self.num_points, 9) + data.vote_label_mask = torch.randint(0, 1, (self.num_points,)).bool() + data.instance_box_corners = torch.randn((num_boxes, 8, 3)).bool() + if self.panoptic: + data.num_instances = torch.tensor([10]) + data.center_label = torch.randn((self.num_points, 3)) + data.y = torch.randint(0, 10, (self.num_points,)) + data.instance_labels = torch.randint(0, 20, (self.num_points,)) + data.instance_mask = torch.rand(self.num_points).bool() + data.vote_label = torch.randn((self.num_points, 3)) + return data + + @property + def datalist(self): + datalist = [self._generate_data() for i in range(self.batch_size)] + if self._transform: + datalist = [self._transform(d.clone()) for d in datalist] + if self._ms_transform: + datalist = [self._ms_transform(d.clone()) for d in datalist] + return datalist + + def __getitem__(self, index): + return SimpleBatch.from_data_list(self.datalist) + + @property + def class_to_segments(self): + return {"class1": [0, 1, 2, 3, 4, 5], "class2": [6, 7, 8, 9]} + + @property + def stuff_classes(self): + return torch.tensor([0, 1]) + + def set_strategies(self, model): + strategies = model.get_spatial_ops() + transform = None + # transform = MultiScaleTransform(strategies) + self._ms_transform = transform + + +class MockDatasetGeometric(MockDataset): + def __getitem__(self, index): + if self._ms_transform: + return MultiScaleBatch.from_data_list(self.datalist) + else: + return Batch.from_data_list(self.datalist) + + +class PairMockDataset(MockDataset): + def __init__(self, feature_size=0, transform=None, num_points=100, is_pair_ind=True, batch_size=2): + super(PairMockDataset, self).__init__(feature_size, transform, num_points, batch_size=batch_size) + if is_pair_ind: + self._pair_ind = torch.tensor([[0, 1], [1, 0]]) + else: + self._pair_ind = None + + @property + def datalist(self): + torch.manual_seed(0) + datalist_source = [ + Data( + pos=torch.randn((self.num_points, 3)), + x=self._feature, + pair_ind=self._pair_ind, + size_pair_ind=torch.tensor([len(self._pair_ind)]), + ) + for i in range(self.batch_size) + ] + datalist_target = [ + Data( + pos=torch.randn((self.num_points, 3)), + x=self._feature, + pair_ind=self._pair_ind, + size_pair_ind=torch.tensor([len(self._pair_ind)]), + ) + for i in range(self.batch_size) + ] + if self._transform: + datalist_source = [self._transform(d.clone()) for d in datalist_source] + datalist_target = [self._transform(d.clone()) for d in datalist_target] + if self._ms_transform: + datalist_source = [self._ms_transform(d.clone()) for d in datalist_source] + datalist_target = [self._ms_transform(d.clone()) for d in datalist_target] + datalist = [Pair.make_pair(datalist_source[i], datalist_target[i]) for i in range(self.batch_size)] + return datalist + + def __getitem__(self, index): + return DensePairBatch.from_data_list(self.datalist) + + +class PairMockDatasetGeometric(PairMockDataset): + def __getitem__(self, index): + + if self._ms_transform: + return PairMultiScaleBatch.from_data_list(self.datalist) + else: + return PairBatch.from_data_list(self.datalist) + + +class SegmentationMockDataset(Dataset): + def __init__(self, train: bool = True, transform: Optional[Callable] = None, size: int = 3, is_same_size: bool = True, num_classes:int = 2): + self.train = train + self.transform = transform + self.size = size + self.is_same_size = is_same_size + self.num_classes = num_classes + + def __len__(self) -> int: + return self.size + + def __getitem__(self, idx: int) -> Data: + size_pt = 1000 + if not self.is_same_size: + size_pt = torch.randint(900, 1100, (1, ))[0].item() + if(self.train): + pos = torch.randn(size_pt, 3).float() + else: + pos = (idx * torch.ones(size_pt, 3)).float() + x = torch.ones(size_pt, 1).float() + y = torch.randint(0, self.num_classes, (size_pt,)).float() + data = Data(pos=pos, y=y, x=x) + if self.transform is not None: + data = self.transform(data) + return data + + + + diff --git a/test/test_data/test_batch.py b/test/test_data/test_batch.py new file mode 100644 index 0000000..018a8f2 --- /dev/null +++ b/test/test_data/test_batch.py @@ -0,0 +1,26 @@ +import unittest +import torch +from torch_geometric.data import Data +import numpy as np + +import os +import sys + +ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..") +sys.path.append(ROOT) + +from torch_points3d.data.batch import SimpleBatch + + + +def test_fromlist(): + nb_points = 100 + pos = torch.randn((nb_points, 3)) + y = torch.tensor([range(10) for i in range(pos.shape[0])], dtype=torch.float) + d = Data(pos=pos, y=y) + b = SimpleBatch.from_data_list([d, d]) + np.testing.assert_equal(b.pos.size(), (2, 100, 3)) + np.testing.assert_equal(b.y.size(), (2, 100, 10)) + + + diff --git a/test/test_data/test_msdata.py b/test/test_data/test_msdata.py new file mode 100644 index 0000000..33fb0dc --- /dev/null +++ b/test/test_data/test_msdata.py @@ -0,0 +1,46 @@ +import unittest +import torch +import torch.testing as tt +import numpy as np +from torch_geometric.data import Data + + +from torch_points3d.data.multiscale_data import MultiScaleBatch, MultiScaleData + + + +def test_apply(): + x = torch.tensor([1]) + pos = torch.tensor([1]) + d1 = Data(x=2 * x, pos=2 * pos) + d2 = Data(x=3 * x, pos=3 * pos) + data = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) + data.apply(lambda x: 2 * x) + np.testing.assert_equal(data.x[0].item(), 2) + np.testing.assert_equal(data.pos[0].item(), 2) + np.testing.assert_equal(data.multiscale[0].pos[0].item(), 4) + np.testing.assert_equal(data.multiscale[0].x[0].item(), 4) + np.testing.assert_equal(data.multiscale[1].pos[0].item(), 6) + np.testing.assert_equal(data.multiscale[1].x[0].item(), 6) + +def test_batch(): + x = torch.tensor([1]) + pos = x + d1 = Data(x=x, pos=pos) + d2 = Data(x=4 * x, pos=4 * pos) + data1 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) + + x = torch.tensor([2]) + pos = x + d1 = Data(x=x, pos=pos) + d2 = Data(x=4 * x, pos=4 * pos) + data2 = MultiScaleData(x=x, pos=pos, multiscale=[d1, d2]) + + batch = MultiScaleBatch.from_data_list([data1, data2]) + tt.assert_allclose(batch.x, torch.tensor([1, 2])) + tt.assert_allclose(batch.batch, torch.tensor([0, 1])) + + ms_batches = batch.multiscale + tt.assert_allclose(ms_batches[0].batch, torch.tensor([0, 1])) + tt.assert_allclose(ms_batches[1].batch, torch.tensor([0, 1])) + tt.assert_allclose(ms_batches[1].x, torch.tensor([4, 8])) diff --git a/test/test_dataset/test_basedataset.py b/test/test_dataset/test_basedataset.py new file mode 100644 index 0000000..0bd3ca9 --- /dev/null +++ b/test/test_dataset/test_basedataset.py @@ -0,0 +1,68 @@ +from typing import Optional, Callable +from dataclasses import dataclass +import numpy as np +import pytest +import torch + +from torch_geometric.data import Data + +from torch_points3d.datasets.base_dataset import PointCloudDataModule, PointCloudDataConfig +from test.mockdatasets import SegmentationMockDataset + +@dataclass +class MockConfig(PointCloudDataConfig): + batch_size: int = 16 + num_workers: int = 0 + size: int = 3 + is_same_size: bool = False + conv_type: str = "dense" + multiscale: bool = False + num_classes: int = 2 + + +class SegmentationMockDataLoader(PointCloudDataModule): + + def __init__(self, cfg, transform: Optional[Callable] = None): + super().__init__(cfg) + + self.ds = { + "train": SegmentationMockDataset(train=True, transform=transform, size=self.cfg.size, is_same_size=self.cfg.is_same_size, num_classes=self.cfg.num_classes), + "validation": SegmentationMockDataset(train=False, transform=transform, size=self.cfg.size, is_same_size=self.cfg.is_same_size, num_classes=self.cfg.num_classes) + } + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("num_classes", [2, 4, 6, 9, 10]) +@pytest.mark.parametrize("size", [3, 10, 100]) +@pytest.mark.parametrize("conv_type, is_same_size, multiscale", + [pytest.param("dense", True, False), + pytest.param("dense", False, False, marks=pytest.mark.xfail), + pytest.param("partial_dense", True, False), + pytest.param("partial_dense", False, False), + pytest.param("partial_dense", True, True, marks=pytest.mark.xfail), + pytest.param("sparse", True, False), + ]) +def test_dataloader(conv_type, is_same_size, size, multiscale, num_classes, batch_size): + cfg = MockConfig(conv_type=conv_type, is_same_size=is_same_size, size=size, multiscale=multiscale, num_classes=2, batch_size=batch_size) + dataloader = SegmentationMockDataLoader(cfg) + + train_dataloader = dataloader.train_dataloader() + val_dataloader = dataloader.val_dataloader() + + for loader in [train_dataloader, val_dataloader]: + # test len + np.testing.assert_equal(len(loader.dataset), size) + # test batch collate + batch = next(iter(train_dataloader)) + num_samples = PointCloudDataModule.get_num_samples(batch, conv_type) + np.testing.assert_equal(num_samples, min(batch_size, size)) + if(is_same_size): + if(conv_type.lower() == "dense"): + np.testing.assert_equal(batch.pos.size(1), 1000) + else: + for i in range(min(batch_size, size)): + np.testing.assert_equal(batch.pos[batch.batch == i].shape[0], 1000) + if(multiscale): + # test downsample and upsample are + pass + diff --git a/test/test_confusion_matrix.py b/test/test_metric/test_confusion_matrix.py similarity index 100% rename from test/test_confusion_matrix.py rename to test/test_metric/test_confusion_matrix.py diff --git a/test/test_segmentation_tracker.py b/test/test_metric/test_segmentation_tracker.py similarity index 89% rename from test/test_segmentation_tracker.py rename to test/test_metric/test_segmentation_tracker.py index cb54eb6..23731d9 100644 --- a/test/test_segmentation_tracker.py +++ b/test/test_metric/test_segmentation_tracker.py @@ -9,7 +9,7 @@ from torch_geometric.data import Data DIR = os.path.dirname(os.path.realpath(__file__)) -ROOT = os.path.join(DIR, "..") +ROOT = os.path.join(DIR, "..", "..") sys.path.insert(0, ROOT) sys.path.append(".") @@ -36,10 +36,10 @@ class MockModel: def __init__(self): self.iter = 0 self.losses = [ - {"loss_1": 1, "loss_2": 2}, - {"loss_1": 2, "loss_2": 2}, - {"loss_1": 1, "loss_2": 2}, - {"loss_1": 1, "loss_2": 2}, + {"loss_1": torch.tensor(1.0), "loss_2": torch.tensor(2.0)}, + {"loss_1": torch.tensor(2.0), "loss_2": torch.tensor(2.0)}, + {"loss_1": torch.tensor(1.0), "loss_2": torch.tensor(2.0)}, + {"loss_1": torch.tensor(1.0), "loss_2": torch.tensor(2.0)}, ] self.outputs = [ torch.tensor([[0, 1], [0, 1]]), @@ -91,7 +91,7 @@ def test_forward(): np.testing.assert_allclose(metrics["train_miou"], 25, atol=1e-5) assert metrics["train_loss_1"] == 1.5 - tracker.reset("test") + tracker = SegmentationTracker(num_classes=2, stage="test") model.iter += 1 output = {"preds": model.get_output(), "labels": model.get_labels()} losses = model.get_current_losses() @@ -103,8 +103,7 @@ def test_forward(): @pytest.mark.parametrize("finalise", [pytest.param(True), pytest.param(False)]) def test_ignore_label(finalise): - tracker = SegmentationTracker(num_classes=2, ignore_label=-100) - tracker.reset("test") + tracker = SegmentationTracker(num_classes=2, ignore_label=-100, stage="test") model = MockModel() model.iter = 3 output = {"preds": model.get_output(), "labels": model.get_labels()} diff --git a/test/test_modules/test_api.py b/test/test_modules/test_api.py new file mode 100644 index 0000000..853cc0c --- /dev/null +++ b/test/test_modules/test_api.py @@ -0,0 +1,145 @@ +from torch_points3d.core.data_transform import GridSampling3D +from test.mockdatasets import MockDatasetGeometric, MockDataset +import pytest +import torch +from omegaconf import OmegaConf + +from torch_points3d.applications.sparseconv3d import SparseConv3d +from torch_points3d.applications.pointnet2 import PointNet2 +from torch_points3d.applications.kpconv import KPConv + + +seed = 0 +torch.manual_seed(seed) +device = "cpu" + + +@pytest.mark.parametrize("architecture", [pytest.param("unet"), pytest.param("encoder", marks=pytest.mark.xfail)]) +@pytest.mark.parametrize("input_nc", [0, 3]) +@pytest.mark.parametrize("num_layers", [4]) +@pytest.mark.parametrize("grid_sampling", [0.02, 0.04]) +@pytest.mark.parametrize("in_feat", [32]) +@pytest.mark.parametrize("output_nc", [None, 32]) +def test_kpconv(architecture, input_nc, num_layers, grid_sampling, in_feat, output_nc): + if output_nc is not None: + model = KPConv( + architecture=architecture, + input_nc=input_nc, + in_feat=in_feat, + in_grid_size=grid_sampling, + num_layers=num_layers, + output_nc=output_nc, + config=None, + ) + else: + model = KPConv( + architecture=architecture, + input_nc=input_nc, + in_feat=in_feat, + in_grid_size=grid_sampling, + num_layers=num_layers, + config=None, + ) + + dataset = MockDatasetGeometric(input_nc + 1, transform=GridSampling3D(0.01), num_points=128) + assert len(model._modules["down_modules"]) == num_layers + 1 + assert len(model._modules["inner_modules"]) == 1 + assert len(model._modules["up_modules"]) == 4 + if output_nc is None: + assert not model.has_mlp_head + assert model.output_nc == in_feat + + try: + data_out = model.forward(dataset[0]) + assert data_out.x.shape[1] == in_feat + except Exception as e: + print("Model failing:") + print(model) + raise e + + +@pytest.mark.skip("RSConv is not yet implemented") +def test_pn2(): + + input_nc = 2 + num_layers = 3 + output_nc = 5 + model = PointNet2( + architecture="unet", + input_nc=input_nc, + output_nc=output_nc, + num_layers=num_layers, + multiscale=True, + config=None, + ) + dataset = MockDataset(input_nc, num_points=512) + self.assertEqual(len(model._modules["down_modules"]), num_layers - 1) + self.assertEqual(len(model._modules["inner_modules"]), 1) + self.assertEqual(len(model._modules["up_modules"]), num_layers) + + try: + data_out = model.forward(dataset[0]) + self.assertEqual(data_out.x.shape[1], output_nc) + except Exception as e: + print("Model failing:") + print(model) + raise e + + +@pytest.mark.skip("RSConv is not yet implemented") +def test_rsconv(): + from torch_points3d.applications.rsconv import RSConv + + input_nc = 2 + num_layers = 4 + output_nc = 5 + model = RSConv( + architecture="unet", + input_nc=input_nc, + output_nc=output_nc, + num_layers=num_layers, + multiscale=True, + config=None, + ) + dataset = MockDataset(input_nc, num_points=1024) + self.assertEqual(len(model._modules["down_modules"]), num_layers) + self.assertEqual(len(model._modules["inner_modules"]), 2) + self.assertEqual(len(model._modules["up_modules"]), num_layers) + + try: + data_out = model.forward(dataset[0]) + self.assertEqual(data_out.x.shape[1], output_nc) + except Exception as e: + print("Model failing:") + print(model) + raise e + + +@pytest.mark.skip("RSConv is not yet implemented") +def test_sparseconv3d(): + + input_nc = 3 + num_layers = 4 + in_feat = 32 + out_feat = in_feat * 3 + model = SparseConv3d( + architecture="unet", + input_nc=input_nc, + in_feat=in_feat, + num_layers=num_layers, + config=None, + ) + dataset = MockDatasetGeometric(input_nc, transform=GridSampling3D(0.01, quantize_coords=True), num_points=128) + self.assertEqual(len(model._modules["down_modules"]), num_layers + 1) + self.assertEqual(len(model._modules["inner_modules"]), 1) + self.assertEqual(len(model._modules["up_modules"]), 4 + 1) + self.assertFalse(model.has_mlp_head) + self.assertEqual(model.output_nc, out_feat) + + try: + data_out = model.forward(dataset[0]) + self.assertEqual(data_out.x.shape[1], out_feat) + except Exception as e: + print("Model failing:") + print(model) + print(e) diff --git a/test/test_modules/test_kpconv.py b/test/test_modules/test_kpconv.py new file mode 100644 index 0000000..8fea8c3 --- /dev/null +++ b/test/test_modules/test_kpconv.py @@ -0,0 +1,35 @@ +import numpy as np +import numpy.testing as npt +import torch + + +from torch_points3d.applications.modules.KPConv.losses import repulsion_loss, fitting_loss, permissive_loss + + +def test_permissive_loss(): + pos_n = np.asarray([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]).astype(np.float) + pos_t = torch.from_numpy(pos_n) + loss = permissive_loss(pos_t, 1).item() + assert loss == np.sqrt(2) + + +def test_fitting_loss(): + pos_n = np.asarray([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]).astype(np.float) + target = np.asarray([[0.5, 0.5, 0]]) + K_points = torch.from_numpy(pos_n) + neighbors = torch.from_numpy(target) + neighbors = neighbors + neighbors = neighbors.repeat([4, 1]) + differences = neighbors - K_points + sq_distances = torch.sum(differences ** 2, dim=-1).unsqueeze(0) + loss = fitting_loss(sq_distances, 1).item() + assert loss == 0.5 + + +def test_repulsion_loss(): + pos_n = np.asarray([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]).astype(np.float64) + K_points = torch.from_numpy(pos_n) + loss = repulsion_loss(K_points.unsqueeze(0), 1).item() + arr_ = np.asarray([0.25, 0.25, 0.0074]).astype(np.float64) + # Pytorch losses precision from decimal 4 + npt.assert_almost_equal(loss, 4 * np.sum(arr_), decimal=3) diff --git a/test/test_model.py b/test/test_task/segmentation/test_model.py similarity index 81% rename from test/test_model.py rename to test/test_task/segmentation/test_model.py index b97368b..978675c 100644 --- a/test/test_model.py +++ b/test/test_task/segmentation/test_model.py @@ -1,15 +1,9 @@ import pytest -import sys -import os import torch from omegaconf import OmegaConf from torch_geometric.data import Batch -DIR = os.path.dirname(os.path.realpath(__file__)) -ROOT = os.path.join(DIR, "..") -sys.path.insert(0, ROOT) -sys.path.append(".") from torch_points3d.models.segmentation.base_model import SegmentationBaseModel from torch_points3d.core.instantiator import HydraInstantiator @@ -31,3 +25,10 @@ def test_forward(self): data = Batch(pos=pos, x=x, batch=batch, y=y, coords=coords) model.set_input(data) model.forward() + + +@pytest.mark.slow +def test_s3dis_run(script_runner): + model = "segmentation/sparseconv3d/ResUNet32" + dataset = "segmentation/s3dis/s3dis1x1" + script_runner.hf_train(dataset=dataset, model=model) diff --git a/torch_points3d/applications/conf/kpconv/encoder_4.yaml b/torch_points3d/applications/conf/kpconv/encoder_4.yaml new file mode 100644 index 0000000..0ea0664 --- /dev/null +++ b/torch_points3d/applications/conf/kpconv/encoder_4.yaml @@ -0,0 +1,68 @@ +class: kpconv.KPConvPaper +conv_type: "PARTIAL_DENSE" +define_constants: + in_grid_size: 0.02 + in_feat: 64 + bn_momentum: 0.2 + output_nc: 256 + max_neighbors: 25 +down_conv: + down_conv_nn: + [ + [[FEAT + 1, in_feat], [in_feat, 2*in_feat]], + [[2*in_feat, 2*in_feat], [2*in_feat, 4*in_feat]], + [[4*in_feat, 4*in_feat], [4*in_feat, 8*in_feat]], + [[8*in_feat, 8*in_feat], [8*in_feat, 16*in_feat]], + [[16*in_feat, 16*in_feat], [16*in_feat, 32 * in_feat]], + ] + grid_size: + [ + [in_grid_size, in_grid_size], + [2*in_grid_size, 2*in_grid_size], + [4*in_grid_size, 4*in_grid_size], + [8*in_grid_size, 8*in_grid_size], + [16*in_grid_size, 16*in_grid_size], + ] + prev_grid_size: + [ + [in_grid_size, in_grid_size], + [in_grid_size, 2*in_grid_size], + [2*in_grid_size, 4*in_grid_size], + [4*in_grid_size, 8*in_grid_size], + [8*in_grid_size, 16*in_grid_size], + ] + block_names: + [ + ["SimpleBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ] + has_bottleneck: + [ + [False, True], + [True, True], + [True, True], + [True, True], + [True, True], + ] + deformable: + [ + [False, False], + [False, False], + [False, False], + [False, False], + [False, False], + ] + max_num_neighbors: + [[max_neighbors,max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors]] + module_name: KPDualBlock +innermost: + module_name: GlobalBaseModule + activation: + name: LeakyReLU + negative_slope: 0.2 + aggr: "mean" + nn: [32 * in_feat + 3, 32 * in_feat] + diff --git a/torch_points3d/applications/conf/kpconv/unet_4.yaml b/torch_points3d/applications/conf/kpconv/unet_4.yaml new file mode 100644 index 0000000..993086e --- /dev/null +++ b/torch_points3d/applications/conf/kpconv/unet_4.yaml @@ -0,0 +1,78 @@ +class: kpconv.KPConvPaper +conv_type: "PARTIAL_DENSE" +define_constants: + in_grid_size: 0.02 + in_feat: 64 + bn_momentum: 0.2 + max_neighbors: 25 +down_conv: + down_conv_nn: + [ + [[FEAT + 1, in_feat], [in_feat, 2*in_feat]], + [[2*in_feat, 2*in_feat], [2*in_feat, 4*in_feat]], + [[4*in_feat, 4*in_feat], [4*in_feat, 8*in_feat]], + [[8*in_feat, 8*in_feat], [8*in_feat, 16*in_feat]], + [[16*in_feat, 16*in_feat], [16*in_feat, 32*in_feat]], + ] + grid_size: + [ + [in_grid_size, in_grid_size], + [2*in_grid_size, 2*in_grid_size], + [4*in_grid_size, 4*in_grid_size], + [8*in_grid_size, 8*in_grid_size], + [16*in_grid_size, 16*in_grid_size], + ] + prev_grid_size: + [ + [in_grid_size, in_grid_size], + [in_grid_size, 2*in_grid_size], + [2*in_grid_size, 4*in_grid_size], + [4*in_grid_size, 8*in_grid_size], + [8*in_grid_size, 16*in_grid_size], + ] + block_names: + [ + ["SimpleBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ] + has_bottleneck: + [ + [False, True], + [True, True], + [True, True], + [True, True], + [True, True], + ] + deformable: + [ + [False, False], + [False, False], + [False, False], + [False, False], + [False, False], + ] + max_num_neighbors: + [[max_neighbors,max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors]] + module_name: KPDualBlock +up_conv: + module_name: FPModule_PD + up_conv_nn: + [ + [32*in_feat + 16*in_feat, 8*in_feat], + [8*in_feat + 8*in_feat, 4*in_feat], + [4*in_feat + 4*in_feat, 2*in_feat], + [2*in_feat + 2*in_feat, in_feat], + ] + skip: True + up_k: [1,1,1,1] + bn_momentum: + [ + bn_momentum, + bn_momentum, + bn_momentum, + bn_momentum, + bn_momentum, + ] diff --git a/torch_points3d/applications/conf/pointnet2/encoder_3_ms.yaml b/torch_points3d/applications/conf/pointnet2/encoder_3_ms.yaml new file mode 100644 index 0000000..8653250 --- /dev/null +++ b/torch_points3d/applications/conf/pointnet2/encoder_3_ms.yaml @@ -0,0 +1,23 @@ +conv_type: "DENSE" +define_constants: + in_feat: 64 +down_conv: + module_name: PointNetMSGDown + npoint: [512, 128] + radii: [[0.1, 0.2, 0.4], [0.4, 0.8]] + nsamples: [[32, 64, 128], [64, 128]] + down_conv_nn: + [ + [ + [FEAT + 3, in_feat // 2, in_feat // 2, in_feat], + [FEAT+ 3, in_feat, in_feat, in_feat * 2], + [FEAT+ 3, in_feat, in_feat + in_feat // 2 , in_feat * 2], + ], + [ + [in_feat + in_feat * 2 + in_feat * 2 + 3, in_feat * 2, in_feat * 2, in_feat * 4], + [in_feat + in_feat * 2 + in_feat * 2 + 3, in_feat * 2, in_feat * 3, in_feat * 4], + ], + ] +innermost: + module_name: GlobalDenseBaseModule + nn: [in_feat * 4 * 2 + 3, in_feat * 4, in_feat * 8] \ No newline at end of file diff --git a/torch_points3d/applications/conf/pointnet2/unet_3_ms.yaml b/torch_points3d/applications/conf/pointnet2/unet_3_ms.yaml new file mode 100644 index 0000000..5c768db --- /dev/null +++ b/torch_points3d/applications/conf/pointnet2/unet_3_ms.yaml @@ -0,0 +1,30 @@ +conv_type: "DENSE" +down_conv: + module_name: PointNetMSGDown + npoint: [512, 128] + radii: [[0.1, 0.2, 0.4], [0.4, 0.8]] + nsamples: [[32, 64, 128], [64, 128]] + down_conv_nn: + [ + [ + [FEAT+3, 32, 32, 64], + [FEAT+3, 64, 64, 128], + [FEAT+3, 64, 96, 128], + ], + [ + [64 + 128 + 128+3, 128, 128, 256], + [64 + 128 + 128+3, 128, 196, 256], + ], + ] +innermost: + module_name: GlobalDenseBaseModule + nn: [256 * 2 + 3, 256, 512, 1024] +up_conv: + module_name: DenseFPModule + up_conv_nn: + [ + [1024 + 256*2, 256, 256], + [256 + 128 * 2 + 64, 256, 128], + [128 + FEAT, 128, 128], + ] + skip: True \ No newline at end of file diff --git a/torch_points3d/applications/conf/pointnet2/unet_3_ss.yaml b/torch_points3d/applications/conf/pointnet2/unet_3_ss.yaml new file mode 100644 index 0000000..c5cbb97 --- /dev/null +++ b/torch_points3d/applications/conf/pointnet2/unet_3_ss.yaml @@ -0,0 +1,19 @@ +conv_type: "DENSE" +down_conv: + module_name: PointNetMSGDown + npoint: [512, 128] + radii: [[0.2], [0.4]] + nsamples: [[64], [64]] + down_conv_nn: [[[FEAT + 3, 64, 64, 128]], [[128+3, 128, 128, 256]]] +innermost: + module_name: GlobalDenseBaseModule + nn: [256 + 3, 256, 512, 1024] +up_conv: + module_name: DenseFPModule + up_conv_nn: + [ + [1024 + 256, 256, 256], + [256 + 128, 256, 128], + [128 + FEAT, 128, 128, 128], + ] + skip: True diff --git a/torch_points3d/applications/conf/pointnet2/unet_4_ss.yaml b/torch_points3d/applications/conf/pointnet2/unet_4_ss.yaml new file mode 100644 index 0000000..c3015ff --- /dev/null +++ b/torch_points3d/applications/conf/pointnet2/unet_4_ss.yaml @@ -0,0 +1,26 @@ +conv_type: "DENSE" +define_constants: + in_feat: 64 +down_conv: + module_name: PointNetMSGDown + npoint: [2048, 1024, 512, 256] + radii: [[0.2], [0.4], [0.8], [1.2]] + nsamples: [[64], [32], [16], [16]] + down_conv_nn: [[[FEAT + 3, in_feat, in_feat, in_feat * 2]], + [[in_feat * 2 + 3, in_feat * 2, in_feat * 2, in_feat * 4]], + [[in_feat * 4 + 3, in_feat * 2, in_feat * 2, in_feat * 4]]] + save_sampling_id: [True, False, False, False] + normalize_xyz: [True, True, True, True] +innermost: + module_name: GlobalDenseBaseModule + nn: [ in_feat * 4 + 3, in_feat * 8, in_feat * 16] +up_conv: + module_name: DenseFPModule + up_conv_nn: + [ + [in_feat * 16 + in_feat * 4, in_feat * 8, in_feat * 8], + [in_feat * 8 + in_feat * 4, in_feat * 8, in_feat * 8], + [in_feat * 8 + in_feat * 2, in_feat * 4, in_feat * 4], + [in_feat * 4 + FEAT, in_feat * 2, in_feat * 2] + ] + skip: True \ No newline at end of file diff --git a/torch_points3d/applications/conf/rsconv/encoder_4.yaml b/torch_points3d/applications/conf/rsconv/encoder_4.yaml new file mode 100644 index 0000000..8bbf9fe --- /dev/null +++ b/torch_points3d/applications/conf/rsconv/encoder_4.yaml @@ -0,0 +1,32 @@ +conv_type: "DENSE" +define_constants: + in_feat: 64 +down_conv: + module_name: RSConvOriginalMSGDown + npoint: [1024, 256, 64, 16] + radii: + [ + [0.075, 0.1, 0.125], + [0.1, 0.15, 0.2], + [0.2, 0.3, 0.4], + [0.4, 0.6, 0.8], + ] + nsamples: [[16, 32, 48], [16, 48, 64], [16, 32, 48], [16, 24, 32]] + down_conv_nn: + [ + [[10, in_feat//2, 16], [FEAT + 3, 16]], + [10, in_feat//2, in_feat * 3 + 3], + [10, in_feat, (in_feat * 2) * 3 + 3], + [10, 2 * in_feat, (in_feat * 4) * 3 + 3], + ] + channel_raising_nn: + [ + [16, in_feat], + [in_feat * 3 + 3, (in_feat * 2)], + [(in_feat * 2) * 3 + 3, (in_feat * 4)], + [(in_feat * 4) * 3 + 3, (in_feat * 8)], + ] +innermost: + module_name: GlobalDenseBaseModule + nn: [(in_feat * 8) * 3 + 3, in_feat * 8] + aggr: "mean" diff --git a/torch_points3d/applications/conf/rsconv/unet_4.yaml b/torch_points3d/applications/conf/rsconv/unet_4.yaml new file mode 100644 index 0000000..e564a3f --- /dev/null +++ b/torch_points3d/applications/conf/rsconv/unet_4.yaml @@ -0,0 +1,45 @@ +conv_type: "DENSE" +down_conv: + module_name: RSConvOriginalMSGDown + npoint: [1024, 256, 64, 16] + radii: + [ + [0.075, 0.1, 0.125], + [0.1, 0.15, 0.2], + [0.2, 0.3, 0.4], + [0.4, 0.6, 0.8], + ] + nsamples: [[16, 32, 48], [16, 48, 64], [16, 32, 48], [16, 24, 32]] + down_conv_nn: + [ + [[10, 64//2, 16], [FEAT + 3, 16]], + [10, 128//4, 64 * 3 + 3], + [10, 256//4, 128 * 3 + 3], + [10, 512//4, 256 * 3 + 3], + ] + channel_raising_nn: + [ + [16, 64], + [64 * 3 + 3, 128], + [128 * 3 + 3, 256], + [256 * 3 + 3, 512], + ] +innermost: + - module_name: GlobalDenseBaseModule + nn: [512 * 3 + 3, 128] + aggr: "mean" + - module_name: GlobalDenseBaseModule + nn: [256 * 3 + 3, 128] + aggr: "mean" +up_conv: + bn: True + bias: False + module_name: DenseFPModule + up_conv_nn: + [ + [512 * 3 + 256 * 3, 512, 512], + [128 * 3 + 512, 512, 512], + [64 * 3 + 512, 256, 256], + [256 + FEAT , 128, 128], + ] + skip: True \ No newline at end of file diff --git a/torch_points3d/applications/conf/sparseconv3d/encoder_2.yaml b/torch_points3d/applications/conf/sparseconv3d/encoder_2.yaml new file mode 100644 index 0000000..acb2ef4 --- /dev/null +++ b/torch_points3d/applications/conf/sparseconv3d/encoder_2.yaml @@ -0,0 +1,18 @@ +conv_type: "SPARSE" +define_constants: + in_feat: 32 + block: ResBlock # Can be any of the blocks in modules/SparseConv3d/modules.py +down_conv: + module_name: ResNetDown + block: block + N: [0, 1, 2] + down_conv_nn: [[FEAT, in_feat], [in_feat, in_feat], [in_feat, 2*in_feat]] + kernel_size: [3, 3, 3] + stride: [1, 2, 2] +innermost: + module_name: GlobalBaseModule + activation: + name: LeakyReLU + negative_slope: 0.2 + aggr: "mean" + nn: [2*in_feat, 2*in_feat] diff --git a/torch_points3d/applications/conf/sparseconv3d/encoder_4.yaml b/torch_points3d/applications/conf/sparseconv3d/encoder_4.yaml new file mode 100644 index 0000000..e413bd1 --- /dev/null +++ b/torch_points3d/applications/conf/sparseconv3d/encoder_4.yaml @@ -0,0 +1,25 @@ +conv_type: "SPARSE" +define_constants: + in_feat: 32 + block: ResBlock # Can be any of the blocks in modules/SparseConv3d/modules.py +down_conv: + module_name: ResNetDown + block: block + N: [0, 1, 2, 2, 3] + down_conv_nn: + [ + [FEAT, in_feat], + [in_feat, in_feat], + [in_feat, 2*in_feat], + [2*in_feat, 4*in_feat], + [4*in_feat, 8*in_feat], + ] + kernel_size: [3, 3, 3, 3, 3] + stride: [1, 2, 2, 2, 2] +innermost: + module_name: GlobalBaseModule + activation: + name: LeakyReLU + negative_slope: 0.2 + aggr: "mean" + nn: [8*in_feat, 8*in_feat] diff --git a/torch_points3d/applications/conf/sparseconv3d/unet_2.yaml b/torch_points3d/applications/conf/sparseconv3d/unet_2.yaml new file mode 100644 index 0000000..813fb44 --- /dev/null +++ b/torch_points3d/applications/conf/sparseconv3d/unet_2.yaml @@ -0,0 +1,23 @@ +conv_type: "SPARSE" +define_constants: + in_feat: 32 + block: ResBlock # Can be any of the blocks in modules/SparseConv3d/modules.py +down_conv: + block: block + module_name: ResNetDown + N: [0, 1, 2] + down_conv_nn: [[FEAT, in_feat], [in_feat, in_feat], [in_feat, 2*in_feat]] + kernel_size: [2, 2] + stride: [1, 2, 2] +up_conv: + block: block + module_name: ResNetUp + N: [1, 1, 0] + up_conv_nn: + [ + [4*in_feat + 2*in_feat, 3*in_feat], + [3*in_feat + in_feat, 3*in_feat], + [3*in_feat + in_feat, 3*in_feat], + ] + kernel_size: [2, 2, 3] + stride: [2, 2, 1] diff --git a/torch_points3d/applications/conf/sparseconv3d/unet_4.yaml b/torch_points3d/applications/conf/sparseconv3d/unet_4.yaml new file mode 100644 index 0000000..66b74b4 --- /dev/null +++ b/torch_points3d/applications/conf/sparseconv3d/unet_4.yaml @@ -0,0 +1,32 @@ +conv_type: "SPARSE" +define_constants: + in_feat: 32 + block: ResBlock # Can be any of the blocks in modules/SparseConv3d/modules.py +down_conv: + module_name: ResNetDown + block: block + N: [0, 1, 2, 2, 3] + down_conv_nn: + [ + [FEAT, in_feat], + [in_feat, in_feat], + [in_feat, 2*in_feat], + [2*in_feat, 4*in_feat], + [4*in_feat, 8*in_feat], + ] + kernel_size: [3, 3, 3, 3, 3] + stride: [1, 2, 2, 2, 2] +up_conv: + block: block + module_name: ResNetUp + N: [1, 1, 1, 1, 0] + up_conv_nn: + [ + [8*in_feat, 4*in_feat], + [4*in_feat + 4*in_feat, 4*in_feat], + [4*in_feat + 2*in_feat, 3*in_feat], + [3*in_feat + in_feat, 3*in_feat], + [3*in_feat + in_feat, 3*in_feat], + ] + kernel_size: [3, 3, 3, 3, 3] + stride: [2, 2, 2, 2, 1] diff --git a/torch_points3d/applications/kpconv.py b/torch_points3d/applications/kpconv.py new file mode 100644 index 0000000..628927b --- /dev/null +++ b/torch_points3d/applications/kpconv.py @@ -0,0 +1,183 @@ +import os +from omegaconf import DictConfig, OmegaConf +import logging + +from torch_points3d.applications.modelfactory import ModelFactory + +from torch_points3d.core.common_modules import FastBatchNorm1d +from torch_points3d.modules.KPConv import * +from torch_points3d.core.base_conv.partial_dense import * +from torch_points3d.applications.base_architectures.unet import UnwrappedUnetBasedModel +from torch_points3d.core.common_modules.base_modules import MLP + +from torch_points3d.data.multiscale_data import MultiScaleBatch + +from .utils import extract_output_nc + + +CUR_FILE = os.path.realpath(__file__) +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/kpconv") + +log = logging.getLogger(__name__) + + +def KPConv( + architecture: str = None, input_nc: int = None, num_layers: int = None, config: DictConfig = None, *args, **kwargs +): + """Create a KPConv backbone model based on the architecture proposed in + https://arxiv.org/abs/1904.08889 + + Parameters + ---------- + architecture : str, optional + Architecture of the model, choose from unet, encoder and decoder + input_nc : int, optional + Number of channels for the input + output_nc : int, optional + If specified, then we add a fully connected head at the end of the network to provide the requested dimension + num_layers : int, optional + Depth of the network + in_grid_size : float, optional + Size of the grid at the entry of the network. It is divided by two at each layer + in_feat : int, optional + Number of channels after the first convolution. Doubles at each layer + config : DictConfig, optional + Custom config, overrides the num_layers and architecture parameters + """ + factory = KPConvFactory( + architecture=architecture, num_layers=num_layers, input_nc=input_nc, config=config, **kwargs + ) + return factory.build() + + +class KPConvFactory(ModelFactory): + def _build_unet(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers)) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return KPConvUnet(model_config, None, modules_lib, **self.kwargs) + + def _build_encoder(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join(PATH_TO_CONFIG, "encoder_{}.yaml".format(self.num_layers)) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return KPConvEncoder(model_config, None, modules_lib, **self.kwargs) + + +class BaseKPConv(UnwrappedUnetBasedModel): + CONV_TYPE = "partial_dense" + + def __init__(self, model_config, model_type, modules, *args, **kwargs): + super(BaseKPConv, self).__init__(model_config, model_type, modules) + try: + default_output_nc = extract_output_nc(model_config) + except: + default_output_nc = -1 + log.warning("Could not resolve number of output channels") + + self._output_nc = default_output_nc + self._has_mlp_head = False + if "output_nc" in kwargs: + self._has_mlp_head = True + self._output_nc = kwargs["output_nc"] + self.mlp = MLP([default_output_nc, self.output_nc], activation=torch.nn.LeakyReLU(0.2), bias=False) + + @property + def has_mlp_head(self): + return self._has_mlp_head + + @property + def output_nc(self): + return self._output_nc + + def _set_input(self, data): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters + ----------- + data: + a dictionary that contains the data itself and its metadata information. + """ + if isinstance(data, MultiScaleBatch): + self.pre_computed = data.multiscale + self.upsample = data.upsample + del data.upsample + del data.multiscale + else: + self.upsample = None + self.pre_computed = None + + self.input = data + + +class KPConvEncoder(BaseKPConv): + def forward(self, data, *args, **kwargs): + """ + Parameters + ----------- + data: + A dictionary that contains the data itself and its metadata information. Should contain + - pos [N, 3] + - x [N, C] + - multiscale (optional) precomputed data for the down convolutions + - upsample (optional) precomputed data for the up convolutions + + Returns + -------- + data: + - pos [1, 3] - Dummy pos + - x [1, output_nc] + """ + self._set_input(data) + data = self.input + stack_down = [data] + for i in range(len(self.down_modules) - 1): + data = self.down_modules[i](data) + stack_down.append(data) + data = self.down_modules[-1](data) + + if not isinstance(self.inner_modules[0], Identity): + stack_down.append(data) + data = self.inner_modules[0](data) + + if self.has_mlp_head: + data.x = self.mlp(data.x) + return data + + +class KPConvUnet(BaseKPConv): + def forward(self, data, *args, **kwargs): + """Run forward pass. + Input --- D1 -- D2 -- D3 -- U1 -- U2 -- output + | |_________| | + |______________________| + + Parameters + ----------- + data: + A dictionary that contains the data itself and its metadata information. Should contain + - pos [N, 3] + - x [N, C] + - multiscale (optional) precomputed data for the down convolutions + - upsample (optional) precomputed data for the up convolutions + + Returns + -------- + data: + - pos [N, 3] + - x [N, output_nc] + """ + self._set_input(data) + data = super().forward(self.input, precomputed_down=self.pre_computed, precomputed_up=self.upsample) + if self.has_mlp_head: + data.x = self.mlp(data.x) + return data diff --git a/torch_points3d/applications/modelfactory.py b/torch_points3d/applications/modelfactory.py index 8763cb5..284f482 100644 --- a/torch_points3d/applications/modelfactory.py +++ b/torch_points3d/applications/modelfactory.py @@ -2,7 +2,7 @@ from omegaconf import DictConfig import logging -from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve + from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve log = logging.getLogger(__name__) diff --git a/torch_points3d/applications/pointnet2.py b/torch_points3d/applications/pointnet2.py new file mode 100644 index 0000000..18a04dd --- /dev/null +++ b/torch_points3d/applications/pointnet2.py @@ -0,0 +1,195 @@ +import os +import sys +from omegaconf import DictConfig, OmegaConf +from typing import Optional +import logging + +from torch_points3d.applications.modelfactory import ModelFactory +from torch_points3d.modules.pointnet2 import * +from torch_points3d.core.base_conv.dense import DenseFPModule + +from torch_points3d.applications.base_architectures.unet import UnwrappedUnetBasedModel + +# Must add multiscale batch +from torch_points3d.data.multiscale_data import MultiScaleBatch + +from torch_points3d.core.common_modules.dense_modules import Conv1D +from torch_points3d.core.common_modules.base_modules import Seq +from .utils import extract_output_nc + +CUR_FILE = os.path.realpath(__file__) +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/pointnet2") + +log = logging.getLogger(__name__) + + +def PointNet2( + architecture: Optional[str] = None, + input_nc: Optional[int] = None, + num_layers: Optional[int] = None, + config: Optional[DictConfig] = None, + multiscale: bool = False, + *args, + **kwargs +): + """Create a PointNet2 backbone model based on the architecture proposed in + https://arxiv.org/abs/1706.02413 + + Parameters + ---------- + architecture : str, optional + Architecture of the model, choose from unet, encoder and decoder + input_nc : int, optional + Number of channels for the input + output_nc : int, optional + If specified, then we add a fully connected head at the end of the network to provide the requested dimension + num_layers : int, optional + Depth of the network + config : DictConfig, optional + Custom config, overrides the num_layers and architecture parameters + """ + factory = PointNet2Factory( + architecture=architecture, + num_layers=num_layers, + input_nc=input_nc, + multiscale=multiscale, + config=config, + **kwargs + ) + return factory.build() + + +class PointNet2Factory(ModelFactory): + def _build_unet(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join( + PATH_TO_CONFIG, "unet_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss") + ) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return PointNet2Unet(model_config, None, modules_lib, **self.kwargs) + + def _build_encoder(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join( + PATH_TO_CONFIG, + "encoder_{}_{}.yaml".format(self.num_layers, "ms" if self.kwargs["multiscale"] else "ss"), + ) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return PointNet2Encoder(model_config, None, None, modules_lib, **self.kwargs) + + +class BasePointnet2(UnwrappedUnetBasedModel): + + CONV_TYPE = "dense" + + def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): + super(BasePointnet2, self).__init__(model_config, model_type, dataset, modules) + + try: + default_output_nc = extract_output_nc(model_config) + except: + default_output_nc = -1 + log.warning("Could not resolve number of output channels") + + self._has_mlp_head = False + self._output_nc = default_output_nc + if "output_nc" in kwargs: + self._has_mlp_head = True + self._output_nc = kwargs["output_nc"] + self.mlp = Seq() + self.mlp.append(Conv1D(default_output_nc, self._output_nc, bn=True, bias=False)) + + @property + def has_mlp_head(self): + return self._has_mlp_head + + @property + def output_nc(self): + return self._output_nc + + def _set_input(self, data): + """Unpack input data from the dataloader and perform necessary pre-processing steps.""" + assert len(data.pos.shape) == 3 + data = data.to(self.device) + if data.x is not None: + data.x = data.x.transpose(1, 2).contiguous() + else: + data.x = None + self.input = data + + +class PointNet2Encoder(BasePointnet2): + def forward(self, data, *args, **kwargs): + """ + Parameters: + ----------- + data + A dictionary that contains the data itself and its metadata information. Should contain + x -- Features [B, N, C] + pos -- Points [B, N, 3] + """ + self._set_input(data) + data = self.input + stack_down = [data] + for i in range(len(self.down_modules) - 1): + data = self.down_modules[i](data) + stack_down.append(data) + data = self.down_modules[-1](data) + + if not isinstance(self.inner_modules[0], Identity): + stack_down.append(data) + data = self.inner_modules[0](data) + + if self.has_mlp_head: + data.x = self.mlp(data.x) + return data + + +class PointNet2Unet(BasePointnet2): + def forward(self, data, *args, **kwargs): + """This method does a forward on the Unet assuming symmetrical skip connections + Input --- D1 -- D2 -- I -- U1 -- U2 -- U3 -- output + | | |________| | | + | |______________________| | + |___________________________________| + + Parameters: + ----------- + data + A dictionary that contains the data itself and its metadata information. Should contain + x -- Features [B, N, C] + pos -- Points [B, N, 3] + """ + self._set_input(data) + data = self.input + stack_down = [data] + for i in range(len(self.down_modules) - 1): + data = self.down_modules[i](data) + stack_down.append(data) + data = self.down_modules[-1](data) + + if not isinstance(self.inner_modules[0], Identity): + stack_down.append(data) + data = self.inner_modules[0](data) + + sampling_ids = self._collect_sampling_ids(stack_down) + + for i in range(len(self.up_modules)): + data = self.up_modules[i]((data, stack_down.pop())) + + for key, value in sampling_ids.items(): + setattr(data, key, value) + + if self.has_mlp_head: + data.x = self.mlp(data.x) + + return data diff --git a/torch_points3d/applications/sparseconv3d.py b/torch_points3d/applications/sparseconv3d.py index 84234da..2e5d129 100644 --- a/torch_points3d/applications/sparseconv3d.py +++ b/torch_points3d/applications/sparseconv3d.py @@ -6,8 +6,8 @@ from torch_geometric.data import Batch from torch_points3d.applications.modelfactory import ModelFactory -import torch_points3d.applications.modules.SparseConv3d as sp3d -from torch_points3d.applications.modules.SparseConv3d.modules import * +import torch_points3d.modules.SparseConv3d as sp3d +from torch_points3d.modules.SparseConv3d.modules import * # from torch_points3d.core.base_conv.message_passing import * # from torch_points3d.core.base_conv.partial_dense import * diff --git a/torch_points3d/core/base_conv/__init__.py b/torch_points3d/core/base_conv/__init__.py new file mode 100644 index 0000000..4f9db98 --- /dev/null +++ b/torch_points3d/core/base_conv/__init__.py @@ -0,0 +1 @@ +from .base_conv import * diff --git a/torch_points3d/core/base_conv/base_conv.py b/torch_points3d/core/base_conv/base_conv.py new file mode 100644 index 0000000..035c90d --- /dev/null +++ b/torch_points3d/core/base_conv/base_conv.py @@ -0,0 +1,10 @@ +from abc import ABC + +from torch_points3d.core.common_modules.base_modules import BaseModule + + +class BaseConvolution(ABC, BaseModule): + def __init__(self, sampler, neighbour_finder, *args, **kwargs): + BaseModule.__init__(self) + self.sampler = sampler + self.neighbour_finder = neighbour_finder diff --git a/torch_points3d/core/base_conv/dense.py b/torch_points3d/core/base_conv/dense.py new file mode 100644 index 0000000..2ba5918 --- /dev/null +++ b/torch_points3d/core/base_conv/dense.py @@ -0,0 +1,187 @@ +import numpy as np +import torch +from torch.nn import ( + Linear as Lin, + ReLU, + LeakyReLU, + BatchNorm1d as BN, + Dropout, +) +from torch_geometric.nn import ( + knn_interpolate, + fps, + radius, + global_max_pool, + global_mean_pool, + knn, +) +from torch_geometric.data import Data +import torch_points_kernels as tp + +from torch_points3d.core.spatial_ops import BaseMSNeighbourFinder +from torch_points3d.core.base_conv import BaseConvolution +from torch_points3d.core.common_modules.dense_modules import MLP2D + +from torch_points3d.utils.enums import ConvolutionFormat +from torch_points3d.utils.model_building_utils.activation_resolver import get_activation + +#################### THOSE MODULES IMPLEMENTS THE BASE DENSE CONV API ############################ + + +class BaseDenseConvolutionDown(BaseConvolution): + """Multiscale convolution down (also supports single scale). Convolution kernel is shared accross the scales + + Arguments: + sampler -- Strategy for sampling the input clouds + neighbour_finder -- Multiscale strategy for finding neighbours + """ + + CONV_TYPE = ConvolutionFormat.DENSE.value + + def __init__(self, sampler, neighbour_finder: BaseMSNeighbourFinder, *args, **kwargs): + super(BaseDenseConvolutionDown, self).__init__(sampler, neighbour_finder, *args, **kwargs) + self._index = kwargs.get("index", None) + self._save_sampling_id = kwargs.get("save_sampling_id", None) + + def conv(self, x, pos, new_pos, radius_idx, scale_idx): + """Implements a Dense convolution where radius_idx represents + the indexes of the points in x and pos to be agragated into the new feature + for each point in new_pos + + Arguments: + x -- Previous features [B, C, N] + pos -- Previous positions [B, N, 3] + new_pos -- Sampled positions [B, npoints, 3] + radius_idx -- Indexes to group [B, npoints, nsample] + scale_idx -- Scale index in multiscale convolutional layers + """ + raise NotImplementedError + + def forward(self, data, sample_idx=None, **kwargs): + """ + Parameters + ---------- + data: Data + x -- Previous features [B, C, N] + pos -- Previous positions [B, N, 3] + sample_idx: Optional[torch.Tensor] + can be used to shortcut the sampler [B,K] + """ + x, pos = data.x, data.pos + if sample_idx: + idx = sample_idx + else: + idx = self.sampler(pos) + idx = idx.unsqueeze(-1).repeat(1, 1, pos.shape[-1]).long() + new_pos = pos.gather(1, idx) + + ms_x = [] + for scale_idx in range(self.neighbour_finder.num_scales): + radius_idx = self.neighbour_finder(pos, new_pos, scale_idx=scale_idx) + ms_x.append(self.conv(x, pos, new_pos, radius_idx, scale_idx)) + new_x = torch.cat(ms_x, 1) + + new_data = Data(pos=new_pos, x=new_x) + if self._save_sampling_id: + setattr(new_data, "sampling_id_{}".format(self._index), idx[:, :, 0]) + return new_data + + +class BaseDenseConvolutionUp(BaseConvolution): + + CONV_TYPE = ConvolutionFormat.DENSE.value + + def __init__(self, neighbour_finder, *args, **kwargs): + super(BaseDenseConvolutionUp, self).__init__(None, neighbour_finder, *args, **kwargs) + self._index = kwargs.get("index", None) + self._skip = kwargs.get("skip", True) + + def conv(self, pos, pos_skip, x): + raise NotImplementedError + + def forward(self, data, **kwargs): + """Propagates features from one layer to the next. + data contains information from the down convs in data_skip + + Arguments: + data -- (data, data_skip) + """ + data, data_skip = data + pos, x = data.pos, data.x + pos_skip, x_skip = data_skip.pos, data_skip.x + + new_features = self.conv(pos, pos_skip, x) + + if x_skip is not None: + new_features = torch.cat([new_features, x_skip], dim=1) # (B, C2 + C1, n) + + new_features = new_features.unsqueeze(-1) + + if hasattr(self, "nn"): + new_features = self.nn(new_features) + + return Data(x=new_features.squeeze(-1), pos=pos_skip) + + +class DenseFPModule(BaseDenseConvolutionUp): + def __init__(self, up_conv_nn, bn=True, bias=False, activation=torch.nn.LeakyReLU(negative_slope=0.01), **kwargs): + super(DenseFPModule, self).__init__(None, **kwargs) + + self.nn = MLP2D(up_conv_nn, bn=bn, activation=activation, bias=False) + + def conv(self, pos, pos_skip, x): + assert pos_skip.shape[2] == 3 + + if pos is not None: + dist, idx = tp.three_nn(pos_skip, pos) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + interpolated_feats = tp.three_interpolate(x, idx, weight) + else: + interpolated_feats = x.expand(*(x.size()[0:2] + (pos_skip.size(1),))) + + return interpolated_feats + + def __repr__(self): + return "{}: {} ({})".format(self.__class__.__name__, self.nb_params, self.nn) + + +class GlobalDenseBaseModule(torch.nn.Module): + def __init__(self, nn, aggr="max", bn=True, activation=torch.nn.LeakyReLU(negative_slope=0.01), **kwargs): + super(GlobalDenseBaseModule, self).__init__() + self.nn = MLP2D(nn, bn=bn, activation=activation, bias=False) + if aggr.lower() not in ["mean", "max"]: + raise Exception("The aggregation provided is unrecognized {}".format(aggr)) + self._aggr = aggr.lower() + + @property + def nb_params(self): + """[This property is used to return the number of trainable parameters for a given layer] + It is useful for debugging and reproducibility. + Returns: + [type] -- [description] + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + self._nb_params = sum([np.prod(p.size()) for p in model_parameters]) + return self._nb_params + + def forward(self, data, **kwargs): + x, pos = data.x, data.pos + pos_flipped = pos.transpose(1, 2).contiguous() + + x = self.nn(torch.cat([x, pos_flipped], dim=1).unsqueeze(-1)) + + if self._aggr == "max": + x = x.squeeze(-1).max(-1)[0] + elif self._aggr == "mean": + x = x.squeeze(-1).mean(-1) + else: + raise NotImplementedError("The following aggregation {} is not recognized".format(self._aggr)) + + pos = None # pos.mean(1).unsqueeze(1) + x = x.unsqueeze(-1) + return Data(x=x, pos=pos) + + def __repr__(self): + return "{}: {} (aggr={}, {})".format(self.__class__.__name__, self.nb_params, self._aggr, self.nn) diff --git a/torch_points3d/core/base_conv/message_passing.py b/torch_points3d/core/base_conv/message_passing.py new file mode 100644 index 0000000..ff10535 --- /dev/null +++ b/torch_points3d/core/base_conv/message_passing.py @@ -0,0 +1,261 @@ +from abc import abstractmethod +from typing import * +import torch +from torch.nn import ( + Linear as Lin, + ReLU, + LeakyReLU, + BatchNorm1d as BN, + Dropout, +) +from torch_geometric.nn import ( + knn_interpolate, + fps, + radius, + global_max_pool, + global_mean_pool, + knn, +) +from torch_geometric.data import Batch + +from torch_points3d.core.base_conv.base_conv import * +from torch_points3d.core.common_modules import * +from torch_points3d.core.spatial_ops import * + + +def copy_from_to(data, batch): + for key in data.keys: + if key not in batch.keys: + setattr(batch, key, getattr(data, key, None)) + + +#################### THOSE MODULES IMPLEMENTS THE BASE MESSAGE_PASSING CONV API ############################ + + +class BaseConvolutionDown(BaseConvolution): + def __init__(self, sampler, neighbour_finder, *args, **kwargs): + super(BaseConvolutionDown, self).__init__(sampler, neighbour_finder, *args, **kwargs) + + self._index = kwargs.get("index", None) + + def conv(self, x, pos, edge_index, batch): + raise NotImplementedError + + def forward(self, data, **kwargs): + batch_obj = Batch() + x, pos, batch = data.x, data.pos, data.batch + idx = self.sampler(pos, batch) + row, col = self.neighbour_finder(pos, pos[idx], batch_x=batch, batch_y=batch[idx]) + edge_index = torch.stack([col, row], dim=0) + batch_obj.idx = idx + batch_obj.edge_index = edge_index + + batch_obj.x = self.conv(x, (pos[idx], pos), edge_index, batch) + + batch_obj.pos = pos[idx] + batch_obj.batch = batch[idx] + copy_from_to(data, batch_obj) + return batch_obj + + +class BaseMSConvolutionDown(BaseConvolution): + """Multiscale convolution down (also supports single scale). Convolution kernel is shared accross the scales + + Arguments: + sampler -- Strategy for sampling the input clouds + neighbour_finder -- Multiscale strategy for finding neighbours + """ + + def __init__(self, sampler, neighbour_finder: BaseMSNeighbourFinder, *args, **kwargs): + super(BaseMSConvolutionDown, self).__init__(sampler, neighbour_finder, *args, **kwargs) + + self._index = kwargs.get("index", None) + + def conv(self, x, pos, edge_index, batch): + raise NotImplementedError + + def forward(self, data, **kwargs): + batch_obj = Batch() + x, pos, batch = data.x, data.pos, data.batch + idx = self.sampler(pos, batch) + batch_obj.idx = idx + + ms_x = [] + for scale_idx in range(self.neighbour_finder.num_scales): + row, col = self.neighbour_finder( + pos, + pos[idx], + batch_x=batch, + batch_y=batch[idx], + scale_idx=scale_idx, + ) + edge_index = torch.stack([col, row], dim=0) + + ms_x.append(self.conv(x, (pos, pos[idx]), edge_index, batch)) + + batch_obj.x = torch.cat(ms_x, -1) + batch_obj.pos = pos[idx] + batch_obj.batch = batch[idx] + copy_from_to(data, batch_obj) + return batch_obj + + +class BaseConvolutionUp(BaseConvolution): + def __init__(self, neighbour_finder, *args, **kwargs): + super(BaseConvolutionUp, self).__init__(None, neighbour_finder, *args, **kwargs) + + self._index = kwargs.get("index", None) + self._skip = kwargs.get("skip", True) + + def conv(self, x, pos, pos_skip, batch, batch_skip, edge_index): + raise NotImplementedError + + def forward(self, data, **kwargs): + batch_obj = Batch() + data, data_skip = data + x, pos, batch = data.x, data.pos, data.batch + x_skip, pos_skip, batch_skip = data_skip.x, data_skip.pos, data_skip.batch + + if self.neighbour_finder is not None: + row, col = self.neighbour_finder(pos, pos_skip, batch, batch_skip) + edge_index = torch.stack([col, row], dim=0) + else: + edge_index = None + + x = self.conv(x, pos, pos_skip, batch, batch_skip, edge_index) + + if x_skip is not None and self._skip: + x = torch.cat([x, x_skip], dim=1) + + if hasattr(self, "nn"): + batch_obj.x = self.nn(x) + else: + batch_obj.x = x + copy_from_to(data_skip, batch_obj) + return batch_obj + + +class GlobalBaseModule(torch.nn.Module): + def __init__(self, nn, aggr="max", *args, **kwargs): + super(GlobalBaseModule, self).__init__() + self.nn = MLP(nn) + self.pool = global_max_pool if aggr == "max" else global_mean_pool + + def forward(self, data, **kwargs): + batch_obj = Batch() + x, pos, batch = data.x, data.pos, data.batch + if pos is not None: + x = self.nn(torch.cat([x, pos], dim=1)) + else: + x = self.nn(x) + x = self.pool(x, batch) + batch_obj.x = x + if pos is not None: + batch_obj.pos = pos.new_zeros((x.size(0), 3)) + batch_obj.batch = torch.arange(x.size(0), device=batch.device) + copy_from_to(data, batch_obj) + return batch_obj + + +#################### COMMON MODULE ######################## + + +class FPModule(BaseConvolutionUp): + """Upsampling module from PointNet++ + + Arguments: + k [int] -- number of nearest neighboors used for the interpolation + up_conv_nn [List[int]] -- list of feature sizes for the uplconv mlp + """ + + def __init__(self, up_k, up_conv_nn, *args, **kwargs): + super(FPModule, self).__init__(None) + + self.k = up_k + bn_momentum = kwargs.get("bn_momentum", 0.1) + self.nn = MLP(up_conv_nn, bn_momentum=bn_momentum, bias=False) + + def conv(self, x, pos, pos_skip, batch, batch_skip, *args): + return knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k) + + def extra_repr(self): + return "Nb parameters: %i" % self.nb_params + + +########################## BASE RESIDUAL DOWN ##################### + + +class BaseResnetBlockDown(BaseConvolutionDown): + def __init__(self, sampler, neighbour_finder, *args, **kwargs): + super(BaseResnetBlockDown, self).__init__(sampler, neighbour_finder, *args, **kwargs) + + in_features, out_features, conv_features = kwargs.get("down_conv_nn", None) + + self.in_features = in_features + self.out_features = out_features + self.conv_features = conv_features + + self.features_downsample_nn = MLP([self.in_features, self.conv_features]) + + self.features_upsample_nn = MLP([self.conv_features, self.out_features]) + self.shortcut_feature_resize_nn = MLP([self.in_features, self.out_features]) + + def convs(self, x, pos, edge_index): + raise NotImplementedError + + def conv(self, x, pos, edge_index): + shortcut = x + x = self.features_downsample_nn(x) + x, pos, edge_index, idx = self.convs(x, pos, edge_index) + x = self.features_upsample_nn(x) + if idx is not None: + shortcut = shortcut[idx] + shortcut = self.shortcut_feature_resize_nn(shortcut) + x = shortcut + x + return x + + +class BaseResnetBlock(torch.nn.Module): + def __init__(self, indim, outdim, convdim): + """ + indim: size of x at the input + outdim: desired size of x at the output + convdim: size of x following convolution + """ + torch.nn.Module.__init__(self) + + self.indim = indim + self.outdim = outdim + self.convdim = convdim + + self.features_downsample_nn = MLP([self.indim, self.outdim // 4]) + self.features_upsample_nn = MLP([self.convdim, self.outdim]) + + self.shortcut_feature_resize_nn = MLP([self.indim, self.outdim]) + + self.activation = ReLU() + + @property + @abstractmethod + def convs(self): + pass + + def forward(self, data, **kwargs): + batch_obj = Batch() + x = data.x # (N, indim) + shortcut = x # (N, indim) + x = self.features_downsample_nn(x) # (N, outdim//4) + # if this is an identity resnet block, idx will be None + data = self.convs(data) # (N', convdim) + x = data.x + idx = data.idx + x = self.features_upsample_nn(x) # (N', outdim) + if idx is not None: + shortcut = shortcut[idx] # (N', indim) + shortcut = self.shortcut_feature_resize_nn(shortcut) # (N', outdim) + x = shortcut + x + batch_obj.x = x + batch_obj.pos = data.pos + batch_obj.batch = data.batch + copy_from_to(data, batch_obj) + return batch_obj diff --git a/torch_points3d/core/base_conv/partial_dense.py b/torch_points3d/core/base_conv/partial_dense.py new file mode 100644 index 0000000..37b4fbf --- /dev/null +++ b/torch_points3d/core/base_conv/partial_dense.py @@ -0,0 +1,146 @@ +from typing import * +import torch +from torch.nn import ( + Linear as Lin, + ReLU, + LeakyReLU, + BatchNorm1d as BN, + Dropout, +) +from torch_geometric.nn import ( + knn_interpolate, + fps, + radius, + global_max_pool, + global_mean_pool, + knn, +) +from torch_geometric.data import Batch + +from torch_points3d.core.spatial_ops import * +from .base_conv import BaseConvolution +from torch_points3d.core.common_modules.base_modules import BaseModule +from torch_points3d.core.common_modules import MLP + + +#################### THOSE MODULES IMPLEMENTS THE BASE PARTIAL_DENSE CONV API ############################ + + +def copy_from_to(data, batch): + for key in data.keys: + if key not in batch.keys: + setattr(batch, key, getattr(data, key, None)) + + +class BasePartialDenseConvolutionDown(BaseConvolution): + + CONV_TYPE = ConvolutionFormat.PARTIAL_DENSE.value + + def __init__(self, sampler, neighbour_finder, *args, **kwargs): + super(BasePartialDenseConvolutionDown, self).__init__(sampler, neighbour_finder, *args, **kwargs) + + self._index = kwargs.get("index", None) + + def conv(self, x, pos, x_neighbour, pos_centered_neighbour, idx_neighbour, idx_sampler): + """Generic down convolution for partial dense data + + Arguments: + x [N, C] -- features + pos [N, 3] -- positions + x_neighbour [N, n_neighbours, C] -- features of the neighbours of each point in x + pos_centered_neighbour [N, n_neighbours, 3] -- position of the neighbours of x_i centred on x_i + idx_neighbour [N, n_neighbours] -- indices of the neighbours of each point in pos + idx_sampler [n] -- points to keep for the output + + Raises: + NotImplementedError: [description] + """ + raise NotImplementedError + + def forward(self, data, **kwargs): + batch_obj = Batch() + x, pos, batch = data.x, data.pos, data.batch + idx_sampler = self.sampler(pos=pos, x=x, batch=batch) + + idx_neighbour = self.neighbour_finder(pos, pos, batch_x=batch, batch_y=batch) + + shadow_x = torch.full((1,) + x.shape[1:], self.shadow_features_fill).to(x.device) + shadow_points = torch.full((1,) + pos.shape[1:], self.shadow_points_fill_).to(x.device) + + x = torch.cat([x, shadow_x], dim=0) + pos = torch.cat([pos, shadow_points], dim=0) + + x_neighbour = x[idx_neighbour] + pos_centered_neighbour = pos[idx_neighbour] - pos[:-1].unsqueeze(1) # Centered the points, no shadow point + + batch_obj.x = self.conv(x, pos, x_neighbour, pos_centered_neighbour, idx_neighbour, idx_sampler) + + batch_obj.pos = pos[idx_sampler] + batch_obj.batch = batch[idx_sampler] + copy_from_to(data, batch_obj) + return batch_obj + + +class GlobalPartialDenseBaseModule(torch.nn.Module): + def __init__(self, nn, aggr="max", *args, **kwargs): + super(GlobalPartialDenseBaseModule, self).__init__() + + self.nn = MLP(nn) + self.pool = global_max_pool if aggr == "max" else global_mean_pool + + def forward(self, data, **kwargs): + batch_obj = Batch() + x, pos, batch = data.x, data.pos, data.batch + x = self.nn(torch.cat([x, pos], dim=1)) + x = self.pool(x, batch) + batch_obj.x = x + batch_obj.pos = pos.new_zeros((x.size(0), 3)) + batch_obj.batch = torch.arange(x.size(0), device=x.device) + copy_from_to(data, batch_obj) + return batch_obj + + +class FPModule_PD(BaseModule): + """Upsampling module from PointNet++ + Arguments: + k [int] -- number of nearest neighboors used for the interpolation + up_conv_nn [List[int]] -- list of feature sizes for the uplconv mlp + """ + + def __init__(self, up_k, up_conv_nn, *args, **kwargs): + super(FPModule_PD, self).__init__() + self.upsample_op = KNNInterpolate(up_k) + bn_momentum = kwargs.get("bn_momentum", 0.1) + self.nn = MLP(up_conv_nn, bn_momentum=bn_momentum, bias=False) + + def forward(self, data, precomputed=None, **kwargs): + data, data_skip = data + batch_out = data_skip.clone() + x_skip = data_skip.x + + has_innermost = len(data.x) == data.batch.max() + 1 + + if precomputed and not has_innermost: + if not hasattr(data, "up_idx"): + setattr(batch_out, "up_idx", 0) + else: + setattr(batch_out, "up_idx", data.up_idx) + + pre_data = precomputed[batch_out.up_idx] + batch_out.up_idx = batch_out.up_idx + 1 + else: + pre_data = None + + if has_innermost: + x = torch.gather(data.x, 0, data_skip.batch.unsqueeze(-1).repeat((1, data.x.shape[-1]))) + else: + x = self.upsample_op(data, data_skip, precomputed=pre_data) + + if x_skip is not None: + x = torch.cat([x, x_skip], dim=1) + + if hasattr(self, "nn"): + batch_out.x = self.nn(x) + else: + batch_out.x = x + return batch_out diff --git a/torch_points3d/core/spatial_ops/__init__.py b/torch_points3d/core/spatial_ops/__init__.py new file mode 100644 index 0000000..f8ac094 --- /dev/null +++ b/torch_points3d/core/spatial_ops/__init__.py @@ -0,0 +1,3 @@ +from .neighbour_finder import * +from .sampling import * +from .interpolate import * diff --git a/torch_points3d/core/spatial_ops/interpolate.py b/torch_points3d/core/spatial_ops/interpolate.py new file mode 100644 index 0000000..aef310c --- /dev/null +++ b/torch_points3d/core/spatial_ops/interpolate.py @@ -0,0 +1,69 @@ +import torch +from torch_geometric.nn import knn_interpolate, knn +from torch_scatter import scatter_add +from torch_geometric.data import Data + + +class KNNInterpolate: + def __init__(self, k): + self.k = k + + def precompute(self, query, support): + """Precomputes a data structure that can be used in the transform itself to speed things up""" + pos_x, pos_y = query.pos, support.pos + if hasattr(support, "batch"): + batch_y = support.batch + else: + batch_y = torch.zeros((support.num_nodes,), dtype=torch.long) + if hasattr(query, "batch"): + batch_x = query.batch + else: + batch_x = torch.zeros((query.num_nodes,), dtype=torch.long) + + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, self.k, batch_x=batch_x, batch_y=batch_y) + y_idx, x_idx = assign_index + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + normalisation = scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0)) + + return Data(num_nodes=support.num_nodes, x_idx=x_idx, y_idx=y_idx, weights=weights, normalisation=normalisation) + + def __call__(self, query, support, precomputed: Data = None): + """Computes a new set of features going from the query resolution position to the support + resolution position + Args: + - query: data structure that holds the low res data (position + features) + - support: data structure that holds the position to which we will interpolate + Returns: + - torch.tensor: interpolated features + """ + if precomputed: + num_points = support.pos.size(0) + if num_points != precomputed.num_nodes: + raise ValueError("Precomputed indices do not match with the data given to the transform") + + x = query.x + x_idx, y_idx, weights, normalisation = ( + precomputed.x_idx, + precomputed.y_idx, + precomputed.weights, + precomputed.normalisation, + ) + y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=num_points) + y = y / normalisation + return y + + x, pos = query.x, query.pos + pos_support = support.pos + if hasattr(support, "batch"): + batch_support = support.batch + else: + batch_support = torch.zeros((support.num_nodes,), dtype=torch.long) + if hasattr(query, "batch"): + batch = query.batch + else: + batch = torch.zeros((query.num_nodes,), dtype=torch.long) + + return knn_interpolate(x, pos, pos_support, batch, batch_support, k=self.k) diff --git a/torch_points3d/core/spatial_ops/neighbour_finder.py b/torch_points3d/core/spatial_ops/neighbour_finder.py new file mode 100644 index 0000000..51b835f --- /dev/null +++ b/torch_points3d/core/spatial_ops/neighbour_finder.py @@ -0,0 +1,182 @@ +from abc import ABC, abstractmethod +from typing import List, Union, cast +import torch +from torch_geometric.nn import knn, radius +import torch_points_kernels as tp + +from torch_points3d.utils.config import is_list +from torch_points3d.utils.enums import ConvolutionFormat + +from torch_points3d.utils.debugging_vars import DEBUGGING_VARS, DistributionNeighbour + + +class BaseNeighbourFinder(ABC): + def __call__(self, x, y, batch_x, batch_y): + return self.find_neighbours(x, y, batch_x, batch_y) + + @abstractmethod + def find_neighbours(self, x, y, batch_x, batch_y): + pass + + def __repr__(self): + return str(self.__class__.__name__) + " " + str(self.__dict__) + + +class RadiusNeighbourFinder(BaseNeighbourFinder): + def __init__(self, radius: float, max_num_neighbors: int = 64, conv_type=ConvolutionFormat.MESSAGE_PASSING.value): + self._radius = radius + self._max_num_neighbors = max_num_neighbors + self._conv_type = conv_type.lower() + + def find_neighbours(self, x, y, batch_x=None, batch_y=None): + if self._conv_type == ConvolutionFormat.MESSAGE_PASSING.value: + return radius(x, y, self._radius, batch_x, batch_y, max_num_neighbors=self._max_num_neighbors) + elif self._conv_type == ConvolutionFormat.DENSE.value or ConvolutionFormat.PARTIAL_DENSE.value: + return tp.ball_query( + self._radius, self._max_num_neighbors, x, y, mode=self._conv_type, batch_x=batch_x, batch_y=batch_y + )[0] + else: + raise NotImplementedError + + +class KNNNeighbourFinder(BaseNeighbourFinder): + def __init__(self, k): + self.k = k + + def find_neighbours(self, x, y, batch_x, batch_y): + return knn(x, y, self.k, batch_x, batch_y) + + +class DilatedKNNNeighbourFinder(BaseNeighbourFinder): + def __init__(self, k, dilation): + self.k = k + self.dilation = dilation + self.initialFinder = KNNNeighbourFinder(k * dilation) + + def find_neighbours(self, x, y, batch_x, batch_y): + # find the self.k * self.dilation closest neighbours in x for each y + row, col = self.initialFinder.find_neighbours(x, y, batch_x, batch_y) + + # for each point in y, randomly select k of its neighbours + index = torch.randint( + self.k * self.dilation, + (len(y), self.k), + device=row.device, + dtype=torch.long, + ) + + arange = torch.arange(len(y), dtype=torch.long, device=row.device) + arange = arange * (self.k * self.dilation) + index = (index + arange.view(-1, 1)).view(-1) + row, col = row[index], col[index] + + return row, col + + +class BaseMSNeighbourFinder(ABC): + def __call__(self, x, y, batch_x=None, batch_y=None, scale_idx=0): + return self.find_neighbours(x, y, batch_x=batch_x, batch_y=batch_y, scale_idx=scale_idx) + + @abstractmethod + def find_neighbours(self, x, y, batch_x=None, batch_y=None, scale_idx=0): + pass + + @property + @abstractmethod + def num_scales(self): + pass + + @property + def dist_meters(self): + return getattr(self, "_dist_meters", None) + + +class MultiscaleRadiusNeighbourFinder(BaseMSNeighbourFinder): + """Radius search with support for multiscale for sparse graphs + + Arguments: + radius {Union[float, List[float]]} + + Keyword Arguments: + max_num_neighbors {Union[int, List[int]]} (default: {64}) + + Raises: + ValueError: [description] + """ + + def __init__( + self, + radius: Union[float, List[float]], + max_num_neighbors: Union[int, List[int]] = 64, + ): + if DEBUGGING_VARS["FIND_NEIGHBOUR_DIST"]: + if not isinstance(radius, list): + radius = [radius] + self._dist_meters = [DistributionNeighbour(r) for r in radius] + if not isinstance(max_num_neighbors, list): + max_num_neighbors = [max_num_neighbors] + max_num_neighbors = [256 for _ in max_num_neighbors] + + if not is_list(max_num_neighbors) and is_list(radius): + self._radius = cast(list, radius) + max_num_neighbors = cast(int, max_num_neighbors) + self._max_num_neighbors = [max_num_neighbors for i in range(len(self._radius))] + return + + if not is_list(radius) and is_list(max_num_neighbors): + self._max_num_neighbors = cast(list, max_num_neighbors) + radius = cast(int, radius) + self._radius = [radius for i in range(len(self._max_num_neighbors))] + return + + if is_list(max_num_neighbors): + max_num_neighbors = cast(list, max_num_neighbors) + radius = cast(list, radius) + if len(max_num_neighbors) != len(radius): + raise ValueError("Both lists max_num_neighbors and radius should be of the same length") + self._max_num_neighbors = max_num_neighbors + self._radius = radius + return + + self._max_num_neighbors = [cast(int, max_num_neighbors)] + self._radius = [cast(int, radius)] + + def find_neighbours(self, x, y, batch_x=None, batch_y=None, scale_idx=0): + if scale_idx >= self.num_scales: + raise ValueError("Scale %i is out of bounds %i" % (scale_idx, self.num_scales)) + + radius_idx = radius( + x, y, self._radius[scale_idx], batch_x, batch_y, max_num_neighbors=self._max_num_neighbors[scale_idx] + ) + return radius_idx + + @property + def num_scales(self): + return len(self._radius) + + def __call__(self, x, y, batch_x=None, batch_y=None, scale_idx=0): + """Sparse interface of the neighboorhood finder""" + return self.find_neighbours(x, y, batch_x, batch_y, scale_idx) + + +class DenseRadiusNeighbourFinder(MultiscaleRadiusNeighbourFinder): + """Multiscale radius search for dense graphs""" + + def find_neighbours(self, x, y, scale_idx=0): + if scale_idx >= self.num_scales: + raise ValueError("Scale %i is out of bounds %i" % (scale_idx, self.num_scales)) + num_neighbours = self._max_num_neighbors[scale_idx] + neighbours = tp.ball_query(self._radius[scale_idx], num_neighbours, x, y)[0] + + if DEBUGGING_VARS["FIND_NEIGHBOUR_DIST"]: + for i in range(neighbours.shape[0]): + start = neighbours[i, :, 0] + valid_neighbours = (neighbours[i, :, 1:] != start.view((-1, 1)).repeat(1, num_neighbours - 1)).sum( + 1 + ) + 1 + self._dist_meters[scale_idx].add_valid_neighbours(valid_neighbours) + return neighbours + + def __call__(self, x, y, scale_idx=0, **kwargs): + """Dense interface of the neighboorhood finder""" + return self.find_neighbours(x, y, scale_idx) diff --git a/torch_points3d/core/spatial_ops/sampling.py b/torch_points3d/core/spatial_ops/sampling.py new file mode 100644 index 0000000..7198d3e --- /dev/null +++ b/torch_points3d/core/spatial_ops/sampling.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +import math +import torch +from torch_geometric.nn import voxel_grid +from torch_geometric.nn.pool.consecutive import consecutive_cluster +from torch_geometric.nn.pool.pool import pool_pos, pool_batch +import torch_points_kernels as tp + +from torch_points3d.utils.config import is_list +from torch_points3d.utils.enums import ConvolutionFormat + + +class BaseSampler(ABC): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def __init__(self, ratio=None, num_to_sample=None, subsampling_param=None): + if num_to_sample is not None: + if (ratio is not None) or (subsampling_param is not None): + raise ValueError("Can only specify ratio or num_to_sample or subsampling_param, not several !") + self._num_to_sample = num_to_sample + + elif ratio is not None: + self._ratio = ratio + + elif subsampling_param is not None: + self._subsampling_param = subsampling_param + + else: + raise Exception('At least ["ratio, num_to_sample, subsampling_param"] should be defined') + + def __call__(self, pos, x=None, batch=None): + return self.sample(pos, batch=batch, x=x) + + def _get_num_to_sample(self, batch_size) -> int: + if hasattr(self, "_num_to_sample"): + return self._num_to_sample + else: + return math.floor(batch_size * self._ratio) + + def _get_ratio_to_sample(self, batch_size) -> float: + if hasattr(self, "_ratio"): + return self._ratio + else: + return self._num_to_sample / float(batch_size) + + @abstractmethod + def sample(self, pos, x=None, batch=None): + pass + + +class FPSSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos, batch, **kwargs): + from torch_geometric.nn import fps + + if len(pos.shape) != 2: + raise ValueError(" This class is for sparse data and expects the pos tensor to be of dimension 2") + return fps(pos, batch, ratio=self._get_ratio_to_sample(pos.shape[0])) + + +class GridSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos=None, x=None, batch=None): + if len(pos.shape) != 2: + raise ValueError("This class is for sparse data and expects the pos tensor to be of dimension 2") + + pool = voxel_grid(pos, batch, self._subsampling_param) + pool, perm = consecutive_cluster(pool) + batch = pool_batch(perm, batch) + if x is not None: + return pool_pos(pool, x), pool_pos(pool, pos), batch + else: + return None, pool_pos(pool, pos), batch + + +class DenseFPSSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos, **kwargs): + """Sample pos + + Arguments: + pos -- [B, N, 3] + + Returns: + indexes -- [B, num_sample] + """ + if len(pos.shape) != 3: + raise ValueError(" This class is for dense data and expects the pos tensor to be of dimension 2") + return tp.furthest_point_sample(pos, self._get_num_to_sample(pos.shape[1])) + + +class RandomSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos, batch, **kwargs): + if len(pos.shape) != 2: + raise ValueError(" This class is for sparse data and expects the pos tensor to be of dimension 2") + idx = torch.randint(0, pos.shape[0], (self._get_num_to_sample(pos.shape[0]),)) + return idx + + +class DenseRandomSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + Arguments: + pos -- [B, N, 3] + """ + + def sample(self, pos, **kwargs): + if len(pos.shape) != 3: + raise ValueError(" This class is for dense data and expects the pos tensor to be of dimension 2") + idx = torch.randint(0, pos.shape[1], (self._get_num_to_sample(pos.shape[1]),)) + return idx diff --git a/torch_points3d/data/batch.py b/torch_points3d/data/batch.py new file mode 100644 index 0000000..51a1536 --- /dev/null +++ b/torch_points3d/data/batch.py @@ -0,0 +1,58 @@ +import torch +from torch_geometric.data import Data + + +class SimpleBatch(Data): + r""" A classic batch object wrapper with :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + """ + + def __init__(self, batch=None, **kwargs): + super(SimpleBatch, self).__init__(**kwargs) + + self.batch = batch + self.__data_class__ = Data + + @staticmethod + def from_data_list(data_list): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + """ + keys = [set(data.keys) for data in data_list] + keys = list(set.union(*keys)) + + # Check if all dimensions matches and we can concatenate data + # if len(data_list) > 0: + # for data in data_list[1:]: + # for key in keys: + # assert data_list[0][key].shape == data[key].shape + + batch = SimpleBatch() + batch.__data_class__ = data_list[0].__class__ + + for key in keys: + batch[key] = [] + + for _, data in enumerate(data_list): + for key in data.keys: + item = data[key] + batch[key].append(item) + + for key in batch.keys: + item = batch[key][0] + if ( + torch.is_tensor(item) + or isinstance(item, int) + or isinstance(item, float) + ): + batch[key] = torch.stack(batch[key]) + else: + raise ValueError("Unsupported attribute type") + + return batch.contiguous() + # return [batch.x.transpose(1, 2).contiguous(), batch.pos, batch.y.view(-1)] + + @property + def num_graphs(self): + """Returns the number of graphs in the batch.""" + return self.batch[-1].item() + 1 diff --git a/torch_points3d/data/multiscale_data.py b/torch_points3d/data/multiscale_data.py new file mode 100644 index 0000000..c66a746 --- /dev/null +++ b/torch_points3d/data/multiscale_data.py @@ -0,0 +1,165 @@ +from typing import List, Optional +import torch +import copy +import torch_geometric +from torch_geometric.data import Data +from torch_geometric.data import Batch + + +class MultiScaleData(Data): + def __init__( + self, + x=None, + y=None, + pos=None, + multiscale: Optional[List[Data]] = None, + upsample: Optional[List[Data]] = None, + **kwargs, + ): + super().__init__(x=x, y=y, pos=pos, multiscale=multiscale, upsample=upsample, **kwargs) + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor and Data attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + if torch.is_tensor(item): + self[key] = func(item) + for scale in range(self.num_scales): + self.multiscale[scale] = self.multiscale[scale].apply(func) + + for up in range(self.num_upsample): + self.upsample[up] = self.upsample[up].apply(func) + return self + + @property + def num_scales(self): + """ Number of scales in the multiscale array + """ + return len(self.multiscale) if hasattr(self, "multiscale") and self.multiscale else 0 + + @property + def num_upsample(self): + """ Number of upsample operations + """ + return len(self.upsample) if hasattr(self, "upsample") and self.upsample else 0 + + @classmethod + def from_data(cls, data): + ms_data = cls() + for k, item in data: + ms_data[k] = item + return ms_data + + +class MultiScaleBatch(MultiScaleData): + @staticmethod + def from_data_list(data_list, follow_batch=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`.""" + for data in data_list: + assert isinstance(data, MultiScaleData) + num_scales = data_list[0].num_scales + for data_entry in data_list: + assert data_entry.num_scales == num_scales, "All data objects should contain the same number of scales" + num_upsample = data_list[0].num_upsample + for data_entry in data_list: + assert data_entry.num_upsample == num_upsample, "All data objects should contain the same number of scales" + + # Build multiscale batches + multiscale = [] + for scale in range(num_scales): + ms_scale = [] + for data_entry in data_list: + ms_scale.append(data_entry.multiscale[scale]) + multiscale.append(from_data_list_token(ms_scale)) + + # Build upsample batches + upsample = [] + for scale in range(num_upsample): + upsample_scale = [] + for data_entry in data_list: + upsample_scale.append(data_entry.upsample[scale]) + upsample.append(from_data_list_token(upsample_scale)) + + # Create batch from non multiscale data + for data_entry in data_list: + del data_entry.multiscale + del data_entry.upsample + batch = Batch.from_data_list(data_list) + batch = MultiScaleBatch.from_data(batch) + batch.multiscale = multiscale + batch.upsample = upsample + + if torch_geometric.is_debug_enabled(): + batch.debug() + + return batch + + +def from_data_list_token(data_list, follow_batch=[]): + """ This is pretty a copy paste of the from data list of pytorch geometric + batch object with the difference that indexes that are negative are not incremented + """ + + keys = [set(data.keys) for data in data_list] + keys = list(set.union(*keys)) + assert "batch" not in keys + + batch = Batch() + batch.__data_class__ = data_list[0].__class__ + batch.__slices__ = {key: [0] for key in keys} + + for key in keys: + batch[key] = [] + + for key in follow_batch: + batch["{}_batch".format(key)] = [] + + cumsum = {key: 0 for key in keys} + batch.batch = [] + for i, data in enumerate(data_list): + for key in data.keys: + item = data[key] + if torch.is_tensor(item) and item.dtype != torch.bool and cumsum[key] > 0: + mask = item >= 0 + item[mask] = item[mask] + cumsum[key] + if torch.is_tensor(item): + size = item.size(data.__cat_dim__(key, data[key])) + else: + size = 1 + batch.__slices__[key].append(size + batch.__slices__[key][-1]) + cumsum[key] += data.__inc__(key, item) + batch[key].append(item) + + if key in follow_batch: + item = torch.full((size,), i, dtype=torch.long) + batch["{}_batch".format(key)].append(item) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long) + batch.batch.append(item) + + if num_nodes is None: + batch.batch = None + + for key in batch.keys: + item = batch[key][0] + if torch.is_tensor(item): + batch[key] = torch.cat( + batch[key], dim=data_list[0].__cat_dim__(key, item)) + elif isinstance(item, int) or isinstance(item, float): + batch[key] = torch.tensor(batch[key]) + else: + raise ValueError( + "Unsupported attribute type {} : {}".format(type(item), item)) + + if torch_geometric.is_debug_enabled(): + batch.debug() + + return batch.contiguous() diff --git a/torch_points3d/data/pair.py b/torch_points3d/data/pair.py new file mode 100644 index 0000000..42c6b62 --- /dev/null +++ b/torch_points3d/data/pair.py @@ -0,0 +1,260 @@ +import torch +from typing import List, Optional, Tuple +from torch_geometric.data import Data +from torch_geometric.data import Batch +from torch_points3d.data.multiscale_data import MultiScaleBatch, MultiScaleData +import re + +class Pair(Data): + + def __init__( + self, + x=None, + y=None, + pos=None, + x_target=None, + pos_target=None, + **kwargs, + ): + self.__data_class__ = Data + super(Pair, self).__init__(x=x, pos=pos, + x_target=x_target, pos_target=pos_target, **kwargs) + + + @classmethod + def make_pair(cls, data_source: Data, data_target: Data): + """ + add in a Data object the source elem, the target elem. + """ + # add concatenation of the point cloud + batch = cls() + for key in data_source.keys: + batch[key] = data_source[key] + for key_target in data_target.keys: + batch[key_target+"_target"] = data_target[key_target] + if(batch.x is None): + batch["x_target"] = None + return batch.contiguous() + + def to_data(self) -> Tuple[Data, Data]: + data_source = self.__data_class__() + data_target = self.__data_class__() + for key in self.keys: + match = re.search(r"(.+)_target$", key) + if match is None: + data_source[key] = self[key] + else: + new_key = match.groups()[0] + data_target[new_key] = self[key] + return data_source, data_target + + @property + def num_nodes_target(self): + for key, item in self('x_target', 'pos_target', 'norm_target', 'batch_target'): + return item.size(self.__cat_dim__(key, item)) + return None + + +class MultiScalePair(Pair): + def __init__( + self, + x=None, + y=None, + pos=None, + multiscale: Optional[List[Data]] = None, + upsample: Optional[List[Data]] = None, + x_target=None, + pos_target=None, + multiscale_target: Optional[List[Data]] = None, + upsample_target: Optional[List[Data]] = None, + **kwargs, + ): + super(MultiScalePair, self).__init__(x=x, pos=pos, + multiscale=multiscale, + upsample=upsample, + x_target=x_target, pos_target=pos_target, + multiscale_target=multiscale_target, + upsample_target=upsample_target, + **kwargs) + self.__data_class__ = MultiScaleData + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor and Data attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + if torch.is_tensor(item): + self[key] = func(item) + for scale in range(self.num_scales): + self.multiscale[scale] = self.multiscale[scale].apply(func) + self.multiscale_target[scale] = self.multiscale_target[scale].apply(func) + + for up in range(self.num_upsample): + self.upsample[up] = self.upsample[up].apply(func) + self.upsample_target[up] = self.upsample_target[up].apply(func) + return self + + @property + def num_scales(self): + """ Number of scales in the multiscale array + """ + return len(self.multiscale) if self.multiscale else 0 + + @property + def num_upsample(self): + """ Number of upsample operations + """ + return len(self.upsample) if self.upsample else 0 + + @classmethod + def from_data(cls, data): + ms_data = cls() + for k, item in data: + ms_data[k] = item + return ms_data + + +class PairBatch(Pair): + + def __init__(self, batch=None, batch_target=None, **kwargs): + r""" + Pair batch for message passing + """ + self.batch_target = batch_target + self.batch = None + super(PairBatch, self).__init__(**kwargs) + self.__data_class__ = Batch + + @staticmethod + def from_data_list(data_list): + r""" + from a list of torch_points3d.data.pair.Pair objects, create + a batch + Warning : follow_batch is not here yet... + """ + assert isinstance(data_list[0], Pair) + data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list]))) + if hasattr(data_list_s[0], 'pair_ind'): + pair_ind = concatenate_pair_ind(data_list_s, data_list_t) + else: + pair_ind = None + batch_s = Batch.from_data_list(data_list_s) + batch_t = Batch.from_data_list(data_list_t) + pair = PairBatch.make_pair(batch_s, batch_t) + pair.pair_ind = pair_ind + return pair.contiguous() + +class PairMultiScaleBatch(MultiScalePair): + + def __init__(self, batch=None, batch_target=None, **kwargs): + self.batch = batch + self.batch_target = batch_target + super(PairMultiScaleBatch, self).__init__(**kwargs) + self.__data_class__ = MultiScaleBatch + + @staticmethod + def from_data_list(data_list): + r""" + from a list of torch_points3d.datasets.registation.pair.Pair objects, create + a batch + Warning : follow_batch is not here yet... + """ + data_list_s, data_list_t = list(map(list, zip(*[data.to_data() for data in data_list]))) + if hasattr(data_list_s[0], 'pair_ind'): + pair_ind = concatenate_pair_ind(data_list_s, data_list_t).to(torch.long) + else: + pair_ind = None + batch_s = MultiScaleBatch.from_data_list(data_list_s) + batch_t = MultiScaleBatch.from_data_list(data_list_t) + pair = PairMultiScaleBatch.make_pair(batch_s, batch_t) + pair.pair_ind = pair_ind + return pair.contiguous() + + +class DensePairBatch(Pair): + r""" A classic batch object wrapper with :class:`Pair`. Used for Dense Pair Batch (ie pointcloud with fixed size). + """ + + def __init__(self, batch=None, **kwargs): + super(DensePairBatch, self).__init__(**kwargs) + + self.batch = batch + self.__data_class__ = Data + + @staticmethod + def from_data_list(data_list): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + """ + keys = [set(data.keys) for data in data_list] + keys = list(set.union(*keys)) + + # Check if all dimensions matches and we can concatenate data + # if len(data_list) > 0: + # for data in data_list[1:]: + # for key in keys: + # assert data_list[0][key].shape == data[key].shape + + batch = DensePairBatch() + batch.__data_class__ = data_list[0].__class__ + + for key in keys: + batch[key] = [] + + for _, data in enumerate(data_list): + for key in data.keys: + item = data[key] + batch[key].append(item) + + for key in batch.keys: + item = batch[key][0] + if ( + torch.is_tensor(item) + or isinstance(item, int) + or isinstance(item, float) + ): + if key != "pair_ind": + batch[key] = torch.stack(batch[key]) + else: + raise ValueError("Unsupported attribute type") + # add pair_ind for dense data too + if getattr(data_list[0], 'pair_ind', None) is not None: + pair_ind = concatenate_pair_ind(data_list, data_list).to(torch.long) + else: + pair_ind = None + batch.pair_ind = pair_ind + return batch.contiguous() + # return [batch.x.transpose(1, 2).contiguous(), batch.pos, batch.y.view(-1)] + + @property + def num_graphs(self): + """Returns the number of graphs in the batch.""" + return self.batch[-1].item() + 1 + + +def concatenate_pair_ind(list_data_source, list_data_target): + """ + for a list of pair of indices batched, change the index it refers to wrt the batch index + Parameters + ---------- + list_data_source: list[Data] + list_data_target: list[Data] + Returns + ------- + torch.Tensor + indices of y corrected wrt batch indices + + + """ + + assert len(list_data_source) == len(list_data_target) + assert hasattr(list_data_source[0], "pair_ind") + list_pair_ind = [] + cum_size = torch.zeros(2) + for i in range(len(list_data_source)): + size = torch.tensor([len(list_data_source[i].pos), + len(list_data_target[i].pos)]) + list_pair_ind.append(list_data_source[i].pair_ind + cum_size) + cum_size = cum_size + size + return torch.cat(list_pair_ind, 0) diff --git a/torch_points3d/datasets/base_dataset.py b/torch_points3d/datasets/base_dataset.py index 1bb5f1c..f62c62b 100644 --- a/torch_points3d/datasets/base_dataset.py +++ b/torch_points3d/datasets/base_dataset.py @@ -1,11 +1,20 @@ from typing import Any, Callable, Dict, Optional, Sequence from dataclasses import dataclass +from functools import partial +import numpy as np import hydra import torch_geometric import pytorch_lightning as pl +from torch.utils.data import Dataset from torch.utils.data import DataLoader from torch_points3d.core.config import BaseDataConfig +from torch_geometric.data import Data +from torch_points3d.data.multiscale_data import MultiScaleBatch +from torch_points3d.data.batch import SimpleBatch + +from torch_points3d.utils.enums import ConvolutionFormat +from torch_points3d.utils.config import ConvolutionFormatFactory @dataclass @@ -13,6 +22,7 @@ class PointCloudDataConfig(BaseDataConfig): batch_size: int = 32 num_workers: int = 0 dataroot: str = "data" + conv_type: str = "dense" pre_transform: Sequence[Any] = None train_transform: Sequence[Any] = None test_transform: Sequence[Any] = None @@ -22,39 +32,74 @@ class PointCloudDataModule(pl.LightningDataModule): def __init__(self, cfg: PointCloudDataConfig = PointCloudDataConfig()) -> None: super().__init__() self.cfg = cfg - self.ds = None - + self.ds: Optional[Dict[str, Dataset]] = None self.cfg.dataroot = hydra.utils.to_absolute_path(self.cfg.dataroot) def train_dataloader(self) -> DataLoader: - return DataLoader( - self.ds["train"], batch_size=self.batch_size, num_workers=self.cfg.num_workers, collate_fn=self.collate_fn, + return self._dataloader( + self.ds["train"], conv_type=self.cfg.conv_type ) def val_dataloader(self) -> DataLoader: - return DataLoader( - self.ds["validation"], - batch_size=self.batch_size, - num_workers=self.cfg.num_workers, - collate_fn=self.collate_fn, + return self._dataloader( + self.ds["validation"], conv_type=self.cfg.conv_type ) def test_dataloader(self) -> Optional[DataLoader]: - if "test" in self.ds: - return DataLoader( - self.ds["test"], - batch_size=self.batch_size, - num_workers=self.cfg.num_workers, - collate_fn=self.collate_fn, + if "test" in self.ds.keys(): + return self._dataloader( + self.ds["test"], conv_type=self.cfg.conv_type ) @property def batch_size(self) -> int: return self.cfg.batch_size - @property - def collate_fn(self) -> Optional[Callable]: - return torch_geometric.data.batch.Batch.from_data_list + @staticmethod + def get_num_samples(batch, conv_type): + is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) + if is_dense: + return batch.pos.shape[0] + else: + return (batch.batch.max() + 1).item() + + @staticmethod + def _collate_fn(batch: Data, collate_fn: Callable, pre_collate_transform: Optional[Callable] = None): + if pre_collate_transform: + batch = pre_collate_transform(batch) + return collate_fn(batch) + + @staticmethod + def _get_collate_function(conv_type: str, is_multiscale: bool, pre_collate_transform: Optional[Callable] = None): + is_dense: bool = ConvolutionFormatFactory.check_is_dense_format(conv_type) + if is_multiscale: + if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): + fn = MultiScaleBatch.from_data_list + else: + raise NotImplementedError( + "MultiscaleTransform is activated and supported only for partial_dense format" + ) + else: + if is_dense: + fn = SimpleBatch.from_data_list + else: + fn = torch_geometric.data.batch.Batch.from_data_list + return partial(PointCloudDataModule._collate_fn, collate_fn=fn, pre_collate_transform=pre_collate_transform) + + def _dataloader(self, dataset: Dataset, pre_batch_collate_transform: Optional[Callable] = None, conv_type: str = "partial_dense", precompute_multi_scale: bool = False, **kwargs): + batch_collate_function = self.__class__._get_collate_function( + conv_type, precompute_multi_scale, pre_batch_collate_transform + ) + num_workers = self.cfg.num_workers + persistent_workers = (num_workers > 0) + + dataloader = partial( + DataLoader, collate_fn=batch_collate_function, worker_init_fn=np.random.seed, + persistent_workers=persistent_workers, + batch_size=self.batch_size, + num_workers=num_workers + ) + return dataloader(dataset, **kwargs) @property def model_data_kwargs(self) -> Dict: @@ -63,4 +108,4 @@ def model_data_kwargs(self) -> Dict: This is useful to provide the number of classes/pixels to the model or any other data specific args Returns: Dict of args """ - return {} \ No newline at end of file + return {} diff --git a/torch_points3d/models/segmentation/base_model.py b/torch_points3d/models/segmentation/base_model.py index 42ac312..452b491 100644 --- a/torch_points3d/models/segmentation/base_model.py +++ b/torch_points3d/models/segmentation/base_model.py @@ -11,7 +11,14 @@ class SegmentationBaseModel(PointCloudBaseModel): - def __init__(self, instantiator: Instantiator, num_classes: int, backbone: DictConfig, criterion: DictConfig): + def __init__( + self, + instantiator: Instantiator, + num_classes: int, + backbone: DictConfig, + criterion: DictConfig, + conv_type: Optional[str] = None, + ): super().__init__(instantiator) print(backbone) diff --git a/torch_points3d/modules/KPConv/__init__.py b/torch_points3d/modules/KPConv/__init__.py new file mode 100644 index 0000000..f7ce612 --- /dev/null +++ b/torch_points3d/modules/KPConv/__init__.py @@ -0,0 +1,2 @@ +from .blocks import * +from .kernels import * diff --git a/torch_points3d/modules/KPConv/blocks.py b/torch_points3d/modules/KPConv/blocks.py new file mode 100644 index 0000000..faadf5c --- /dev/null +++ b/torch_points3d/modules/KPConv/blocks.py @@ -0,0 +1,297 @@ +import torch +import sys +from torch.nn import Linear as Lin + +from .kernels import KPConvLayer, KPConvDeformableLayer +from torch_points3d.core.common_modules.base_modules import BaseModule, FastBatchNorm1d +from torch_points3d.core.spatial_ops import RadiusNeighbourFinder +from torch_points3d.core.data_transform import GridSampling3D +from torch_points3d.utils.enums import ConvolutionFormat +from torch_points3d.core.base_conv.message_passing import GlobalBaseModule +from torch_points3d.core.common_modules.base_modules import Identity +from torch_points3d.utils.config import is_list + + +class SimpleBlock(BaseModule): + """ + simple layer with KPConv convolution -> activation -> BN + we can perform a stride version (just change the query and the neighbors) + """ + + CONV_TYPE = ConvolutionFormat.PARTIAL_DENSE.value + DEFORMABLE_DENSITY = 5.0 + RIGID_DENSITY = 2.5 + + def __init__( + self, + down_conv_nn=None, + grid_size=None, + prev_grid_size=None, + sigma=1.0, + max_num_neighbors=16, + activation=torch.nn.LeakyReLU(negative_slope=0.1), + bn_momentum=0.02, + bn=FastBatchNorm1d, + deformable=False, + add_one=False, + **kwargs, + ): + super(SimpleBlock, self).__init__() + assert len(down_conv_nn) == 2 + num_inputs, num_outputs = down_conv_nn + if deformable: + density_parameter = self.DEFORMABLE_DENSITY + self.kp_conv = KPConvDeformableLayer( + num_inputs, num_outputs, point_influence=prev_grid_size * sigma, add_one=add_one, **kwargs + ) + else: + density_parameter = self.RIGID_DENSITY + self.kp_conv = KPConvLayer( + num_inputs, num_outputs, point_influence=prev_grid_size * sigma, add_one=add_one, **kwargs + ) + search_radius = density_parameter * sigma * prev_grid_size + self.neighbour_finder = RadiusNeighbourFinder(search_radius, max_num_neighbors, conv_type=self.CONV_TYPE) + + if bn: + self.bn = bn(num_outputs, momentum=bn_momentum) + else: + self.bn = None + self.activation = activation + + is_strided = prev_grid_size != grid_size + if is_strided: + self.sampler = GridSampling3D(grid_size) + else: + self.sampler = None + + def forward(self, data, precomputed=None, **kwargs): + if not hasattr(data, "block_idx"): + setattr(data, "block_idx", 0) + + if precomputed: + query_data = precomputed[data.block_idx] + else: + if self.sampler: + query_data = self.sampler(data.clone()) + else: + query_data = data.clone() + + if precomputed: + idx_neighboors = query_data.idx_neighboors + q_pos = query_data.pos + else: + q_pos, q_batch = query_data.pos, query_data.batch + idx_neighboors = self.neighbour_finder(data.pos, q_pos, batch_x=data.batch, batch_y=q_batch) + query_data.idx_neighboors = idx_neighboors + + x = self.kp_conv( + q_pos, + data.pos, + idx_neighboors, + data.x, + ) + if self.bn: + x = self.bn(x) + x = self.activation(x) + + query_data.x = x + query_data.block_idx = data.block_idx + 1 + return query_data + + def extra_repr(self): + return "Nb parameters: {}; {}; {}".format(self.nb_params, self.sampler, self.neighbour_finder) + + +class ResnetBBlock(BaseModule): + """Resnet block with optional bottleneck activated by default + Arguments: + down_conv_nn (len of 2 or 3) : + sizes of input, intermediate, output. + If length == 2 then intermediate = num_outputs // 4 + radius : radius of the conv kernel + sigma : + density_parameter : density parameter for the kernel + max_num_neighbors : maximum number of neighboors for the neighboor search + activation : activation function + has_bottleneck: wether to use the bottleneck or not + bn_momentum + bn : batch norm (can be None -> no batch norm) + grid_size : size of the grid, + prev_grid_size : size of the grid at previous step. + In case of a strided block, this is different than grid_size + """ + + CONV_TYPE = ConvolutionFormat.PARTIAL_DENSE.value + + def __init__( + self, + down_conv_nn=None, + grid_size=None, + prev_grid_size=None, + sigma=1, + max_num_neighbors=16, + activation=torch.nn.LeakyReLU(negative_slope=0.1), + has_bottleneck=True, + bn_momentum=0.02, + bn=FastBatchNorm1d, + deformable=False, + add_one=False, + **kwargs, + ): + super(ResnetBBlock, self).__init__() + assert len(down_conv_nn) == 2 or len(down_conv_nn) == 3, "down_conv_nn should be of size 2 or 3" + if len(down_conv_nn) == 2: + num_inputs, num_outputs = down_conv_nn + d_2 = num_outputs // 4 + else: + num_inputs, d_2, num_outputs = down_conv_nn + self.is_strided = prev_grid_size != grid_size + self.has_bottleneck = has_bottleneck + + # Main branch + if self.has_bottleneck: + kp_size = [d_2, d_2] + else: + kp_size = [num_inputs, num_outputs] + + self.kp_conv = SimpleBlock( + down_conv_nn=kp_size, + grid_size=grid_size, + prev_grid_size=prev_grid_size, + sigma=sigma, + max_num_neighbors=max_num_neighbors, + activation=activation, + bn_momentum=bn_momentum, + bn=bn, + deformable=deformable, + add_one=add_one, + **kwargs, + ) + + if self.has_bottleneck: + if bn: + self.unary_1 = torch.nn.Sequential( + Lin(num_inputs, d_2, bias=False), bn(d_2, momentum=bn_momentum), activation + ) + self.unary_2 = torch.nn.Sequential( + Lin(d_2, num_outputs, bias=False), bn(num_outputs, momentum=bn_momentum), activation + ) + else: + self.unary_1 = torch.nn.Sequential(Lin(num_inputs, d_2, bias=False), activation) + self.unary_2 = torch.nn.Sequential(Lin(d_2, num_outputs, bias=False), activation) + + # Shortcut + if num_inputs != num_outputs: + if bn: + self.shortcut_op = torch.nn.Sequential( + Lin(num_inputs, num_outputs, bias=False), bn(num_outputs, momentum=bn_momentum) + ) + else: + self.shortcut_op = Lin(num_inputs, num_outputs, bias=False) + else: + self.shortcut_op = torch.nn.Identity() + + # Final activation + self.activation = activation + + def forward(self, data, precomputed=None, **kwargs): + """ + data: x, pos, batch_idx and idx_neighbour when the neighboors of each point in pos have already been computed + """ + # Main branch + output = data.clone() + shortcut_x = data.x + if self.has_bottleneck: + output.x = self.unary_1(output.x) + output = self.kp_conv(output, precomputed=precomputed) + if self.has_bottleneck: + output.x = self.unary_2(output.x) + + # Shortcut + if self.is_strided: + idx_neighboors = output.idx_neighboors + shortcut_x = torch.cat([shortcut_x, torch.zeros_like(shortcut_x[:1, :])], axis=0) # Shadow feature + neighborhood_features = shortcut_x[idx_neighboors] + shortcut_x = torch.max(neighborhood_features, dim=1, keepdim=False)[0] + + shortcut = self.shortcut_op(shortcut_x) + output.x += shortcut + return output + + @property + def sampler(self): + return self.kp_conv.sampler + + @property + def neighbour_finder(self): + return self.kp_conv.neighbour_finder + + def extra_repr(self): + return "Nb parameters: %i" % self.nb_params + + +class KPDualBlock(BaseModule): + """Dual KPConv block (usually strided + non strided) + + Arguments: Accepted kwargs + block_names: Name of the blocks to be used as part of this dual block + down_conv_nn: Size of the convs e.g. [64,128], + grid_size: Size of the grid for each block, + prev_grid_size: Size of the grid in the previous KPConv + has_bottleneck: Wether a block should implement the bottleneck + max_num_neighbors: Max number of neighboors for the radius search, + deformable: Is deformable, + add_one: Add one as a feature, + """ + + def __init__( + self, + block_names=None, + down_conv_nn=None, + grid_size=None, + prev_grid_size=None, + has_bottleneck=None, + max_num_neighbors=None, + deformable=False, + add_one=False, + **kwargs, + ): + super(KPDualBlock, self).__init__() + + assert len(block_names) == len(down_conv_nn) + self.blocks = torch.nn.ModuleList() + for i, class_name in enumerate(block_names): + # Constructing extra keyword arguments + block_kwargs = {} + for key, arg in kwargs.items(): + block_kwargs[key] = arg[i] if is_list(arg) else arg + + # Building the block + kpcls = getattr(sys.modules[__name__], class_name) + block = kpcls( + down_conv_nn=down_conv_nn[i], + grid_size=grid_size[i], + prev_grid_size=prev_grid_size[i], + has_bottleneck=has_bottleneck[i], + max_num_neighbors=max_num_neighbors[i], + deformable=deformable[i] if is_list(deformable) else deformable, + add_one=add_one[i] if is_list(add_one) else add_one, + **block_kwargs, + ) + self.blocks.append(block) + + def forward(self, data, precomputed=None, **kwargs): + for block in self.blocks: + data = block(data, precomputed=precomputed) + return data + + @property + def sampler(self): + return [b.sampler for b in self.blocks] + + @property + def neighbour_finder(self): + return [b.neighbour_finder for b in self.blocks] + + def extra_repr(self): + return "Nb parameters: %i" % self.nb_params diff --git a/torch_points3d/modules/KPConv/convolution_ops.py b/torch_points3d/modules/KPConv/convolution_ops.py new file mode 100644 index 0000000..cc053c5 --- /dev/null +++ b/torch_points3d/modules/KPConv/convolution_ops.py @@ -0,0 +1,235 @@ +# defining KPConv using torch ops +# Adaptation of https://github.com/HuguesTHOMAS/KPConv/ +# Adaption from https://github.com/humanpose1/KPConvTorch/ + +import torch +from torch_points3d.core.common_modules.gathering import gather + + +def radius_gaussian(sq_r, sig, eps=1e-9): + """ + Compute a radius gaussian (gaussian of distance) + :param sq_r: input radiuses [dn, ..., d1, d0] + :param sig: extents of gaussians [d1, d0] or [d0] or float + :return: gaussian of sq_r [dn, ..., d1, d0] + """ + return torch.exp(-sq_r / (2 * sig ** 2 + eps)) + + +def KPConv_ops( + query_points, + support_points, + neighbors_indices, + features, + K_points, + K_values, + KP_extent, + KP_influence, + aggregation_mode, +): + """ + This function creates a graph of operations to define Kernel Point Convolution in tensorflow. See KPConv function + above for a description of each parameter + :param query_points: float32[n_points, dim] - input query points (center of neighborhoods) + :param support_points: float32[n0_points, dim] - input support points (from which neighbors are taken) + :param neighbors_indices: int32[n_points, n_neighbors] - indices of neighbors of each point + :param features: float32[n0_points, in_fdim] - input features + :param K_values: float32[n_kpoints, in_fdim, out_fdim] - weights of the kernel + :param fixed: string in ('none', 'center' or 'verticals') - fix position of certain kernel points + :param KP_extent: float32 - influence radius of each kernel point + :param KP_influence: string in ('constant', 'linear', 'gaussian') - influence function of the kernel points + :param aggregation_mode: string in ('closest', 'sum') - whether to sum influences, or only keep the closest + :return: [n_points, out_fdim] + """ + + # Get variables + int(K_points.shape[0]) + + # Add a fake point in the last row for shadow neighbors + shadow_point = torch.ones_like(support_points[:1, :]) * 1e6 + support_points = torch.cat([support_points, shadow_point], dim=0) + + # Get neighbor points [n_points, n_neighbors, dim] + neighbors = gather(support_points, neighbors_indices) + + # Center every neighborhood + neighbors = neighbors - query_points.unsqueeze(1) + + # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim] + neighbors.unsqueeze_(2) + differences = neighbors - K_points + + # Get the square distances [n_points, n_neighbors, n_kpoints] + sq_distances = torch.sum(differences ** 2, dim=3) + + # Get Kernel point influences [n_points, n_kpoints, n_neighbors] + if KP_influence == "constant": + # Every point get an influence of 1. + all_weights = torch.ones_like(sq_distances) + all_weights = all_weights.transpose(2, 1) + + elif KP_influence == "linear": + # Influence decrease linearly with the distance, and get to zero when d = KP_extent. + all_weights = torch.clamp(1 - torch.sqrt(sq_distances) / KP_extent, min=0.0) + all_weights = all_weights.transpose(2, 1) + + elif KP_influence == "gaussian": + # Influence in gaussian of the distance. + sigma = KP_extent * 0.3 + all_weights = radius_gaussian(sq_distances, sigma) + all_weights = all_weights.transpose(2, 1) + else: + raise ValueError("Unknown influence function type (config.KP_influence)") + + # In case of closest mode, only the closest KP can influence each point + if aggregation_mode == "closest": + neighbors_1nn = torch.argmin(sq_distances, dim=-1) + all_weights *= torch.transpose(torch.nn.functional.one_hot(neighbors_1nn, K_points.shape[0]), 1, 2) + + elif aggregation_mode != "sum": + raise ValueError("Unknown convolution mode. Should be 'closest' or 'sum'") + + features = torch.cat([features, torch.zeros_like(features[:1, :])], dim=0) + + # Get the features of each neighborhood [n_points, n_neighbors, in_fdim] + neighborhood_features = gather(features, neighbors_indices) + + # Apply distance weights [n_points, n_kpoints, in_fdim] + weighted_features = torch.matmul(all_weights, neighborhood_features) + + # Apply network weights [n_kpoints, n_points, out_fdim] + weighted_features = weighted_features.permute(1, 0, 2) + kernel_outputs = torch.matmul(weighted_features, K_values) + + # Convolution sum to get [n_points, out_fdim] + output_features = torch.sum(kernel_outputs, dim=0) + + return output_features + + +def KPConv_deform_ops( + query_points, + support_points, + neighbors_indices, + features, + K_points, + offsets, + modulations, + K_values, + KP_extent, + KP_influence, + aggregation_mode, +): + """ + This function creates a graph of operations to define Deformable Kernel Point Convolution in tensorflow. See + KPConv_deformable function above for a description of each parameter + :param query_points: [n_points, dim] + :param support_points: [n0_points, dim] + :param neighbors_indices: [n_points, n_neighbors] + :param features: [n0_points, in_fdim] + :param K_points: [n_kpoints, dim] + :param offsets: [n_points, n_kpoints, dim] + :param modulations: [n_points, n_kpoints] or None + :param K_values: [n_kpoints, in_fdim, out_fdim] + :param KP_extent: float32 + :param KP_influence: string + :param aggregation_mode: string in ('closest', 'sum') - whether to sum influences, or only keep the closest + + :return features, square_distances, deformed_K_points + """ + + # Get variables + n_kp = int(K_points.shape[0]) + shadow_ind = support_points.shape[0] + + # Add a fake point in the last row for shadow neighbors + shadow_point = torch.ones_like(support_points[:1, :]) * 1e6 + support_points = torch.cat([support_points, shadow_point], axis=0) + + # Get neighbor points [n_points, n_neighbors, dim] + neighbors = support_points[neighbors_indices] + + # Center every neighborhood + neighbors = neighbors - query_points.unsqueeze(1) + + # Apply offsets to kernel points [n_points, n_kpoints, dim] + deformed_K_points = torch.add(offsets, K_points) + + # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim] + neighbors = neighbors.unsqueeze(2) + neighbors = neighbors.repeat([1, 1, n_kp, 1]) + differences = neighbors - deformed_K_points.unsqueeze(1) + + # Get the square distances [n_points, n_neighbors, n_kpoints] + sq_distances = torch.sum(differences ** 2, axis=3) + + # Boolean of the neighbors in range of a kernel point [n_points, n_neighbors] + in_range = (sq_distances < KP_extent ** 2).any(2).to(torch.long) + + # New value of max neighbors + new_max_neighb = torch.max(torch.sum(in_range, axis=1)) + # print(new_max_neighb) + + # For each row of neighbors, indices of the ones that are in range [n_points, new_max_neighb] + new_neighb_bool, new_neighb_inds = torch.topk(in_range, k=new_max_neighb) + + # Gather new neighbor indices [n_points, new_max_neighb] + new_neighbors_indices = neighbors_indices.gather(1, new_neighb_inds) + + # Gather new distances to KP [n_points, new_max_neighb, n_kpoints] + new_neighb_inds_sq = new_neighb_inds.unsqueeze(-1) + new_sq_distances = sq_distances.gather(1, new_neighb_inds_sq.repeat((1, 1, sq_distances.shape[-1]))) + + # New shadow neighbors have to point to the last shadow point + new_neighbors_indices *= new_neighb_bool + new_neighbors_indices += (1 - new_neighb_bool) * shadow_ind + + # Get Kernel point influences [n_points, n_kpoints, n_neighbors] + if KP_influence == "constant": + # Every point get an influence of 1. + all_weights = (new_sq_distances < KP_extent ** 2).to(torch.float32) + all_weights = all_weights.permute(0, 2, 1) + + elif KP_influence == "linear": + # Influence decrease linearly with the distance, and get to zero when d = KP_extent. + all_weights = torch.relu(1 - torch.sqrt(new_sq_distances) / KP_extent) + all_weights = all_weights.permute(0, 2, 1) + + elif KP_influence == "gaussian": + # Influence in gaussian of the distance. + sigma = KP_extent * 0.3 + all_weights = radius_gaussian(new_sq_distances, sigma) + all_weights = all_weights.permute(0, 2, 1) + else: + raise ValueError("Unknown influence function type (config.KP_influence)") + + # In case of closest mode, only the closest KP can influence each point + if aggregation_mode == "closest": + neighbors_1nn = torch.argmin(new_sq_distances, axis=2, output_type=torch.long) + all_weights *= torch.zeros_like(all_weights, dtype=torch.float32).scatter_(1, neighbors_1nn, 1) + + elif aggregation_mode != "sum": + raise ValueError("Unknown convolution mode. Should be 'closest' or 'sum'") + + features = torch.cat([features, torch.zeros_like(features[:1, :])], axis=0) + + # Get the features of each neighborhood [n_points, new_max_neighb, in_fdim] + neighborhood_features = features[new_neighbors_indices] + + # Apply distance weights [n_points, n_kpoints, in_fdim] + # print(all_weights.shape, neighborhood_features.shape) + weighted_features = torch.matmul(all_weights, neighborhood_features) + + # Apply modulations + if modulations is not None: + weighted_features *= modulations.unsqueeze(2) + + # Apply network weights [n_kpoints, n_points, out_fdim] + weighted_features = weighted_features.permute(1, 0, 2) + kernel_outputs = torch.matmul(weighted_features, K_values) + + # Convolution sum [n_points, out_fdim] + output_features = torch.sum(kernel_outputs, axis=0) + + # we need regularization + return output_features, sq_distances, deformed_K_points diff --git a/torch_points3d/modules/KPConv/kernel_utils.py b/torch_points3d/modules/KPConv/kernel_utils.py new file mode 100644 index 0000000..0ed4214 --- /dev/null +++ b/torch_points3d/modules/KPConv/kernel_utils.py @@ -0,0 +1,287 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Functions handling the disposition of kernel points. +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 11/06/2018 +# + + +# ------------------------------------------------------------------------------------------ +# +# Imports and global variables +# \**********************************/ +# + + +# Import numpy package and name it "np" +import numpy as np +import matplotlib.pyplot as plt +from os import makedirs +from os.path import join, exists +import os +import logging + +from .plyutils import read_ply, write_ply + + +# ------------------------------------------------------------------------------------------ +# +# Functions +# \***************/ +# +# +log = logging.getLogger(__name__) +DIR = os.path.dirname(os.path.realpath(__file__)) + + +def kernel_point_optimization_debug( + radius, num_points, num_kernels=1, dimension=3, fixed="center", ratio=1.0, verbose=0 +): + """ + Creation of kernel point via optimization of potentials. + :param radius: Radius of the kernels + :param num_points: points composing kernels + :param num_kernels: number of wanted kernels + :param dimension: dimension of the space + :param fixed: fix position of certain kernel points ('none', 'center' or 'verticals') + :param ratio: ratio of the radius where you want the kernels points to be placed + :param verbose: display option + :return: points [num_kernels, num_points, dimension] + """ + + ####################### + # Parameters definition + ####################### + + # Radius used for optimization (points are rescaled afterwards) + radius0 = 1 + diameter0 = 2 + + # Factor multiplicating gradients for moving points (~learning rate) + moving_factor = 1e-2 + continuous_moving_decay = 0.9995 + + # Gradient threshold to stop optimization + thresh = 1e-5 + + # Gradient clipping value + clip = 0.05 * radius0 + + ####################### + # Kernel initialization + ####################### + + # Random kernel points + kernel_points = np.random.rand(num_kernels * num_points - 1, dimension) * diameter0 - radius0 + while kernel_points.shape[0] < num_kernels * num_points: + new_points = np.random.rand(num_kernels * num_points - 1, dimension) * diameter0 - radius0 + kernel_points = np.vstack((kernel_points, new_points)) + d2 = np.sum(np.power(kernel_points, 2), axis=1) + kernel_points = kernel_points[d2 < 0.5 * radius0 * radius0, :] + kernel_points = kernel_points[: num_kernels * num_points, :].reshape((num_kernels, num_points, -1)) + + # Optionnal fixing + if fixed == "center": + kernel_points[:, 0, :] *= 0 + if fixed == "verticals": + kernel_points[:, :3, :] *= 0 + kernel_points[:, 1, -1] += 2 * radius0 / 3 + kernel_points[:, 2, -1] -= 2 * radius0 / 3 + + ##################### + # Kernel optimization + ##################### + + # Initiate figure + if verbose > 1: + fig = plt.figure() + + saved_gradient_norms = np.zeros((10000, num_kernels)) + old_gradient_norms = np.zeros((num_kernels, num_points)) + for iter in range(10000): + + # Compute gradients + # ***************** + + # Derivative of the sum of potentials of all points + A = np.expand_dims(kernel_points, axis=2) + B = np.expand_dims(kernel_points, axis=1) + interd2 = np.sum(np.power(A - B, 2), axis=-1) + inter_grads = (A - B) / (np.power(np.expand_dims(interd2, -1), 3 / 2) + 1e-6) + inter_grads = np.sum(inter_grads, axis=1) + + # Derivative of the radius potential + circle_grads = 10 * kernel_points + + # All gradients + gradients = inter_grads + circle_grads + + if fixed == "verticals": + gradients[:, 1:3, :-1] = 0 + + # Stop condition + # ************** + + # Compute norm of gradients + gradients_norms = np.sqrt(np.sum(np.power(gradients, 2), axis=-1)) + saved_gradient_norms[iter, :] = np.max(gradients_norms, axis=1) + + # Stop if all moving points are gradients fixed (low gradients diff) + + if fixed == "center" and np.max(np.abs(old_gradient_norms[:, 1:] - gradients_norms[:, 1:])) < thresh: + break + elif fixed == "verticals" and np.max(np.abs(old_gradient_norms[:, 3:] - gradients_norms[:, 3:])) < thresh: + break + elif np.max(np.abs(old_gradient_norms - gradients_norms)) < thresh: + break + old_gradient_norms = gradients_norms + + # Move points + # *********** + + # Clip gradient to get moving dists + moving_dists = np.minimum(moving_factor * gradients_norms, clip) + + # Fix central point + if fixed == "center": + moving_dists[:, 0] = 0 + if fixed == "verticals": + moving_dists[:, 0] = 0 + + # Move points + kernel_points -= np.expand_dims(moving_dists, -1) * gradients / np.expand_dims(gradients_norms + 1e-6, -1) + + if verbose: + log.info("iter {:5d} / max grad = {:f}".format(iter, np.max(gradients_norms[:, 3:]))) + if verbose > 1: + plt.clf() + plt.plot(kernel_points[0, :, 0], kernel_points[0, :, 1], ".") + circle = plt.Circle((0, 0), radius, color="r", fill=False) + fig.axes[0].add_artist(circle) + fig.axes[0].set_xlim((-radius * 1.1, radius * 1.1)) + fig.axes[0].set_ylim((-radius * 1.1, radius * 1.1)) + fig.axes[0].set_aspect("equal") + plt.draw() + plt.pause(0.001) + plt.show(block=False) + log.info(moving_factor) + + # moving factor decay + moving_factor *= continuous_moving_decay + + # Rescale radius to fit the wanted ratio of radius + r = np.sqrt(np.sum(np.power(kernel_points, 2), axis=-1)) + kernel_points *= ratio / np.mean(r[:, 1:]) + + # Rescale kernels with real radius + return kernel_points * radius, saved_gradient_norms + + +def load_kernels(radius, num_kpoints, num_kernels, dimension, fixed): + + # Number of tries in the optimization process, to ensure we get the most stable disposition + num_tries = 100 + + # Kernel directory + kernel_dir = join(DIR, "kernels/dispositions") + if not exists(kernel_dir): + makedirs(kernel_dir) + + # Kernel_file + if dimension == 3: + kernel_file = join(kernel_dir, "k_{:03d}_{:s}.ply".format(num_kpoints, fixed)) + elif dimension == 2: + kernel_file = join(kernel_dir, "k_{:03d}_{:s}_2D.ply".format(num_kpoints, fixed)) + else: + raise ValueError("Unsupported dimpension of kernel : " + str(dimension)) + + # Check if already done + if not exists(kernel_file): + + # Create kernels + kernel_points, grad_norms = kernel_point_optimization_debug( + 1.0, + num_kpoints, + num_kernels=num_tries, + dimension=dimension, + fixed=fixed, + verbose=0, + ) + + # Find best candidate + best_k = np.argmin(grad_norms[-1, :]) + + # Save points + original_kernel = kernel_points[best_k, :, :] + write_ply(kernel_file, original_kernel, ["x", "y", "z"]) + + else: + data = read_ply(kernel_file) + original_kernel = np.vstack((data["x"], data["y"], data["z"])).T + + # N.B. 2D kernels are not supported yet + if dimension == 2: + return original_kernel + + # Random rotations depending of the fixed points + if fixed == "verticals": + + # Create random rotations + thetas = np.random.rand(num_kernels) * 2 * np.pi + c, s = np.cos(thetas), np.sin(thetas) + R = np.zeros((num_kernels, 3, 3), dtype=np.float32) + R[:, 0, 0] = c + R[:, 1, 1] = c + R[:, 2, 2] = 1 + R[:, 0, 1] = s + R[:, 1, 0] = -s + + # Scale kernels + original_kernel = radius * np.expand_dims(original_kernel, 0) + + # Rotate kernels + kernels = np.matmul(original_kernel, R) + + else: + + # Create random rotations + u = np.ones((num_kernels, 3)) + v = np.ones((num_kernels, 3)) + wrongs = np.abs(np.sum(u * v, axis=1)) > 0.99 + while np.any(wrongs): + new_u = np.random.rand(num_kernels, 3) * 2 - 1 + new_u = new_u / np.expand_dims(np.linalg.norm(new_u, axis=1) + 1e-9, -1) + u[wrongs, :] = new_u[wrongs, :] + new_v = np.random.rand(num_kernels, 3) * 2 - 1 + new_v = new_v / np.expand_dims(np.linalg.norm(new_v, axis=1) + 1e-9, -1) + v[wrongs, :] = new_v[wrongs, :] + wrongs = np.abs(np.sum(u * v, axis=1)) > 0.99 + + # Make v perpendicular to u + v -= np.expand_dims(np.sum(u * v, axis=1), -1) * u + v = v / np.expand_dims(np.linalg.norm(v, axis=1) + 1e-9, -1) + + # Last rotation vector + w = np.cross(u, v) + R = np.stack((u, v, w), axis=-1) + + # Scale kernels + original_kernel = radius * np.expand_dims(original_kernel, 0) + + # Rotate kernels + kernels = np.matmul(original_kernel, R) + + # Add a small noise + kernels = kernels + kernels = kernels + np.random.normal(scale=radius * 0.01, size=kernels.shape) + + return kernels diff --git a/torch_points3d/modules/KPConv/kernels.py b/torch_points3d/modules/KPConv/kernels.py new file mode 100644 index 0000000..ceaea0b --- /dev/null +++ b/torch_points3d/modules/KPConv/kernels.py @@ -0,0 +1,266 @@ +import torch +from torch.nn.parameter import Parameter + +from .kernel_utils import kernel_point_optimization_debug, load_kernels +from .losses import fitting_loss, repulsion_loss, permissive_loss +from .convolution_ops import * +from torch_points3d.applications.modules.base_modules import BaseInternalLossModule + + +def add_ones(query_points, x, add_one): + if add_one: + ones = torch.ones(query_points.shape[0], dtype=torch.float).unsqueeze(-1).to(query_points.device) + if x is not None: + x = torch.cat([ones.to(x.dtype), x], dim=-1) + else: + x = ones + return x + + +class KPConvLayer(torch.nn.Module): + """ + apply the kernel point convolution on a point cloud + NB : it is the original version of KPConv, it is not the message passing version + attributes: + num_inputs : dimension of the input feature + num_outputs : dimension of the output feature + point_influence: influence distance of a single point (sigma * grid_size) + n_kernel_points=15 + fixed="center" + KP_influence="linear" + aggregation_mode="sum" + dimension=3 + """ + + _INFLUENCE_TO_RADIUS = 1.5 + + def __init__( + self, + num_inputs, + num_outputs, + point_influence, + n_kernel_points=15, + fixed="center", + KP_influence="linear", + aggregation_mode="sum", + dimension=3, + add_one=False, + **kwargs + ): + super(KPConvLayer, self).__init__() + self.kernel_radius = self._INFLUENCE_TO_RADIUS * point_influence + self.point_influence = point_influence + self.add_one = add_one + self.num_inputs = num_inputs + self.add_one * 1 + self.num_outputs = num_outputs + + self.KP_influence = KP_influence + self.n_kernel_points = n_kernel_points + self.aggregation_mode = aggregation_mode + + # Initial kernel extent for this layer + K_points_numpy = load_kernels( + self.kernel_radius, + n_kernel_points, + num_kernels=1, + dimension=dimension, + fixed=fixed, + ) + + self.K_points = Parameter( + torch.from_numpy(K_points_numpy.reshape((n_kernel_points, dimension))).to(torch.float), + requires_grad=False, + ) + + weights = torch.empty([n_kernel_points, self.num_inputs, num_outputs], dtype=torch.float) + torch.nn.init.xavier_normal_(weights) + self.weight = Parameter(weights) + + def forward(self, query_points, support_points, neighbors, x): + """ + - query_points(torch Tensor): query of size N x 3 + - support_points(torch Tensor): support points of size N0 x 3 + - neighbors(torch Tensor): neighbors of size N x M + - features : feature of size N0 x d (d is the number of inputs) + """ + x = add_ones(support_points, x, self.add_one) + + new_feat = KPConv_ops( + query_points, + support_points, + neighbors, + x, + self.K_points, + self.weight, + self.point_influence, + self.KP_influence, + self.aggregation_mode, + ) + return new_feat + + def __repr__(self): + return "KPConvLayer(InF: %i, OutF: %i, kernel_pts: %i, radius: %.2f, KP_influence: %s, Add_one: %s)" % ( + self.num_inputs, + self.num_outputs, + self.n_kernel_points, + self.kernel_radius, + self.KP_influence, + self.add_one, + ) + + +class KPConvDeformableLayer(BaseInternalLossModule): + """ + apply the deformable kernel point convolution on a point cloud + NB : it is the original version of KPConv, it is not the message passing version + attributes: + num_inputs : dimension of the input feature + num_outputs : dimension of the output feature + point_influence: influence distance of a single point (sigma * grid_size) + n_kernel_points=15 + fixed="center" + KP_influence="linear" + aggregation_mode="sum" + dimension=3 + modulated = False : If deformable conv should be modulated + """ + + PERMISSIVE_LOSS_KEY = "permissive_loss" + FITTING_LOSS_KEY = "fitting_loss" + REPULSION_LOSS_KEY = "repulsion_loss" + + _INFLUENCE_TO_RADIUS = 1.5 + + def __init__( + self, + num_inputs, + num_outputs, + point_influence, + n_kernel_points=15, + fixed="center", + KP_influence="linear", + aggregation_mode="sum", + dimension=3, + modulated=False, + loss_mode="fitting", + add_one=False, + **kwargs + ): + super(KPConvDeformableLayer, self).__init__() + self.kernel_radius = self._INFLUENCE_TO_RADIUS * point_influence + self.point_influence = point_influence + self.add_one = add_one + self.num_inputs = num_inputs + self.add_one * 1 + self.num_outputs = num_outputs + + self.KP_influence = KP_influence + self.n_kernel_points = n_kernel_points + self.aggregation_mode = aggregation_mode + self.modulated = modulated + self.internal_losses = {self.PERMISSIVE_LOSS_KEY: 0.0, self.FITTING_LOSS_KEY: 0.0, self.REPULSION_LOSS_KEY: 0.0} + self.loss_mode = loss_mode + + # Initial kernel extent for this layer + K_points_numpy = load_kernels( + self.kernel_radius, + n_kernel_points, + num_kernels=1, + dimension=dimension, + fixed=fixed, + ) + self.K_points = Parameter( + torch.from_numpy(K_points_numpy.reshape((n_kernel_points, dimension))).to(torch.float), + requires_grad=False, + ) + + # Create independant weight for the first convolution and a bias term as no batch normalization happen + if modulated: + offset_dim = (dimension + 1) * self.n_kernel_points + else: + offset_dim = dimension * self.n_kernel_points + offset_weights = torch.empty([n_kernel_points, self.num_inputs, offset_dim], dtype=torch.float) + torch.nn.init.xavier_normal_(offset_weights) + self.offset_weights = Parameter(offset_weights) + self.offset_bias = Parameter(torch.zeros(offset_dim, dtype=torch.float)) + + # Main deformable weights + weights = torch.empty([n_kernel_points, self.num_inputs, num_outputs], dtype=torch.float) + torch.nn.init.xavier_normal_(weights) + self.weight = Parameter(weights) + + def forward(self, query_points, support_points, neighbors, x): + """ + - query_points(torch Tensor): query of size N x 3 + - support_points(torch Tensor): support points of size N0 x 3 + - neighbors(torch Tensor): neighbors of size N x M + - features : feature of size N0 x d (d is the number of inputs) + """ + + x = add_ones(support_points, x, self.add_one) + + offset_feat = ( + KPConv_ops( + query_points, + support_points, + neighbors, + x, + self.K_points, + self.offset_weights, + self.point_influence, + self.KP_influence, + self.aggregation_mode, + ) + + self.offset_bias + ) + points_dim = query_points.shape[-1] + if self.modulated: + # Get offset (in normalized scale) from features + offsets = offset_feat[:, : points_dim * self.n_kernel_points] + offsets = offsets.reshape((-1, self.n_kernel_points, points_dim)) + + # Get modulations + modulations = 2 * torch.nn.functional.sigmoid(offset_feat[:, points_dim * self.n_kernel_points :]) + else: + # Get offset (in normalized scale) from features + offsets = offset_feat.reshape((-1, self.n_kernel_points, points_dim)) + # No modulations + modulations = None + offsets *= self.point_influence + + # Apply deformable kernel + new_feat, sq_distances, K_points_deformed = KPConv_deform_ops( + query_points, + support_points, + neighbors, + x, + self.K_points, + offsets, + modulations, + self.weight, + self.point_influence, + self.KP_influence, + self.aggregation_mode, + ) + + if self.loss_mode == "fitting": + self.internal_losses[self.FITTING_LOSS_KEY] = fitting_loss(sq_distances, self.kernel_radius) + self.internal_losses[self.REPULSION_LOSS_KEY] = repulsion_loss(K_points_deformed, self.point_influence) + elif self.loss_mode == "permissive": + self.internal_losses[self.PERMISSIVE_LOSS_KEY] = permissive_loss(K_points_deformed, self.kernel_radius) + else: + raise NotImplementedError( + "Loss mode %s not recognised. Only permissive and fitting are valid" % self.loss_mode + ) + return new_feat + + def get_internal_losses(self): + return self.internal_losses + + def __repr__(self): + return "KPConvDeformableLayer(InF: %i, OutF: %i, kernel_pts: %i, radius: %.2f, KP_influence: %s)" % ( + self.num_inputs, + self.num_outputs, + self.n_kernel_points, + self.kernel_radius, + self.KP_influence, + ) diff --git a/torch_points3d/modules/KPConv/kernels/dispositions/k_015_center.ply b/torch_points3d/modules/KPConv/kernels/dispositions/k_015_center.ply new file mode 100644 index 0000000..c85135c Binary files /dev/null and b/torch_points3d/modules/KPConv/kernels/dispositions/k_015_center.ply differ diff --git a/torch_points3d/modules/KPConv/losses.py b/torch_points3d/modules/KPConv/losses.py new file mode 100644 index 0000000..528b87f --- /dev/null +++ b/torch_points3d/modules/KPConv/losses.py @@ -0,0 +1,42 @@ +import torch + + +def fitting_loss(sq_distance, radius): + """KPConv fitting loss. For each query point it ensures that at least one neighboor is + close to each kernel point + + Arguments: + sq_distance - For each querry point, from all neighboors to all KP points [N_querry, N_neighboors, N_KPoints] + radius - Radius of the convolution + """ + kpmin = sq_distance.min(dim=1)[0] + normalised_kpmin = kpmin / (radius ** 2) + return torch.mean(normalised_kpmin) + + +def repulsion_loss(deformed_kpoints, radius): + """Ensures that the deformed points within the kernel remain equidistant + + Arguments: + deformed_kpoints - deformed points for each query point + radius - Radius of the kernel + """ + deformed_kpoints / float(radius) + n_points = deformed_kpoints.shape[1] + repulsive_loss = 0 + for i in range(n_points): + with torch.no_grad(): + other_points = torch.cat([deformed_kpoints[:, :i, :], deformed_kpoints[:, i + 1 :, :]], dim=1) + distances = torch.sqrt(torch.sum((other_points - deformed_kpoints[:, i : i + 1, :]) ** 2, dim=-1)) + repulsion_force = torch.sum(torch.pow(torch.relu(1.5 - distances), 2), dim=1) + repulsive_loss += torch.mean(repulsion_force) + return repulsive_loss + + +def permissive_loss(deformed_kpoints, radius): + """This loss is responsible to penalize deformed_kpoints to + move outside from the radius defined for the convolution + """ + norm_deformed_normalized = torch.norm(deformed_kpoints, p=2, dim=-1) / float(radius) + permissive_loss = torch.mean(norm_deformed_normalized[norm_deformed_normalized > 1.0]) + return permissive_loss diff --git a/torch_points3d/modules/KPConv/plyutils.py b/torch_points3d/modules/KPConv/plyutils.py new file mode 100644 index 0000000..5fe647a --- /dev/null +++ b/torch_points3d/modules/KPConv/plyutils.py @@ -0,0 +1,342 @@ +# +# 0===============================0 +# | PLY files reader/writer | +# 0===============================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# function to read/write .ply files +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 10/02/2017 +# + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Imports and global variables +# \**********************************/ +# + + +# Basic libs +import numpy as np +import sys +import logging + +log = logging.getLogger(__name__) + + +# Define PLY types +ply_dtypes = dict( + [ + (b"int8", "i1"), + (b"char", "i1"), + (b"uint8", "u1"), + (b"uchar", "u1"), + (b"int16", "i2"), + (b"short", "i2"), + (b"uint16", "u2"), + (b"ushort", "u2"), + (b"int32", "i4"), + (b"int", "i4"), + (b"uint32", "u4"), + (b"uint", "u4"), + (b"float32", "f4"), + (b"float", "f4"), + (b"float64", "f8"), + (b"double", "f8"), + ] +) + +# Numpy reader format +valid_formats = {"ascii": "", "binary_big_endian": ">", "binary_little_endian": "<"} + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Functions +# \***************/ +# + + +def parse_header(plyfile, ext): + # Variables + line = [] + properties = [] + num_points = None + + while b"end_header" not in line and line != b"": + line = plyfile.readline() + + if b"element" in line: + line = line.split() + num_points = int(line[2]) + + elif b"property" in line: + line = line.split() + properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) + + return num_points, properties + + +def parse_mesh_header(plyfile, ext): + # Variables + line = [] + vertex_properties = [] + num_points = None + num_faces = None + current_element = None + + while b"end_header" not in line and line != b"": + line = plyfile.readline() + + # Find point element + if b"element vertex" in line: + current_element = "vertex" + line = line.split() + num_points = int(line[2]) + + elif b"element face" in line: + current_element = "face" + line = line.split() + num_faces = int(line[2]) + + elif b"property" in line: + if current_element == "vertex": + line = line.split() + vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) + elif current_element == "vertex": + if not line.startswith("property list uchar int"): + raise ValueError("Unsupported faces property : " + line) + + return num_points, num_faces, vertex_properties + + +def read_ply(filename, triangular_mesh=False): + """ + Read ".ply" files + Parameters + ---------- + filename : string + the name of the file to read. + Returns + ------- + result : array + data stored in the file + Examples + -------- + Store data in file + >>> points = np.random.rand(5, 3) + >>> values = np.random.randint(2, size=10) + >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values']) + Read the file + >>> data = read_ply('example.ply') + >>> values = data['values'] + array([0, 0, 1, 1, 0]) + + >>> points = np.vstack((data['x'], data['y'], data['z'])).T + array([[ 0.466 0.595 0.324] + [ 0.538 0.407 0.654] + [ 0.850 0.018 0.988] + [ 0.395 0.394 0.363] + [ 0.873 0.996 0.092]]) + """ + + with open(filename, "rb") as plyfile: + + # Check if the file start with ply + if b"ply" not in plyfile.readline(): + raise ValueError("The file does not start whith the word ply") + + # get binary_little/big or ascii + fmt = plyfile.readline().split()[1].decode() + if fmt == "ascii": + raise ValueError("The file is not binary") + + # get extension for building the numpy dtypes + ext = valid_formats[fmt] + + # PointCloud reader vs mesh reader + if triangular_mesh: + + # Parse header + num_points, num_faces, properties = parse_mesh_header(plyfile, ext) + + # Get point data + vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points) + + # Get face data + face_properties = [ + ("k", ext + "u1"), + ("v1", ext + "i4"), + ("v2", ext + "i4"), + ("v3", ext + "i4"), + ] + faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces) + + # Return vertex data and concatenated faces + faces = np.vstack((faces_data["v1"], faces_data["v2"], faces_data["v3"])).T + data = [vertex_data, faces] + + else: + + # Parse header + num_points, properties = parse_header(plyfile, ext) + + # Get data + data = np.fromfile(plyfile, dtype=properties, count=num_points) + + return data + + +def header_properties(field_list, field_names): + + # List of lines to write + lines = [] + + # First line describing element vertex + lines.append("element vertex %d" % field_list[0].shape[0]) + + # Properties lines + i = 0 + for fields in field_list: + for field in fields.T: + lines.append("property %s %s" % (field.dtype.name, field_names[i])) + i += 1 + + return lines + + +def write_ply(filename, field_list, field_names, triangular_faces=None): + """ + Write ".ply" files + Parameters + ---------- + filename : string + the name of the file to which the data is saved. A '.ply' extension will be appended to the + file name if it does no already have one. + field_list : list, tuple, numpy array + the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a + tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered + as one field. + field_names : list + the name of each fields as a list of strings. Has to be the same length as the number of + fields. + Examples + -------- + >>> points = np.random.rand(10, 3) + >>> write_ply('example1.ply', points, ['x', 'y', 'z']) + >>> values = np.random.randint(2, size=10) + >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values']) + >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8) + >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', values'] + >>> write_ply('example3.ply', [points, colors, values], field_names) + """ + + # Format list input to the right form + field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,)) + for i, field in enumerate(field_list): + if field.ndim < 2: + field_list[i] = field.reshape(-1, 1) + if field.ndim > 2: + log.info("fields have more than 2 dimensions") + return False + + # check all fields have the same number of data + n_points = [field.shape[0] for field in field_list] + if not np.all(np.equal(n_points, n_points[0])): + log.info("wrong field dimensions") + return False + + # Check if field_names and field_list have same nb of column + n_fields = np.sum([field.shape[1] for field in field_list]) + if n_fields != len(field_names): + log.info("wrong number of field names") + return False + + # Add extension if not there + if not filename.endswith(".ply"): + filename += ".ply" + + # open in text mode to write the header + with open(filename, "w") as plyfile: + + # First magical word + header = ["ply"] + + # Encoding format + header.append("format binary_" + sys.byteorder + "_endian 1.0") + + # Points properties description + header.extend(header_properties(field_list, field_names)) + + # Add faces if needded + if triangular_faces is not None: + header.append("element face {:d}".format(triangular_faces.shape[0])) + header.append("property list uchar int vertex_indices") + + # End of header + header.append("end_header") + + # Write all lines + for line in header: + plyfile.write("%s\n" % line) + + # open in binary/append to use tofile + with open(filename, "ab") as plyfile: + + # Create a structured array + i = 0 + type_list = [] + for fields in field_list: + for field in fields.T: + type_list += [(field_names[i], field.dtype.str)] + i += 1 + data = np.empty(field_list[0].shape[0], dtype=type_list) + i = 0 + for fields in field_list: + for field in fields.T: + data[field_names[i]] = field + i += 1 + + data.tofile(plyfile) + + if triangular_faces is not None: + triangular_faces = triangular_faces.astype(np.int32) + type_list = [("k", "uint8")] + [(str(ind), "int32") for ind in range(3)] + data = np.empty(triangular_faces.shape[0], dtype=type_list) + data["k"] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8) + data["0"] = triangular_faces[:, 0] + data["1"] = triangular_faces[:, 1] + data["2"] = triangular_faces[:, 2] + data.tofile(plyfile) + + return True + + +def describe_element(name, df): + """Takes the columns of the dataframe and builds a ply-like description + Parameters + ---------- + name: str + df: pandas DataFrame + Returns + ------- + element: list[str] + """ + property_formats = {"f": "float", "u": "uchar", "i": "int"} + element = ["element " + name + " " + str(len(df))] + + if name == "face": + element.append("property list uchar int points_indices") + + else: + for i in range(len(df.columns)): + # get first letter of dtype to infer format + f = property_formats[str(df.dtypes[i])[0]] + element.append("property " + f + " " + df.columns.values[i]) + + return element diff --git a/torch_points3d/applications/modules/SparseConv3d/__init__.py b/torch_points3d/modules/SparseConv3d/__init__.py similarity index 100% rename from torch_points3d/applications/modules/SparseConv3d/__init__.py rename to torch_points3d/modules/SparseConv3d/__init__.py diff --git a/torch_points3d/applications/modules/SparseConv3d/modules.py b/torch_points3d/modules/SparseConv3d/modules.py similarity index 55% rename from torch_points3d/applications/modules/SparseConv3d/modules.py rename to torch_points3d/modules/SparseConv3d/modules.py index 3555c55..ba57af1 100644 --- a/torch_points3d/applications/modules/SparseConv3d/modules.py +++ b/torch_points3d/modules/SparseConv3d/modules.py @@ -1,9 +1,12 @@ +from typing import Any, List, Optional import torch import sys from torch_points3d.core.common_modules import Seq, Identity import torch_points3d.applications.modules.SparseConv3d.nn as snn +from omegaconf import DictConfig +from omegaconf import OmegaConf class ResBlock(torch.nn.Module): @@ -22,26 +25,35 @@ class ResBlock(torch.nn.Module): Dimension of the spatial grid """ - def __init__(self, input_nc, output_nc, convolution): + def __init__( + self, + input_nc: int, + output_nc: int, + convolution: Any, + bn: Any, + bn_args: DictConfig = OmegaConf.create(), + activation: torch.nn.Module = torch.nn.ReLU(), + ): super().__init__() + self.activation = snn.create_activation_function(activation) self.block = ( Seq() .append(convolution(input_nc, output_nc, kernel_size=3, stride=1)) - .append(snn.BatchNorm(output_nc)) - .append(snn.ReLU()) + .append(bn(output_nc, **bn_arg)) + .append(self.activation) .append(convolution(output_nc, output_nc, kernel_size=3, stride=1)) - .append(snn.BatchNorm(output_nc)) - .append(snn.ReLU()) + .append(bn(output_nc, **bn_args)) + .append(self.activation) ) if input_nc != output_nc: self.downsample = ( - Seq().append(snn.Conv3d(input_nc, output_nc, kernel_size=1, stride=1)).append(snn.BatchNorm(output_nc)) + Seq().append(convolution(input_nc, output_nc, kernel_size=1, stride=1)).append(bn(output_nc, **bn_args)) ) else: self.downsample = None - def forward(self, x): + def forward(self, x: snn.SparseTensor): out = self.block(x) if self.downsample: out += self.downsample(x) @@ -55,14 +67,23 @@ class BottleneckBlock(torch.nn.Module): Bottleneck block with residual """ - def __init__(self, input_nc, output_nc, convolution, reduction=4): + def __init__( + self, + input_nc: int, + output_nc: int, + convolution: Any, + bn: Any, + bn_args: DictConfig = OmegaConf.create(), + reduction: int = 4, + activation: torch.nn.Module = torch.nn.ReLU(), + ): super().__init__() - + self.activation = snn.create_activation_function(activation) self.block = ( Seq() - .append(snn.Conv3d(input_nc, output_nc // reduction, kernel_size=1, stride=1)) - .append(snn.BatchNorm(output_nc // reduction)) - .append(snn.ReLU()) + .append(convolution(input_nc, output_nc // reduction, kernel_size=1, stride=1)) + .append(bn(output_nc // reduction, **bn_args)) + .append(self.activation) .append( convolution( output_nc // reduction, @@ -71,27 +92,27 @@ def __init__(self, input_nc, output_nc, convolution, reduction=4): stride=1, ) ) - .append(snn.BatchNorm(output_nc // reduction)) - .append(snn.ReLU()) + .append(bn(output_nc // reduction, **bn_args)) + .append(self.activation) .append( - snn.Conv3d( + convolution( output_nc // reduction, output_nc, kernel_size=1, ) ) - .append(snn.BatchNorm(output_nc)) - .append(snn.ReLU()) + .append(bn(output_nc, **bn_args)) + .append(self.activation) ) if input_nc != output_nc: self.downsample = ( - Seq().append(convolution(input_nc, output_nc, kernel_size=1, stride=1)).append(snn.BatchNorm(output_nc)) + Seq().append(convolution(input_nc, output_nc, kernel_size=1, stride=1)).append(bn(output_nc, **bn_args)) ) else: self.downsample = None - def forward(self, x): + def forward(self, x: snn.SparseTensor): out = self.block(x) if self.downsample: out += self.downsample(x) @@ -113,17 +134,25 @@ class ResNetDown(torch.nn.Module): """ CONVOLUTION = "Conv3d" + BATCHNORM = "BatchNorm" def __init__( self, - down_conv_nn=[], - kernel_size=2, - dilation=1, - stride=2, - N=1, - block="ResBlock", + down_conv_nn: List[int] = [], + kernel_size: int = 2, + dilation: int = 1, + stride: int = 2, + N: int = 1, + block: str = "ResBlock", + activation: torch.nn.Module = torch.nn.ReLU(), + bn_args: Optional[DictConfig] = None, **kwargs, ): + assert len(down_conv_nn) == 2 + if bn_args is None: + bn_args = OmegaConf.create() + else: + bn_args = bn_args.to_dict() block = getattr(_res_blocks, block) super().__init__() if stride > 1: @@ -132,6 +161,7 @@ def __init__( conv1_output = down_conv_nn[1] conv = getattr(snn, self.CONVOLUTION) + bn = getattr(snn, self.BATCHNORM) self.conv_in = ( Seq() .append( @@ -143,19 +173,21 @@ def __init__( dilation=dilation, ) ) - .append(snn.BatchNorm(conv1_output)) - .append(snn.ReLU()) + .append(bn(conv1_output, **bn_args)) + .append(snn.create_activation_function(activation)) ) if N > 0: self.blocks = Seq() for _ in range(N): - self.blocks.append(block(conv1_output, down_conv_nn[1], conv)) + self.blocks.append( + block(conv1_output, down_conv_nn[1], conv, bn=bn, bn_args=bn_args, activation=activation) + ) conv1_output = down_conv_nn[1] else: self.blocks = None - def forward(self, x): + def forward(self, x: snn.SparseTensor): out = self.conv_in(x) if self.blocks: out = self.blocks(out) @@ -169,17 +201,30 @@ class ResNetUp(ResNetDown): CONVOLUTION = "Conv3dTranspose" - def __init__(self, up_conv_nn=[], kernel_size=2, dilation=1, stride=2, N=1, **kwargs): + def __init__( + self, + up_conv_nn: List = [], + kernel_size: int = 2, + dilation: int = 1, + stride: int = 2, + N: int = 1, + block: str = "ResBlock", + activation: torch.nn.Module = torch.nn.ReLU(), + bn_args: Optional[DictConfig] = None, + **kwargs, + ): super().__init__( down_conv_nn=up_conv_nn, kernel_size=kernel_size, dilation=dilation, stride=stride, N=N, - **kwargs, + activation=activation, + bn_args=bn_args, + block=block ** kwargs, ) - def forward(self, x, skip): + def forward(self, x: snn.SparseTensor, skip: Optional[snn.SparseTensor]): if skip is not None: inp = snn.cat(x, skip) else: diff --git a/torch_points3d/modules/SparseConv3d/nn/.#__init__.py b/torch_points3d/modules/SparseConv3d/nn/.#__init__.py new file mode 120000 index 0000000..08941df --- /dev/null +++ b/torch_points3d/modules/SparseConv3d/nn/.#__init__.py @@ -0,0 +1 @@ +admincaor@admincaor.4088:1624465837 \ No newline at end of file diff --git a/torch_points3d/applications/modules/SparseConv3d/nn/__init__.py b/torch_points3d/modules/SparseConv3d/nn/__init__.py similarity index 96% rename from torch_points3d/applications/modules/SparseConv3d/nn/__init__.py rename to torch_points3d/modules/SparseConv3d/nn/__init__.py index 79081d7..368b4d0 100644 --- a/torch_points3d/applications/modules/SparseConv3d/nn/__init__.py +++ b/torch_points3d/modules/SparseConv3d/nn/__init__.py @@ -18,7 +18,7 @@ pass -__all__ = ["cat", "Conv3d", "Conv3dTranspose", "ReLU", "SparseTensor", "BatchNorm"] +__all__ = ["cat", "Conv3d", "Conv3dTranspose", "ReLU", "SparseTensor", "BatchNorm", "create_activation_function"] for val in __all__: exec(val + "=None") diff --git a/torch_points3d/applications/modules/SparseConv3d/nn/minkowski.py b/torch_points3d/modules/SparseConv3d/nn/minkowski.py similarity index 56% rename from torch_points3d/applications/modules/SparseConv3d/nn/minkowski.py rename to torch_points3d/modules/SparseConv3d/nn/minkowski.py index 32251ad..682cac3 100644 --- a/torch_points3d/applications/modules/SparseConv3d/nn/minkowski.py +++ b/torch_points3d/modules/SparseConv3d/nn/minkowski.py @@ -63,3 +63,39 @@ def SparseTensor(feats, coordinates, batch, device=torch.device("cpu")): batch = batch.unsqueeze(-1) coords = torch.cat([batch.int(), coordinates.int()], -1) return ME.SparseTensor(features=feats, coordinates=coords, device=device) + + +class MinkowskiNonlinearityBase(MinkowskiModuleBase): + """ + taken from https://github.com/NVIDIA/MinkowskiEngine/blob/master/MinkowskiEngine/MinkowskiNonlinearity.py + """ + + def __init__(self, module): + super(MinkowskiNonlinearityBase, self).__init__() + self.module = module + + def forward(self, input): + output = self.module(input.F) + if isinstance(input, TensorField): + return TensorField( + output, + coordinate_field_map_key=input.coordinate_field_map_key, + coordinate_manager=input.coordinate_manager, + quantization_mode=input.quantization_mode, + ) + else: + return SparseTensor( + output, + coordinate_map_key=input.coordinate_map_key, + coordinate_manager=input.coordinate_manager, + ) + + def __repr__(self): + return self.module.__class__.__name__ + "()" + + +def create_activation_function(activation: torch.nn.Module = torch.nn.ReLU()): + """ + create an ME activation function from a torch.nn activation function + """ + return MinkowskiNonlinearityBase(module=activation) diff --git a/torch_points3d/applications/modules/SparseConv3d/nn/torchsparse.py b/torch_points3d/modules/SparseConv3d/nn/torchsparse.py similarity index 78% rename from torch_points3d/applications/modules/SparseConv3d/nn/torchsparse.py rename to torch_points3d/modules/SparseConv3d/nn/torchsparse.py index 155829c..e295b32 100644 --- a/torch_points3d/applications/modules/SparseConv3d/nn/torchsparse.py +++ b/torch_points3d/modules/SparseConv3d/nn/torchsparse.py @@ -71,3 +71,19 @@ def SparseTensor(feats, coordinates, batch, device=torch.device("cpu")): batch = batch.unsqueeze(-1) coords = torch.cat([coordinates.int(), batch.int()], -1) return TS.SparseTensor(feats, coords).to(device) + + +class TorchSparseNonLinearityBase(torch.nn.Module): + def __init__(module): + super(TorchSparseNonLinearityBase, self).__init__() + self.module = module + + def forward(self, input: TS.SparseTensor): + return TS.nn.utils.fapply(input, self.module) + + +def create_activation_function(activation: torch.nn.Module = torch.nn.ReLU()): + """ + create an TS activation function from a torch.nn activation function + """ + return TorchSparseNonLinearityBase(module=activation) diff --git a/torch_points3d/modules/base_modules.py b/torch_points3d/modules/base_modules.py new file mode 100644 index 0000000..709b02e --- /dev/null +++ b/torch_points3d/modules/base_modules.py @@ -0,0 +1,11 @@ +from typing import Dict, Any +import torch +from abc import abstractmethod + + +class BaseInternalLossModule(torch.nn.Module): + """ABC for modules which have internal loss(es)""" + + @abstractmethod + def get_internal_losses(self) -> Dict[str, Any]: + pass diff --git a/torch_points3d/modules/pointnet2/__init__.py b/torch_points3d/modules/pointnet2/__init__.py new file mode 100644 index 0000000..44b7ac4 --- /dev/null +++ b/torch_points3d/modules/pointnet2/__init__.py @@ -0,0 +1,2 @@ +from .dense import * +from .message_passing import * diff --git a/torch_points3d/modules/pointnet2/dense.py b/torch_points3d/modules/pointnet2/dense.py new file mode 100644 index 0000000..5d04d9a --- /dev/null +++ b/torch_points3d/modules/pointnet2/dense.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_points_kernels as tp + +from torch_points3d.core.base_conv.dense import * +from torch_points3d.core.spatial_ops import DenseRadiusNeighbourFinder, DenseFPSSampler + +# from torch_points3d.utils.model_building_utils.activation_resolver import get_activation + + +class PointNetMSGDown(BaseDenseConvolutionDown): + def __init__( + self, + npoint=None, + radii=None, + nsample=None, + down_conv_nn=None, + bn=True, + activation=torch.nn.LeakyReLU(negative_slope=0.01), + use_xyz=True, + normalize_xyz=False, + **kwargs + ): + assert len(radii) == len(nsample) == len(down_conv_nn) + super(PointNetMSGDown, self).__init__( + DenseFPSSampler(num_to_sample=npoint), DenseRadiusNeighbourFinder(radii, nsample), **kwargs + ) + self.use_xyz = use_xyz + self.npoint = npoint + self.mlps = nn.ModuleList() + for i in range(len(radii)): + self.mlps.append(MLP2D(down_conv_nn[i], bn=bn, activation=activation, bias=False)) + self.radii = radii + self.normalize_xyz = normalize_xyz + + def _prepare_features(self, x, pos, new_pos, idx, scale_idx): + new_pos_trans = pos.transpose(1, 2).contiguous() + grouped_pos = tp.grouping_operation(new_pos_trans, idx) # (B, 3, npoint, nsample) + grouped_pos -= new_pos.transpose(1, 2).unsqueeze(-1) + + if self.normalize_xyz: + grouped_pos /= self.radii[scale_idx] + + if x is not None: + grouped_features = tp.grouping_operation(x, idx) + if self.use_xyz: + new_features = torch.cat([grouped_pos, grouped_features], dim=1) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" + new_features = grouped_pos + + return new_features + + def conv(self, x, pos, new_pos, radius_idx, scale_idx): + """Implements a Dense convolution where radius_idx represents + the indexes of the points in x and pos to be agragated into the new feature + for each point in new_pos + + Arguments: + x -- Previous features [B, N, C] + pos -- Previous positions [B, N, 3] + new_pos -- Sampled positions [B, npoints, 3] + radius_idx -- Indexes to group [B, npoints, nsample] + scale_idx -- Scale index in multiscale convolutional layers + Returns: + new_x -- Features after passing trhough the MLP [B, mlp[-1], npoints] + """ + assert scale_idx < len(self.mlps) + new_features = self._prepare_features(x, pos, new_pos, radius_idx, scale_idx) + new_features = self.mlps[scale_idx](new_features) # (B, mlp[-1], npoint, nsample) + new_features = F.max_pool2d(new_features, kernel_size=[1, new_features.size(3)]) # (B, mlp[-1], npoint, 1) + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + return new_features diff --git a/torch_points3d/modules/pointnet2/message_passing.py b/torch_points3d/modules/pointnet2/message_passing.py new file mode 100644 index 0000000..7a37e82 --- /dev/null +++ b/torch_points3d/modules/pointnet2/message_passing.py @@ -0,0 +1,31 @@ +from torch_geometric.nn import PointConv + +from torch_points3d.core.base_conv.base_conv import * +from torch_points3d.core.base_conv.message_passing import * +from torch_points3d.core.common_modules.base_modules import * +from torch_points3d.core.spatial_ops import FPSSampler, RandomSampler, MultiscaleRadiusNeighbourFinder + + +class SAModule(BaseMSConvolutionDown): + def __init__(self, ratio=None, radius=None, radius_num_point=None, down_conv_nn=None, *args, **kwargs): + super(SAModule, self).__init__( + FPSSampler(ratio=ratio), + MultiscaleRadiusNeighbourFinder(radius, max_num_neighbors=radius_num_point), + *args, + **kwargs + ) + + local_nn = MLP(down_conv_nn) if down_conv_nn is not None else None + + self._conv = PointConv(local_nn=local_nn, global_nn=None) + self._radius = radius + self._ratio = ratio + self._num_points = radius_num_point + + def conv(self, x, pos, edge_index, batch): + return self._conv(x, pos, edge_index) + + def extra_repr(self): + return "{}(ratio {}, radius {}, radius_points {})".format( + self.__class__.__name__, self._ratio, self._radius, self._num_points + ) diff --git a/torch_points3d/utils/config.py b/torch_points3d/utils/config.py new file mode 100644 index 0000000..d99fe10 --- /dev/null +++ b/torch_points3d/utils/config.py @@ -0,0 +1,32 @@ +import numpy as np +from typing import List +import shutil +import matplotlib.pyplot as plt +import os +from os import path as osp +import torch +import logging +from collections import namedtuple +from omegaconf import OmegaConf +from omegaconf.listconfig import ListConfig +from omegaconf.dictconfig import DictConfig +from .enums import ConvolutionFormat +from torch_points3d.utils.debugging_vars import DEBUGGING_VARS +import subprocess + +log = logging.getLogger(__name__) + + +class ConvolutionFormatFactory: + @staticmethod + def check_is_dense_format(conv_type): + if ( + conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower() + or conv_type.lower() == ConvolutionFormat.MESSAGE_PASSING.value.lower() + or conv_type.lower() == ConvolutionFormat.SPARSE.value.lower() + ): + return False + elif conv_type.lower() == ConvolutionFormat.DENSE.value.lower(): + return True + else: + raise NotImplementedError("Conv type {} not supported".format(conv_type)) diff --git a/torch_points3d/utils/debugging_vars.py b/torch_points3d/utils/debugging_vars.py new file mode 100644 index 0000000..41c582c --- /dev/null +++ b/torch_points3d/utils/debugging_vars.py @@ -0,0 +1,48 @@ +import numpy as np + +DEBUGGING_VARS = {"FIND_NEIGHBOUR_DIST": False} + + +def extract_histogram(spatial_ops, normalize=True): + out = [] + for idx, nf in enumerate(spatial_ops["neighbour_finder"]): + dist_meters = nf.dist_meters + temp = {} + for dist_meter in dist_meters: + hist = dist_meter.histogram.copy() + if normalize: + hist /= hist.sum() + temp[str(dist_meter.radius)] = hist.tolist() + dist_meter.reset() + out.append(temp) + return out + + +class DistributionNeighbour(object): + def __init__(self, radius, bins=1000): + self._radius = radius + self._bins = bins + self._histogram = np.zeros(self._bins) + + def reset(self): + self._histogram = np.zeros(self._bins) + + @property + def radius(self): + return self._radius + + @property + def histogram(self): + return self._histogram + + @property + def histogram_non_zero(self): + idx = len(self._histogram) - np.cumsum(self._histogram[::-1]).nonzero()[0][0] + return self._histogram[:idx] + + def add_valid_neighbours(self, points): + for num_valid in points: + self._histogram[num_valid] += 1 + + def __repr__(self): + return "{}(radius={}, bins={})".format(self.__class__.__name__, self._radius, self._bins) diff --git a/torch_points3d/utils/enums.py b/torch_points3d/utils/enums.py new file mode 100644 index 0000000..f9de27f --- /dev/null +++ b/torch_points3d/utils/enums.py @@ -0,0 +1,14 @@ +import enum + + +class SchedulerUpdateOn(enum.Enum): + ON_EPOCH = "on_epoch" + ON_NUM_BATCH = "on_num_batch" + ON_NUM_SAMPLE = "on_num_sample" + + +class ConvolutionFormat(enum.Enum): + DENSE = "dense" + PARTIAL_DENSE = "partial_dense" + MESSAGE_PASSING = "message_passing" + SPARSE = "sparse" diff --git a/torch_points3d/utils/model_building_utils/activation_resolver.py b/torch_points3d/utils/model_building_utils/activation_resolver.py new file mode 100644 index 0000000..e89b994 --- /dev/null +++ b/torch_points3d/utils/model_building_utils/activation_resolver.py @@ -0,0 +1,19 @@ +import torch.nn + +from torch_points3d.utils.config import is_dict + + +def get_activation(act_opt, create_cls=True): + if is_dict(act_opt): + act_opt = dict(act_opt) + act = getattr(torch.nn, act_opt["name"]) + del act_opt["name"] + args = dict(act_opt) + else: + act = getattr(torch.nn, act_opt) + args = {} + + if create_cls: + return act(**args) + else: + return act