Skip to content

Commit

Permalink
Add options for cuda flags, slang flags, skipping ninja check & gener…
Browse files Browse the repository at this point in the history
…ating line info
  • Loading branch information
saipraveenb25 committed Oct 29, 2024
1 parent 695bc9f commit 1d8ee70
Showing 1 changed file with 56 additions and 23 deletions.
79 changes: 56 additions & 23 deletions slangtorch/slangtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-DSLANG_CUDA_ENABLE_HALF=1",]
"-DSLANG_CUDA_ENABLE_HALF=1"]

if sys.platform == "win32":
# Windows
Expand Down Expand Up @@ -137,9 +137,9 @@ def tryGetSlangDynamicLibraryPath():
return None


def getDictionaryHash(dictionary, truncate_at=16):
# Convert dictionary to JSON string
jsonString = json.dumps(dictionary, sort_keys=True)
def getHash(obj, truncate_at=16):
# Convert obj to JSON string
jsonString = json.dumps(obj, sort_keys=True)

# Compute SHA-256 hash of the JSON string
hashObject = hashlib.sha256(jsonString.encode())
Expand Down Expand Up @@ -238,7 +238,7 @@ def getOrCreateUniqueDir(moduleKey, baseDir):
return targetDir


def compileSlang(metadata, fileName, targetMode, options, outputFile, verbose=False, includePaths=[], dryRun=False):
def compileSlang(metadata, fileName, targetMode, options, outputFile, verbose=False, includePaths=[], dryRun=False, extraSlangFlags=[]):
needsRecompile = False

# If version either doesn't exist or is different, we need to recompile.
Expand Down Expand Up @@ -310,21 +310,21 @@ def compileSlang(metadata, fileName, targetMode, options, outputFile, verbose=Fa
needsRecompile = True

if needsRecompile:
return True, (_compileSlang(metadata, fileName, targetMode, options, outputFile, includePaths, verbose) if not dryRun else None)
return True, (_compileSlang(metadata, fileName, targetMode, options, outputFile, includePaths, verbose, extraSlangFlags) if not dryRun else None)
else:
return False, (metadata if not dryRun else None)


def _compileSlang(metadata, fileName, targetMode, options, outputFile, includePaths=[], verbose=False):
def _compileSlang(metadata, fileName, targetMode, options, outputFile, includePaths=[], verbose=False, extraSlangFlags=[]):
# Create a temporary depfile path.
depFile = f"{outputFile}.d.out"

compileCommand = [slangcPath, fileName, *options,
'-target', targetMode,
'-line-directive-mode', 'none',
'-o', outputFile,
'-depfile', depFile,
'-ignore-capabilities']
compileCommand.extend(extraSlangFlags)

if includePaths is not None:
for includePath in includePaths:
Expand Down Expand Up @@ -357,7 +357,7 @@ def _compileSlang(metadata, fileName, targetMode, options, outputFile, includePa
return {"options": options, "deps": deps, "version": versionCode, "includePaths": includePaths}


def compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir=None, verbose=False, dryRun=False):
def compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir=None, verbose=False, dryRun=False, skipNinjaCheck=False, extraCudaFlags=[]):
needsRebuild = False
needsReload = False

Expand All @@ -384,7 +384,7 @@ def compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir
else:
needsRebuild = True

if not needsRebuild:
if not needsRebuild and not skipNinjaCheck:
# One more check: we will run ninja on the build directory to see if there is anything to do.
# This check catches the case where the Slang products are up-to-date, but any downstream
# dependencies such as prelude header files, or user-defined header files have changed.
Expand Down Expand Up @@ -415,6 +415,9 @@ def compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir
needsReload = False
else:
raise RuntimeError(f"Unknown ninja result: {ninja_result}")
else:
if verbose:
print(f"Skipping additional ninja check (WARNING: this may ignore changes to non-slang files)", file=sys.stderr)

cacheLookupKey = moduleName
if not needsRebuild:
Expand Down Expand Up @@ -457,7 +460,7 @@ def compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir
return True, None

# Compile the module.
slangLib = _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir, verbose)
slangLib = _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir, extraCudaFlags, verbose)

newMetadata = metadata.copy()
newMetadata["moduleName"] = moduleName
Expand All @@ -475,7 +478,7 @@ def compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir
compileAndLoadModule._moduleCache = {}


def _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir, verbose=False):
def _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDir, extraCudaFlags=[], verbose=False):
# make sure to add cl.exe to PATH on windows so ninja can find it.
_add_msvc_to_env_var()

Expand All @@ -491,6 +494,9 @@ def _compileAndLoadModule(metadata, sources, moduleName, buildDir, slangSourceDi
extra_cflags = ["-std=c++17"]
extra_cuda_cflags = ["-std=c++17"]

if extraCudaFlags:
extra_cuda_cflags.extend(extraCudaFlags)

if slangSourceDir:
extra_include_paths = [slangSourceDir]
else:
Expand Down Expand Up @@ -531,7 +537,7 @@ def parseDepfile(depFile):
return allDepFiles


def _loadModule(fileName, moduleName, outputFolder, options, sourceDir=None, verbose=False, includePaths=[], dryRun=False):
def _loadModule(fileName, moduleName, outputFolder, options, sourceDir=None, verbose=False, includePaths=[], dryRun=False, skipNinjaCheck=False, extraCudaFlags=[], extraSlangFlags=[]):

# Try to find a metadata file "metadata.json" in outputFolder.
metadataFile = os.path.join(outputFolder, "metadata.json")
Expand All @@ -554,10 +560,10 @@ def _loadModule(fileName, moduleName, outputFolder, options, sourceDir=None, ver
# Compile slang files to intermediate host and kernel modules.
compileStartTime = time.perf_counter()

resultCpp, metadataCpp = compileSlang(metadata.get("cpp", None), fileName, "torch-binding", options, cppOutName, verbose, includePaths=includePaths, dryRun=dryRun)
resultCpp, metadataCpp = compileSlang(metadata.get("cpp", None), fileName, "torch-binding", options, cppOutName, verbose, includePaths=includePaths, dryRun=dryRun, extraSlangFlags=extraSlangFlags)
metadata["cpp"] = metadataCpp

resultCuda, metadataCuda = compileSlang(metadata.get("cuda", None), fileName, "cuda", options, cudaOutName, verbose, includePaths=includePaths, dryRun=dryRun)
resultCuda, metadataCuda = compileSlang(metadata.get("cuda", None), fileName, "cuda", options, cudaOutName, verbose, includePaths=includePaths, dryRun=dryRun, extraSlangFlags=extraSlangFlags)
metadata["cuda"] = metadataCuda

if dryRun and (resultCuda or resultCpp):
Expand All @@ -571,7 +577,9 @@ def _loadModule(fileName, moduleName, outputFolder, options, sourceDir=None, ver
slangLib, metadata = compileAndLoadModule(
metadata, [cppOutName, cudaOutName],
moduleName, outputFolder, slangSourceDir,
verbose, dryRun=dryRun)
verbose, dryRun=dryRun,
skipNinjaCheck=skipNinjaCheck,
extraCudaFlags=extraCudaFlags)

if dryRun:
if slangLib:
Expand All @@ -592,7 +600,7 @@ def _loadModule(fileName, moduleName, outputFolder, options, sourceDir=None, ver
return slangLib


def loadModule(fileName, skipSlang=None, verbose=False, defines={}, includePaths=[]):
def loadModule(fileName, skipSlang=None, verbose=False, defines={}, includePaths=[], skipNinjaCheck=False, slangGenLineInfo=True, cudaFastMath=True, cudaGenLineInfo=True, extraSlangFlags=[], extraCudaFlags=[]):
# Print warning
if skipSlang is not None:
print("Warning: skipSlang is deprecated in favor of a dependency-based cache.", file=sys.stderr)
Expand All @@ -601,10 +609,35 @@ def loadModule(fileName, skipSlang=None, verbose=False, defines={}, includePaths
print(f"Loading slang module: {fileName}", file=sys.stderr)
print(f"Using slangc location: {slangcPath}", file=sys.stderr)

if defines:
optionsHash = getDictionaryHash(defines, truncate_at=16)
else:
optionsHash = getDictionaryHash({}, truncate_at=16)
if not defines:
defines = {}

if not extraCudaFlags:
extraCudaFlags = []

if not extraSlangFlags:
extraSlangFlags = []

assert(isinstance(cudaFastMath, bool))
if cudaFastMath:
if verbose:
print("Using fast math (--use_fast_math)", file=sys.stderr)
extraCudaFlags.append("--use_fast_math")

assert(isinstance(cudaGenLineInfo, bool))
if cudaGenLineInfo:
if verbose:
print("Using line info (--generate-line-info)", file=sys.stderr)
extraCudaFlags.append("--generate-line-info")

assert(isinstance(slangGenLineInfo, bool))
if not slangGenLineInfo:
if verbose:
print("Disabling slang line info (-line-directive-mode none)", file=sys.stderr)
extraSlangFlags.append("-line-directive-mode")
extraSlangFlags.append("none")

optionsHash = getHash([defines, extraCudaFlags, extraSlangFlags], truncate_at=16)

parentFolder = os.path.dirname(fileName)
baseNameWoExt = os.path.splitext(os.path.basename(fileName))[0]
Expand All @@ -631,7 +664,7 @@ def loadModule(fileName, skipSlang=None, verbose=False, defines={}, includePaths
if verbose:
print(f"Dry-run using latest build directory: {buildDir}", file=sys.stderr)

needsRecompile = _loadModule(fileName, moduleName, buildDir, options, sourceDir=outputFolder, verbose=verbose, includePaths=includePaths, dryRun=True)
needsRecompile = _loadModule(fileName, moduleName, buildDir, options, sourceDir=outputFolder, verbose=verbose, includePaths=includePaths, dryRun=True, skipNinjaCheck=skipNinjaCheck, extraCudaFlags=extraCudaFlags, extraSlangFlags=extraSlangFlags)
else:
if verbose:
print(f"No latest build directory.", file=sys.stderr)
Expand All @@ -649,7 +682,7 @@ def loadModule(fileName, skipSlang=None, verbose=False, defines={}, includePaths
if verbose:
print(f"Working folder: {buildDir}", file=sys.stderr)

rawModule = _loadModule(fileName, moduleName, buildDir, options, sourceDir=outputFolder, verbose=verbose, includePaths=includePaths, dryRun=False)
rawModule = _loadModule(fileName, moduleName, buildDir, options, sourceDir=outputFolder, verbose=verbose, includePaths=includePaths, dryRun=False, skipNinjaCheck=skipNinjaCheck, extraCudaFlags=extraCudaFlags, extraSlangFlags=extraSlangFlags)
addLoadedDirectoryEntry(outputFolder, buildDir)

return wrapModule(rawModule)
Expand Down

0 comments on commit 1d8ee70

Please sign in to comment.