Skip to content

Commit

Permalink
Merge pull request #16 from teddykoker/fix15
Browse files Browse the repository at this point in the history
Fix CUDA Leak + Input Validation
  • Loading branch information
teddykoker authored Apr 17, 2021
2 parents d480043 + 7fdfa94 commit 81e8a5f
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pip install torchsort
To build the CUDA extension you will need the CUDA toolchain installed. If you
want to build in an environment without a CUDA runtime (e.g. docker), you will
need to export the environment variable
`TORCH_CUDA_ARCH_LIST="Pascal;Volta;Turing"` before installing.
`TORCH_CUDA_ARCH_LIST="Pascal;Volta;Turing;Ampere"` before installing.

## Usage

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ext_modules():

setup(
name="torchsort",
version="0.1.2",
version="0.1.3",
description="Differentiable sorting and ranking in PyTorch",
author="Teddy Koker",
url="https://github.com/teddykoker/torchsort",
Expand Down
14 changes: 9 additions & 5 deletions torchsort/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@
def soft_rank(values, regularization="l2", regularization_strength=1.0):
if len(values.shape) != 2:
raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
if regularization not in ["l2", "kl"]:
raise ValueError(f"'regularization' should be a 'l2' or 'kl'")
return SoftRank.apply(values, regularization, regularization_strength)


def soft_sort(values, regularization="l2", regularization_strength=1.0):
if len(values.shape) != 2:
raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
if regularization not in ["l2", "kl"]:
raise ValueError(f"'regularization' should be a 'l2' or 'kl'")
return SoftSort.apply(values, regularization, regularization_strength)


Expand Down Expand Up @@ -90,19 +94,19 @@ def forward(ctx, tensor, regularization="l2", regularization_strength=1.0):
if ctx.regularization == "l2":
dual_sol = isotonic_l2[s.device.type](s - w)
ret = (s - dual_sol).gather(1, inv_permutation)
ctx.factor = 1.0
factor = torch.tensor(1.0, device=s.device)
else:
dual_sol = isotonic_kl[s.device.type](s, torch.log(w))
ret = torch.exp((s - dual_sol).gather(1, inv_permutation))
ctx.factor = ret
factor = ret

ctx.save_for_backward(s, dual_sol, permutation, inv_permutation)
ctx.save_for_backward(factor, s, dual_sol, permutation, inv_permutation)
return ret

@staticmethod
def backward(ctx, grad_output):
grad = (grad_output * ctx.factor).clone()
s, dual_sol, permutation, inv_permutation = ctx.saved_tensors
factor, s, dual_sol, permutation, inv_permutation = ctx.saved_tensors
grad = (grad_output * factor).clone()
if ctx.regularization == "l2":
grad -= isotonic_l2_backward[s.device.type](
s, dual_sol, grad.gather(1, permutation)
Expand Down

0 comments on commit 81e8a5f

Please sign in to comment.