diff --git a/Src/ILGPU.Algorithms/Random/RNG.cs b/Src/ILGPU.Algorithms/Random/RNG.cs index d536c779f..77e78a822 100644 --- a/Src/ILGPU.Algorithms/Random/RNG.cs +++ b/Src/ILGPU.Algorithms/Random/RNG.cs @@ -1,6 +1,6 @@ // --------------------------------------------------------------------------------------- // ILGPU Algorithms -// Copyright (c) 2021 ILGPU Project +// Copyright (c) 2021-2022 ILGPU Project // www.ilgpu.net // // File: RNG.cs @@ -158,17 +158,17 @@ public abstract void FillUniform( /// /// The maximum number of parallel groups. /// - private readonly int groupSize; + private readonly int maxNumParallelWarps; /// /// Initializes the RNG view. /// /// The random providers. - /// The maximum number of parallel groups. - internal RNGView(ArrayView providers, int numParallelGroups) + /// The maximum number of parallel warps. + internal RNGView(ArrayView providers, int numParallelWarps) { randomProviders = providers; - groupSize = numParallelGroups; + maxNumParallelWarps = numParallelWarps; } #endregion @@ -183,17 +183,17 @@ internal RNGView(ArrayView providers, int numParallelGroups) private readonly ref TRandomProvider GetRandomProvider() { // Compute the global warp index - int groupOffset = Stride3D.DenseXY.ComputeElementIndex( - Grid.Index, - Grid.Dimension) % groupSize; - int warpOffset = Group.LinearIndex; - int warpIdx = groupOffset * Warp.WarpSize + warpOffset / Warp.WarpSize; + int groupIndex = Group.LinearIndex; + int warpIndex = Warp.ComputeWarpIdx(groupIndex); + int groupStride = XMath.DivRoundUp(Group.Dimension.Size, Warp.WarpSize); + int groupOffset = Grid.LinearIndex * groupStride; + int providerIndex = groupOffset + warpIndex; // Access the underlying provider Trace.Assert( - warpIdx < randomProviders.Length, + providerIndex < randomProviders.Length, "Current warp does not have a valid RNG provider"); - return ref randomProviders[warpIdx]; + return ref randomProviders[providerIndex]; } /// @@ -403,14 +403,11 @@ public RNGView GetViewViaThreads(int numThreads) => [MethodImpl(MethodImplOptions.AggressiveInlining)] public RNGView GetView(int numWarps) { - // Ensure that the number of warps is a multiple of the warp size. - int numGroups = XMath.DivRoundUp(numWarps, Accelerator.WarpSize); - numWarps = numGroups * Accelerator.WarpSize; Trace.Assert( numWarps > 0 && numWarps <= randomProvidersPerWarp.Length, "Invalid number of warps"); var subView = randomProvidersPerWarp.View.SubView(0, numWarps); - return new RNGView(subView, numGroups); + return new RNGView(subView, numWarps); } /// diff --git a/Src/ILGPU/Grid.cs b/Src/ILGPU/Grid.cs index 5272040bd..a47c7887b 100644 --- a/Src/ILGPU/Grid.cs +++ b/Src/ILGPU/Grid.cs @@ -1,6 +1,6 @@ // --------------------------------------------------------------------------------------- // ILGPU -// Copyright (c) 2017-2021 ILGPU Project +// Copyright (c) 2017-2022 ILGPU Project // www.ilgpu.net // // File: Grid.cs @@ -101,6 +101,13 @@ public static int DimZ /// The grid dimension. public static Index3D Dimension => new Index3D(DimX, DimY, DimZ); + /// + /// Returns the linear grid index of the current group within the current + /// thread grid. + /// + public static int LinearIndex => + Stride3D.DenseXY.ComputeElementIndex(Index, Dimension); + /// /// Returns the global index. /// @@ -115,6 +122,13 @@ public static int DimZ Index, Group.Index); + /// + /// Returns the linear thread index of the current thread within the current + /// thread grid. + /// + public static int GlobalLinearIndex => + LinearIndex * Group.Dimension.Size + Group.LinearIndex; + #endregion #region Methods