From 454789a7af892e18e797b8fb450caf1d11b6eb9a Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Sat, 21 Aug 2021 17:43:09 -0400 Subject: [PATCH 1/2] add support for floating point --- tests/test_ops.py | 19 +++++++++++++++++-- torchsort/isotonic_cpu.cpp | 8 ++++---- torchsort/isotonic_cuda.cu | 8 ++++---- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 021a04b..a079ccf 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -15,8 +15,8 @@ REGULARIZATION = ["l2", "kl"] REGULARIZATION_STRENGTH = [1e-1, 1e0, 1e1] -DEVICES = ( - [torch.device("cpu")] + ([torch.device("cuda")] if torch.cuda.is_available() else []) +DEVICES = [torch.device("cpu")] + ( + [torch.device("cuda")] if torch.cuda.is_available() else [] ) torch.manual_seed(0) @@ -59,3 +59,18 @@ def test_vs_original(funcs, regularization, regularization_strength, device): funcs[0](x, **kwargs).cpu(), funcs[1](x.cpu(), **kwargs), ) + + +@pytest.mark.parametrize("function", [soft_rank, soft_sort]) +@pytest.mark.parametrize("regularization", REGULARIZATION) +@pytest.mark.parametrize("regularization_strength", REGULARIZATION_STRENGTH) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to test fp16") +def test_half(function, regularization, regularization_strength, device): + x = torch.randn(BATCH_SIZE, SEQ_LEN, dtype=torch.float16, requires_grad=True).cuda() + f = partial( + function, + regularization=regularization, + regularization_strength=regularization_strength, + ) + f(x) diff --git a/torchsort/isotonic_cpu.cpp b/torchsort/isotonic_cpu.cpp index f2fb445..4efeddb 100644 --- a/torchsort/isotonic_cpu.cpp +++ b/torchsort/isotonic_cpu.cpp @@ -288,7 +288,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) { auto target = torch::zeros_like(y); auto c = torch::zeros_like(y); - AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_l2", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_l2", ([&] { isotonic_l2_kernel( y.accessor(), sol.accessor(), @@ -311,7 +311,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) { auto lse_w_ = torch::zeros_like(y); auto target = torch::zeros_like(y); - AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_kl", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_kl", ([&] { isotonic_kl_kernel( y.accessor(), w.accessor(), @@ -330,7 +330,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te auto n = sol.size(1); auto ret = torch::zeros_like(sol); - AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_l2_backward", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_l2_backward", ([&] { isotonic_l2_backward_kernel( s.accessor(), sol.accessor(), @@ -347,7 +347,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te auto n = sol.size(1); auto ret = torch::zeros_like(sol); - AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_kl_backward", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_kl_backward", ([&] { isotonic_kl_backward_kernel( s.accessor(), sol.accessor(), diff --git a/torchsort/isotonic_cuda.cu b/torchsort/isotonic_cuda.cu index 2f26799..2098fad 100644 --- a/torchsort/isotonic_cuda.cu +++ b/torchsort/isotonic_cuda.cu @@ -316,7 +316,7 @@ torch::Tensor isotonic_l2(torch::Tensor y) { const int threads = 1024; const int blocks = (batch + threads - 1) / threads; - AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_l2", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_l2", ([&] { isotonic_l2_kernel<<>>( y.packed_accessor32(), sol.packed_accessor32(), @@ -342,7 +342,7 @@ torch::Tensor isotonic_kl(torch::Tensor y, torch::Tensor w) { const int threads = 1024; const int blocks = (batch + threads - 1) / threads; - AT_DISPATCH_FLOATING_TYPES(y.scalar_type(), "isotonic_kl", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(y.scalar_type(), "isotonic_kl", ([&] { isotonic_kl_kernel<<>>( y.packed_accessor32(), w.packed_accessor32(), @@ -365,7 +365,7 @@ torch::Tensor isotonic_l2_backward(torch::Tensor s, torch::Tensor sol, torch::Te const int threads = 1024; const int blocks = (batch + threads - 1) / threads; - AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_l2_backward", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_l2_backward", ([&] { isotonic_l2_backward_kernel<<>>( s.packed_accessor32(), sol.packed_accessor32(), @@ -387,7 +387,7 @@ torch::Tensor isotonic_kl_backward(torch::Tensor s, torch::Tensor sol, torch::Te const int threads = 1024; const int blocks = (batch + threads - 1) / threads; - AT_DISPATCH_FLOATING_TYPES(sol.scalar_type(), "isotonic_kl_backward", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF(sol.scalar_type(), "isotonic_kl_backward", ([&] { isotonic_kl_backward_kernel<<>>( s.packed_accessor32(), sol.packed_accessor32(), From 82a9759af7b207f8109e0fa83da2b6bb6c665cff Mon Sep 17 00:00:00 2001 From: Teddy Koker Date: Sat, 21 Aug 2021 18:04:18 -0400 Subject: [PATCH 2/2] better test, version num --- setup.py | 2 +- tests/test_ops.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 43aa366..7634714 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def ext_modules(): setup( name="torchsort", - version="0.1.6", + version="0.1.7", description="Differentiable sorting and ranking in PyTorch", author="Teddy Koker", url="https://github.com/teddykoker/torchsort", diff --git a/tests/test_ops.py b/tests/test_ops.py index a079ccf..e856f2f 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -67,10 +67,11 @@ def test_vs_original(funcs, regularization, regularization_strength, device): @pytest.mark.parametrize("device", DEVICES) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA to test fp16") def test_half(function, regularization, regularization_strength, device): - x = torch.randn(BATCH_SIZE, SEQ_LEN, dtype=torch.float16, requires_grad=True).cuda() + x = torch.randn(BATCH_SIZE, SEQ_LEN, requires_grad=True).cuda().half() f = partial( function, regularization=regularization, regularization_strength=regularization_strength, ) - f(x) + # don't think theres a better way of testing, tolerance must be pretty high + assert torch.allclose(f(x), f(x.float()).half(), atol=1e-1)