Skip to content

ivan-chai/torch-linear-assignment

Repository files navigation

Batch linear assignment for PyTorch

PyPI version Build Status Downloads License

Batch computation of the linear assignment problem on GPU.

Install

Build and install via PyPI (source distribution):

pip install torch-linear-assignment

Build and install from Git repository:

pip install .

When building with CUDA, make sure NVCC has the same CUDA version as PyTorch. You can choose CUDA version by

export PATH=/usr/local/cuda-<version>/bin:"$PATH"

If you need custom C++ compiler, use the following command:

CXX=<c++-compiler> CC=<gcc-compiler> pip install .

If you get a torch-not-found error, try the following command:

pip install --upgrade pip wheel setuptools
python -m pip install .

Example

import torch
from torch_linear_assignment import batch_linear_assignment

cost = torch.tensor([
    8, 4, 7,
    5, 2, 3,
    9, 6, 7,
    9, 4, 8
]).reshape(1, 4, 3).cuda()

assignment = batch_linear_assignment(cost)
print(assignment)

The output is:

tensor([[ 0,  2, -1,  1]], device='cuda:0')

To get indices in the SciPy's format:

from torch_linear_assignment import assignment_to_indices

row_ind, col_ind = assignment_to_indices(assignment)
print(row_ind)
print(col_ind)

The output is:

tensor([[0, 1, 3]], device='cuda:0')
tensor([[0, 2, 1]], device='cuda:0')

Citation

The code was originally developed for the HoTPP Benchmark. If you use this code in your research project, please cite one of the following papers:

@article{karpukhin2024hotppbenchmark,
  title={HoTPP Benchmark: Are We Good at the Long Horizon Events Forecasting?},
  author={Karpukhin, Ivan and Shipilov, Foma and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2406.14341},
  year={2024},
  url ={https://arxiv.org/abs/2406.14341}
}

@article{karpukhin2024detpp,
  title={DeTPP: Leveraging Object Detection for Robust Long-Horizon Event Prediction},
  author={Karpukhin, Ivan and Savchenko, Andrey},
  journal={arXiv preprint arXiv:2408.13131},
  year={2024},
  url ={https://arxiv.org/abs/2408.13131}
}