diff --git a/Src/ILGPU.Algorithms/Random/RNG.cs b/Src/ILGPU.Algorithms/Random/RNG.cs
index d536c779f..77e78a822 100644
--- a/Src/ILGPU.Algorithms/Random/RNG.cs
+++ b/Src/ILGPU.Algorithms/Random/RNG.cs
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU Algorithms
-// Copyright (c) 2021 ILGPU Project
+// Copyright (c) 2021-2022 ILGPU Project
// www.ilgpu.net
//
// File: RNG.cs
@@ -158,17 +158,17 @@ public abstract void FillUniform(
///
/// The maximum number of parallel groups.
///
- private readonly int groupSize;
+ private readonly int maxNumParallelWarps;
///
/// Initializes the RNG view.
///
/// The random providers.
- /// The maximum number of parallel groups.
- internal RNGView(ArrayView providers, int numParallelGroups)
+ /// The maximum number of parallel warps.
+ internal RNGView(ArrayView providers, int numParallelWarps)
{
randomProviders = providers;
- groupSize = numParallelGroups;
+ maxNumParallelWarps = numParallelWarps;
}
#endregion
@@ -183,17 +183,17 @@ internal RNGView(ArrayView providers, int numParallelGroups)
private readonly ref TRandomProvider GetRandomProvider()
{
// Compute the global warp index
- int groupOffset = Stride3D.DenseXY.ComputeElementIndex(
- Grid.Index,
- Grid.Dimension) % groupSize;
- int warpOffset = Group.LinearIndex;
- int warpIdx = groupOffset * Warp.WarpSize + warpOffset / Warp.WarpSize;
+ int groupIndex = Group.LinearIndex;
+ int warpIndex = Warp.ComputeWarpIdx(groupIndex);
+ int groupStride = XMath.DivRoundUp(Group.Dimension.Size, Warp.WarpSize);
+ int groupOffset = Grid.LinearIndex * groupStride;
+ int providerIndex = groupOffset + warpIndex;
// Access the underlying provider
Trace.Assert(
- warpIdx < randomProviders.Length,
+ providerIndex < randomProviders.Length,
"Current warp does not have a valid RNG provider");
- return ref randomProviders[warpIdx];
+ return ref randomProviders[providerIndex];
}
///
@@ -403,14 +403,11 @@ public RNGView GetViewViaThreads(int numThreads) =>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public RNGView GetView(int numWarps)
{
- // Ensure that the number of warps is a multiple of the warp size.
- int numGroups = XMath.DivRoundUp(numWarps, Accelerator.WarpSize);
- numWarps = numGroups * Accelerator.WarpSize;
Trace.Assert(
numWarps > 0 && numWarps <= randomProvidersPerWarp.Length,
"Invalid number of warps");
var subView = randomProvidersPerWarp.View.SubView(0, numWarps);
- return new RNGView(subView, numGroups);
+ return new RNGView(subView, numWarps);
}
///
diff --git a/Src/ILGPU/Grid.cs b/Src/ILGPU/Grid.cs
index 5272040bd..a47c7887b 100644
--- a/Src/ILGPU/Grid.cs
+++ b/Src/ILGPU/Grid.cs
@@ -1,6 +1,6 @@
// ---------------------------------------------------------------------------------------
// ILGPU
-// Copyright (c) 2017-2021 ILGPU Project
+// Copyright (c) 2017-2022 ILGPU Project
// www.ilgpu.net
//
// File: Grid.cs
@@ -101,6 +101,13 @@ public static int DimZ
/// The grid dimension.
public static Index3D Dimension => new Index3D(DimX, DimY, DimZ);
+ ///
+ /// Returns the linear grid index of the current group within the current
+ /// thread grid.
+ ///
+ public static int LinearIndex =>
+ Stride3D.DenseXY.ComputeElementIndex(Index, Dimension);
+
///
/// Returns the global index.
///
@@ -115,6 +122,13 @@ public static int DimZ
Index,
Group.Index);
+ ///
+ /// Returns the linear thread index of the current thread within the current
+ /// thread grid.
+ ///
+ public static int GlobalLinearIndex =>
+ LinearIndex * Group.Dimension.Size + Group.LinearIndex;
+
#endregion
#region Methods