From 695bc9f6f54a6ce76553fc4b695a08bd1db007d1 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru Date: Tue, 3 Sep 2024 12:43:29 -0400 Subject: [PATCH] v1.2.5 + Add missing test file --- pyproject.toml | 2 +- tests/multiply_half.slang | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 tests/multiply_half.slang diff --git a/pyproject.toml b/pyproject.toml index f68c30d..339e312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "slangtorch" -version = "1.2.4" +version = "1.2.5" dependencies = [ "torch>=1.1.0", "hatchling>=1.11.0", diff --git a/tests/multiply_half.slang b/tests/multiply_half.slang new file mode 100644 index 0000000..f65c08f --- /dev/null +++ b/tests/multiply_half.slang @@ -0,0 +1,24 @@ +static const half kFactor = 2.h; + +half computeOutputValue(TensorView A, uint2 loc) +{ + return A[loc] * kFactor; +} + +[CudaKernel] +void mul_kernel(TensorView A, TensorView result) +{ + uint2 location = (cudaBlockDim() * cudaBlockIdx() + cudaThreadIdx()).xy; + result[location] = computeOutputValue(A, location); +} + +[TorchEntryPoint] +TorchTensor multiply(TorchTensor A) +{ + var result = TorchTensor.zerosLike(A); + let blockCount = uint3(1); + let groupSize = uint3(A.size(0), A.size(1), 1); + + __dispatch_kernel(mul_kernel, blockCount, groupSize)(A, result); + return result; +} \ No newline at end of file