Skip to content

Commit

Permalink
Add dry run to get RNG seed and offset
Browse files Browse the repository at this point in the history
  • Loading branch information
rdspring1 committed Dec 5, 2023
1 parent 3c54867 commit 035f766
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
20 changes: 20 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,26 @@ struct LowerGuard {

} // namespace

kir::Kernel* GpuLower::dryRun() {
FusionGuard fg(fusion_);
LowerGuard lower_guard(this);
// Reorder expressions for loop-nest generation respecting computeAt
// relationships
auto exprs_lowered = reorderExprsForComputeAt();
dumpExprsIfEnabled(exprs_lowered, "reorderExprsForComputeAt");

commonScalarMap().initialize(exprs_lowered);

// For RNG ops whose seed and offset are not yet set, grab the seed and offset
// from the host and assign them to the ops.
// This must be after expr sort, because we do not want the generated
// computation of offset and seed to be considered as part of fusion
// definition
assignRNGOffset(fusion_);

return kernel_.get();
}

kir::Kernel* GpuLower::run() {
FusionGuard fg(fusion_);
LowerGuard lower_guard(this);
Expand Down
3 changes: 3 additions & 0 deletions csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class GpuLower : public NonCopyable {
//! Query if lowering is in progress
static bool hasCurrent();

//! Partially lower without executing any passes
kir::Kernel* dryRun();

//! Actually run the lowering by executing the passes in the order given by
//! passes_
kir::Kernel* run();
Expand Down
4 changes: 3 additions & 1 deletion csrc/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2242,8 +2242,10 @@ void FusionExecutor::deserialize(
compile_params.index_type = serde::mapToNvfuserDtype(buffer->index_type());
compile_params.maxrregcount = maxrregcount_high_water_mark_;

// Get lowered fusion
// Get lowered fusion and then dry run to get RNG seed and offset
lowered_ = std::make_unique<GpuLower>(fusion, compile_params);
// TODO only dry run if there are RNG operations in fusion.
lowered_->dryRun();

// Replace integers that are tensor sizes by named scalars like "T0.size[0]"
fusion_ = lowered_->kernel()->as<Fusion>();
Expand Down

0 comments on commit 035f766

Please sign in to comment.