Skip to content

Commit

Permalink
Merge pull request #31 from teddykoker/half
Browse files Browse the repository at this point in the history
Add fp16 support
  • Loading branch information
teddykoker authored Aug 21, 2021
2 parents 27b4ce5 + 82a9759 commit 5b45bfa
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ext_modules():

setup(
name="torchsort",
version="0.1.6",
version="0.1.7",
description="Differentiable sorting and ranking in PyTorch",
author="Teddy Koker",
url="https://github.com/teddykoker/torchsort",
Expand Down
20 changes: 18 additions & 2 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
REGULARIZATION = ["l2", "kl"]
REGULARIZATION_STRENGTH = [1e-1, 1e0, 1e1]

DEVICES = (
[torch.device("cpu")] + ([torch.device("cuda")] if torch.cuda.is_available() else [])
DEVICES = [torch.device("cpu")] + (
[torch.device("cuda")] if torch.cuda.is_available() else []
)

torch.manual_seed(0)
Expand Down Expand Up @@ -59,3 +59,19 @@ def test_vs_original(funcs, regularization, regularization_strength, device):
funcs[0](x, **kwargs).cpu(),
funcs[1](x.cpu(), **kwargs),
)


@pytest.mark.parametrize("function", [soft_rank, soft_sort])
@pytest.mark.parametrize("regularization", REGULARIZATION)
@pytest.mark.parametrize("regularization_strength", REGULARIZATION_STRENGTH)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to test fp16")
def test_half(function, regularization, regularization_strength, device):
x = torch.randn(BATCH_SIZE, SEQ_LEN, requires_grad=True).cuda().half()
f = partial(
function,
regularization=regularization,
regularization_strength=regularization_strength,
)
# don't think theres a better way of testing, tolerance must be pretty high
assert torch.allclose(f(x), f(x.float()).half(), atol=1e-1)
8 changes: 4 additions & 4 deletions torchsort/isotonic_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) {
auto target = torch::zeros_like(y);
auto c = torch::zeros_like(y);

AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_l2", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_l2", ([&] {
isotonic_l2_kernel<scalar_t>(
y.accessor<scalar_t, 2>(),
sol.accessor<scalar_t, 2>(),
Expand All @@ -311,7 +311,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
auto lse_w_ = torch::zeros_like(y);
auto target = torch::zeros_like(y);

AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_kl", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_kl", ([&] {
isotonic_kl_kernel<scalar_t>(
y.accessor<scalar_t, 2>(),
w.accessor<scalar_t, 2>(),
Expand All @@ -330,7 +330,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te
auto n = sol.size(1);
auto ret = torch::zeros_like(sol);

AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_l2_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_l2_backward", ([&] {
isotonic_l2_backward_kernel<scalar_t>(
s.accessor<scalar_t, 2>(),
sol.accessor<scalar_t, 2>(),
Expand All @@ -347,7 +347,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te
auto n = sol.size(1);
auto ret = torch::zeros_like(sol);

AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_kl_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_kl_backward", ([&] {
isotonic_kl_backward_kernel<scalar_t>(
s.accessor<scalar_t, 2>(),
sol.accessor<scalar_t, 2>(),
Expand Down
8 changes: 4 additions & 4 deletions torchsort/isotonic_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) {
const int threads = 1024;
const int blocks = (batch + threads - 1) / threads;

AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_l2", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_l2", ([&] {
isotonic_l2_kernel<scalar_t><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
sol.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
Expand All @@ -342,7 +342,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) {
const int threads = 1024;
const int blocks = (batch + threads - 1) / threads;

AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_kl", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_kl", ([&] {
isotonic_kl_kernel<scalar_t><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
w.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
Expand All @@ -365,7 +365,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te
const int threads = 1024;
const int blocks = (batch + threads - 1) / threads;

AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_l2_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_l2_backward", ([&] {
isotonic_l2_backward_kernel<scalar_t><<<blocks, threads>>>(
s.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
sol.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
Expand All @@ -387,7 +387,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te
const int threads = 1024;
const int blocks = (batch + threads - 1) / threads;

AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_kl_backward", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_kl_backward", ([&] {
isotonic_kl_backward_kernel<scalar_t><<<blocks, threads>>>(
s.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
sol.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
Expand Down

0 comments on commit 5b45bfa

Please sign in to comment.