diff --git a/src/Orleans.Core/Messaging/CorrelationId.cs b/src/Orleans.Core/Messaging/CorrelationId.cs index 82b1f3adf3..8ed019f3dd 100644 --- a/src/Orleans.Core/Messaging/CorrelationId.cs +++ b/src/Orleans.Core/Messaging/CorrelationId.cs @@ -1,4 +1,5 @@ using System; +using System.Runtime.CompilerServices; #nullable enable namespace Orleans.Runtime @@ -28,13 +29,21 @@ namespace Orleans.Runtime public int CompareTo(CorrelationId other) => id.CompareTo(other.id); - public override string ToString() => id.ToString(); + public override string ToString() => id.ToString("X16"); - string IFormattable.ToString(string? format, IFormatProvider? formatProvider) => id.ToString(format, formatProvider); + string IFormattable.ToString(string? format, IFormatProvider? formatProvider) => id.ToString(format ?? "X16", formatProvider); bool ISpanFormattable.TryFormat(Span destination, out int charsWritten, ReadOnlySpan format, IFormatProvider? provider) - => id.TryFormat(destination, out charsWritten, format, provider); + { + if (format.IsEmpty) + { + format = "X16"; + } + return id.TryFormat(destination, out charsWritten, format, provider); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal long ToInt64() => id; } } diff --git a/src/Orleans.Core/Messaging/MessageFactory.cs b/src/Orleans.Core/Messaging/MessageFactory.cs index 249f0d4e3e..8bcd17f4c1 100644 --- a/src/Orleans.Core/Messaging/MessageFactory.cs +++ b/src/Orleans.Core/Messaging/MessageFactory.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.IO.Hashing; +using System.Numerics; using Microsoft.Extensions.Logging; using Orleans.CodeGeneration; using Orleans.Serialization; @@ -9,15 +11,25 @@ namespace Orleans.Runtime { internal class MessageFactory { - private readonly DeepCopier deepCopier; - private readonly ILogger logger; - private readonly MessagingTrace messagingTrace; + [ThreadStatic] + private static ulong _nextId; + + // The nonce reduces the chance of an id collision for a given grain to effectively zero. Id collisions are only relevant in scenarios + // where where the infinitesimally small chance of a collision is acceptable, such as call cancellation. + private readonly ulong _seed; + private readonly DeepCopier _deepCopier; + private readonly ILogger _logger; + private readonly MessagingTrace _messagingTrace; public MessageFactory(DeepCopier deepCopier, ILogger logger, MessagingTrace messagingTrace) { - this.deepCopier = deepCopier; - this.logger = logger; - this.messagingTrace = messagingTrace; + _deepCopier = deepCopier; + _logger = logger; + _messagingTrace = messagingTrace; + + // Generate a 64-bit nonce for the host, to be combined with per-message correlation ids to get a unique, per-host value. + // This avoids id collisions across different hosts for a given grain. + _seed = unchecked((ulong)Random.Shared.NextInt64()); } public Message CreateMessage(object body, InvokeMethodOptions options) @@ -25,18 +37,29 @@ public Message CreateMessage(object body, InvokeMethodOptions options) var message = new Message { Direction = (options & InvokeMethodOptions.OneWay) != 0 ? Message.Directions.OneWay : Message.Directions.Request, - Id = CorrelationId.GetNext(), + Id = GetNextCorrelationId(), IsReadOnly = (options & InvokeMethodOptions.ReadOnly) != 0, IsUnordered = (options & InvokeMethodOptions.Unordered) != 0, IsAlwaysInterleave = (options & InvokeMethodOptions.AlwaysInterleave) != 0, BodyObject = body, - RequestContextData = RequestContextExtensions.Export(this.deepCopier), + RequestContextData = RequestContextExtensions.Export(_deepCopier), }; - messagingTrace.OnCreateMessage(message); + _messagingTrace.OnCreateMessage(message); return message; } + private CorrelationId GetNextCorrelationId() + { + // To avoid cross-thread coordination, combine a thread-local counter with the managed thread id. The values are XOR'd together with a + // 64-bit nonce. Rotating the thread id reduces the chance of collision further by putting the significant bits at the high end, where + // they are less likely to collide with the per-thread counter, which could become relevant if the counter exceeded 2^32. + var managedThreadId = Environment.CurrentManagedThreadId; + var tid = (ulong)(managedThreadId << 16 | managedThreadId >> 16) << 32; + var id = _seed ^ tid ^ ++_nextId; + return new CorrelationId(unchecked((long)id)); + } + public Message CreateResponseMessage(Message request) { var response = new Message @@ -52,16 +75,16 @@ public Message CreateResponseMessage(Message request) SendingGrain = request.TargetGrain, CacheInvalidationHeader = request.CacheInvalidationHeader, TimeToLive = request.TimeToLive, - RequestContextData = RequestContextExtensions.Export(this.deepCopier), + RequestContextData = RequestContextExtensions.Export(_deepCopier), }; - messagingTrace.OnCreateMessage(response); + _messagingTrace.OnCreateMessage(response); return response; } public Message CreateRejectionResponse(Message request, Message.RejectionTypes type, string info, Exception ex = null) { - var response = this.CreateResponseMessage(request); + var response = CreateResponseMessage(request); response.Result = Message.ResponseTypes.Rejection; response.BodyObject = new RejectionResponse { @@ -69,8 +92,8 @@ public Message CreateRejectionResponse(Message request, Message.RejectionTypes t RejectionInfo = info, Exception = ex, }; - if (this.logger.IsEnabled(LogLevel.Debug)) - this.logger.LogDebug( + if (_logger.IsEnabled(LogLevel.Debug)) + _logger.LogDebug( ex, "Creating {RejectionType} rejection with info '{Info}' at:" + Environment.NewLine + "{StackTrace}", type, @@ -81,11 +104,11 @@ public Message CreateRejectionResponse(Message request, Message.RejectionTypes t internal Message CreateDiagnosticResponseMessage(Message request, bool isExecuting, bool isWaiting, List diagnostics) { - var response = this.CreateResponseMessage(request); + var response = CreateResponseMessage(request); response.Result = Message.ResponseTypes.Status; response.BodyObject = new StatusResponse(isExecuting, isWaiting, diagnostics); - if (this.logger.IsEnabled(LogLevel.Debug)) this.logger.LogDebug("Creating {RequestMessage} status update with diagnostics {Diagnostics}", request, diagnostics); + if (_logger.IsEnabled(LogLevel.Debug)) _logger.LogDebug("Creating {RequestMessage} status update with diagnostics {Diagnostics}", request, diagnostics); return response; }