Skip to content

Commit

Permalink
Adjust caching logic to prevent re-using build directories that have …
Browse files Browse the repository at this point in the history
…already been loaded into the current python session
  • Loading branch information
saipraveenb25 committed Jul 10, 2024
1 parent 2375b88 commit 8c8c7e5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
2 changes: 0 additions & 2 deletions examples/hard-rasterizer-example/rasterizer2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def hook(grad):
var.grad = grad
return hook

rasterizer2d_core = slangtorch.loadModule("hard-rasterizer2d.slang")

# Run our training loop.
def optimize(i):
print("Iteration %d" % i)
Expand Down
29 changes: 29 additions & 0 deletions slangtorch/slangtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@
if not os.path.exists(slangcPath):
raise RuntimeError(f"Could not find slangc executable at {slangcPath}")

# Mapping from module key to latest version number. Used to create unique build directories.
MODULE_VERSIONS = {}

# Mapping from module key to set of loaded build directories. Used to avoid re-using build directories
# whose binaries are in use (compilation will fail)
#
LOADED_BUILD_DIRS = {}

def getUniqueSessionVersion(moduleKey):
if moduleKey not in MODULE_VERSIONS:
MODULE_VERSIONS[moduleKey] = 0
Expand All @@ -55,6 +61,23 @@ def getUniqueSessionVersion(moduleKey):

return MODULE_VERSIONS[moduleKey]

def getCurrentSessionVersion(moduleKey):
if moduleKey not in MODULE_VERSIONS:
MODULE_VERSIONS[moduleKey] = 0

return MODULE_VERSIONS[moduleKey]

def addLoadedDirectoryEntry(moduleKey, version):
if moduleKey not in LOADED_BUILD_DIRS:
LOADED_BUILD_DIRS[moduleKey] = set()

LOADED_BUILD_DIRS[moduleKey].add(version)

def isDirectoryInUse(moduleKey, version):
if moduleKey not in LOADED_BUILD_DIRS:
return False

return version in LOADED_BUILD_DIRS[moduleKey]

def _replaceFileExt(fileName, newExt, suffix=None):
baseName, old_extension = os.path.splitext(fileName)
Expand Down Expand Up @@ -166,6 +189,10 @@ def getOrCreateUniqueDir(moduleKey, baseDir):
latestBuildID = None

targetBuildID = getUniqueSessionVersion(moduleKey)

while (isDirectoryInUse(moduleKey, makeBuildDirPath(baseDir, targetBuildID))):
# If the build directory is in use, we need to create a new build directory.
targetBuildID = getUniqueSessionVersion(moduleKey)

targetDir = None
if (latestBuildID is None) or targetBuildID == latestBuildID:
Expand Down Expand Up @@ -623,6 +650,8 @@ def loadModule(fileName, skipSlang=None, verbose=False, defines={}, includePaths
print(f"Working folder: {buildDir}", file=sys.stderr)

rawModule = _loadModule(fileName, moduleName, buildDir, options, sourceDir=outputFolder, verbose=verbose, includePaths=includePaths, dryRun=False)
addLoadedDirectoryEntry(outputFolder, buildDir)

return wrapModule(rawModule)


Expand Down

0 comments on commit 8c8c7e5

Please sign in to comment.