Skip to content

Commit

Permalink
Merge pull request #13 from saipraveenb25/main
Browse files Browse the repository at this point in the history
Support `half` types when compiling CUDA generated from Slang
  • Loading branch information
saipraveenb25 authored Jul 10, 2024
2 parents 3e528b4 + 2375b88 commit 17b5d37
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
14 changes: 11 additions & 3 deletions slangtorch/slangtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
packageDir = pkg_resources.resource_filename(__name__, '')
versionCode = my_version = pkg_resources.get_distribution('slangtorch').version

DEFAULT_CUDA_CFLAGS = ["-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-DSLANG_CUDA_ENABLE_HALF=1",]

if sys.platform == "win32":
# Windows
executable_extension = ".exe"
Expand Down Expand Up @@ -446,8 +452,8 @@ def _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDi
# make sure to add cl.exe to PATH on windows so ninja can find it.
_add_msvc_to_env_var()

extra_cflags = None
extra_cuda_cflags = None
extra_cflags = []
extra_cuda_cflags = []
# If windows, add /std:c++17 to extra_cflags
if sys.platform == "win32":
extra_cflags = ["/std:c++17"]
Expand All @@ -463,11 +469,13 @@ def _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDi
else:
extra_include_paths = None

extra_cuda_cflags = extra_cuda_cflags + DEFAULT_CUDA_CFLAGS

return jit_compile(
moduleName,
sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_cuda_cflags=extra_cuda_cflags if extra_cuda_cflags else None,
extra_ldflags=None,
extra_include_paths=extra_include_paths,
build_directory=os.path.realpath(buildDir),
Expand Down
13 changes: 13 additions & 0 deletions tests/autobind-square-half.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[AutoPyBindCUDA]
[CUDAKernel]
void square(TensorView<half> input, TensorView<half> output)
{
// Get the 'global' index of this thread.
uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim();

// If the thread index is beyond the input size, exit early.
if (dispatchIdx.x >= input.size(0))
return;

output[dispatchIdx.x] = input[dispatchIdx.x] * input[dispatchIdx.x];
}
17 changes: 17 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,20 @@ def test_empty_tensor(self):

# Should not crash.

class TestHalfDType(unittest.TestCase):
def setUp(self) -> None:
test_dir = os.path.dirname(os.path.abspath(__file__))
slangModuleSourceFile = os.path.join(test_dir, 'autobind-square-half.slang')

module = slangtorch.loadModule(slangModuleSourceFile)
self.module = module

def test_half_multiply(self):
X = torch.tensor([1., 2., 3., 4.]).cuda().half()
Z = torch.zeros_like(X).cuda().half()

self.module.square(input=X, output=Z).launchRaw(blockSize=(32, 1, 1), gridSize=(1, 1, 1))

expected = torch.tensor([1., 4., 9., 16.]).cuda().half()

assert(torch.all(torch.eq(Z, expected)))

0 comments on commit 17b5d37

Please sign in to comment.