From 5dec1c60f916b9b3d456287a6f8ea3f2029fd9b7 Mon Sep 17 00:00:00 2001 From: Lala Sabathil Date: Thu, 2 May 2024 23:10:36 +0200 Subject: [PATCH] fix: bleh --- .../AuthorizationCodeEventWaiter.cs | 2 +- .../ExtensionMethods.cs | 10 +++++ .../OAuth2WebConfiguration.cs | 38 +++++++++++-------- .../OAuth2WebExtension.cs | 27 +++++++++---- 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/DisCatSharp.Extensions.OAuth2Web/EventHandling/AuthorizationCodeEventWaiter.cs b/DisCatSharp.Extensions.OAuth2Web/EventHandling/AuthorizationCodeEventWaiter.cs index 074f63a..04fa7ec 100644 --- a/DisCatSharp.Extensions.OAuth2Web/EventHandling/AuthorizationCodeEventWaiter.cs +++ b/DisCatSharp.Extensions.OAuth2Web/EventHandling/AuthorizationCodeEventWaiter.cs @@ -102,7 +102,7 @@ public async Task> CollectMatc this._collectRequests.TryRemove(request); } - return request.Collected.ToArray(); + return [.. request.Collected]; } /// diff --git a/DisCatSharp.Extensions.OAuth2Web/ExtensionMethods.cs b/DisCatSharp.Extensions.OAuth2Web/ExtensionMethods.cs index edd8732..bf26629 100644 --- a/DisCatSharp.Extensions.OAuth2Web/ExtensionMethods.cs +++ b/DisCatSharp.Extensions.OAuth2Web/ExtensionMethods.cs @@ -137,6 +137,16 @@ public static async Task StopAsync(this IReadOnlyDictionary extensions, DiscordShardedClient client) => extensions.Values.Select(extension => extension.Configuration.RedirectUri).All(client.CurrentApplication.RedirectUris.Contains); + /// + /// Checks if the redirect uri is set for the application in the developer portal. + /// Use this function after you've executed . + /// + /// The extensions. + /// The . + /// Whether the required redirect uris is set. + public static bool HasRequiredRedirectUriSet(this OAuth2WebExtension extensions, DiscordClient client) + => client.CurrentApplication.RedirectUris.Contains(extensions.Configuration.RedirectUri); + /// /// Gets the required redirect uris for the developer portal. /// diff --git a/DisCatSharp.Extensions.OAuth2Web/OAuth2WebConfiguration.cs b/DisCatSharp.Extensions.OAuth2Web/OAuth2WebConfiguration.cs index 5e47ae2..e653af5 100644 --- a/DisCatSharp.Extensions.OAuth2Web/OAuth2WebConfiguration.cs +++ b/DisCatSharp.Extensions.OAuth2Web/OAuth2WebConfiguration.cs @@ -118,23 +118,30 @@ public sealed class OAuth2WebConfiguration /// public string? HtmlOutputException { internal get; init; } = null; - /// - /// Sets the minimum logging level for messages. - /// Defaults to . - /// - public LogLevel MinimumLogLevel { internal get; set; } = LogLevel.Information; + /// + /// Sets the minimum logging level for messages. + /// Defaults to . + /// + public LogLevel MinimumLogLevel { internal get; set; } = LogLevel.Information; - /// - /// Allows you to overwrite the time format used by the internal debug logger. - /// Only applicable when is set left at default value. Defaults to ISO 8601-like format. - /// - public string LogTimestampFormat { internal get; set; } = "yyyy-MM-dd HH:mm:ss zzz"; + /// + /// Allows you to overwrite the time format used by the internal debug logger. + /// Only applicable when is set left at default value. Defaults to ISO 8601-like format. + /// + public string LogTimestampFormat { internal get; set; } = "yyyy-MM-dd HH:mm:ss zzz"; - /// - /// Sets the proxy to use for HTTP connections to Discord. - /// Defaults to . - /// - public IWebProxy? Proxy { internal get; set; } = null; + /// + /// Sets the proxy to use for HTTP connections to Discord. + /// Defaults to . + /// + public IWebProxy? Proxy { internal get; set; } = null; + + /// + /// Sets the logger implementation to use. + /// To create your own logger, implement the instance. + /// Defaults to built-in implementation. + /// + public ILoggerFactory? LoggerFactory { internal get; set; } = null; /// /// Creates a new instance of . @@ -174,5 +181,6 @@ public OAuth2WebConfiguration(OAuth2WebConfiguration other) this.Proxy = other.Proxy; this.LogTimestampFormat = other.LogTimestampFormat; this.MinimumLogLevel = other.MinimumLogLevel; + this.LoggerFactory = other.LoggerFactory; } } diff --git a/DisCatSharp.Extensions.OAuth2Web/OAuth2WebExtension.cs b/DisCatSharp.Extensions.OAuth2Web/OAuth2WebExtension.cs index f59d9ee..34fed08 100644 --- a/DisCatSharp.Extensions.OAuth2Web/OAuth2WebExtension.cs +++ b/DisCatSharp.Extensions.OAuth2Web/OAuth2WebExtension.cs @@ -41,6 +41,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Extensions; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; namespace DisCatSharp.Extensions.OAuth2Web; @@ -49,6 +50,11 @@ namespace DisCatSharp.Extensions.OAuth2Web; /// public sealed class OAuth2WebExtension : BaseExtension { + /// + /// Gets the logger for this extension. + /// + public ILogger Logger { get; private set; } + /// /// Gets the OAuth2 Web configuration. /// @@ -67,8 +73,7 @@ public sealed class OAuth2WebExtension : BaseExtension /// /// Gets the service provider this OAuth2 Web module was configured with. /// - public IServiceProvider ServiceProvider - => this.Configuration.ServiceProvider; + public IServiceProvider ServiceProvider { get; private set; } /// /// Gets the authorization code event waiter. @@ -149,7 +154,7 @@ internal OAuth2WebExtension(OAuth2WebConfiguration configuration) // , DiscordCl { this.Configuration = configuration; - this.OAuth2Client = new(this.Configuration.ClientId, this.Configuration.ClientSecret, this.Configuration.RedirectUri, this.ServiceProvider, this.Configuration.Proxy, null, true, null, this.Configuration.MinimumLogLevel, this.Configuration.LogTimestampFormat); // , discordClient: discordClient); + this.OAuth2Client = new(this.Configuration.ClientId, this.Configuration.ClientSecret, this.Configuration.RedirectUri, this.ServiceProvider, this.Configuration.Proxy, default, default, this.Configuration.LoggerFactory, this.Configuration.MinimumLogLevel, this.Configuration.LogTimestampFormat); // , discordClient: discordClient); this._authorizationCodeReceived = new("OAUTH2_AUTH_CODE_RECEIVED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler); this._authorizationCodeExchanged = new("OAUTH2_AUTH_CODE_EXCHANGED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler); @@ -157,6 +162,10 @@ internal OAuth2WebExtension(OAuth2WebConfiguration configuration) // , DiscordCl this._accessTokenRevoked = new("OAUTH2_ACCESS_TOKEN_REVOKED", TimeSpan.Zero, this.OAuth2Client.EventErrorHandler); this._authorizationCodeWaiter = new(this, this.OAuth2Client); + this.AuthorizationCodeExchanged += this.OnAuthorizationCodeExchangedAsync; + this.AccessTokenRefreshed += this.OnAccessTokenRefreshedAsync; + this.AccessTokenRevoked += this.OnAccessTokenRevokedAsync; + var builder = WebApplication.CreateBuilder(); builder.Services.AddRouting(); @@ -172,12 +181,8 @@ internal OAuth2WebExtension(OAuth2WebConfiguration configuration) // , DiscordCl this.WEB_APP.UseAuthorization(); - this.WEB_APP.MapGet("/oauth/{shard}", this.HandleOAuth2Async); + this.WEB_APP.MapGet("/oauth/{shard}/", this.HandleOAuth2Async); this.WEB_APP.MapGet("/oauth/", this.HandleOAuth2Async); - - this.AuthorizationCodeExchanged += this.OnAuthorizationCodeExchangedAsync; - this.AccessTokenRefreshed += this.OnAccessTokenRefreshedAsync; - this.AccessTokenRevoked += this.OnAccessTokenRevokedAsync; } /// @@ -353,6 +358,9 @@ protected internal override void Setup(DiscordClient client) this.Repository = "DisCatSharp.Extensions"; this.PackageId = "DisCatSharp.Extensions.OAuth2Web"; + + this.Logger = (this.Configuration.LoggerFactory ?? this.Client.Configuration.LoggerFactory).CreateLogger(); + this.ServiceProvider = this.Configuration.ServiceProvider; } /// @@ -491,7 +499,10 @@ await this._authorizationCodeReceived.InvokeAsync(this.OAuth2Client, { var stateUserId = ulong.Parse(this.OAuth2Client.ReadSecureState(state).Split("::")[1]); if (stateUserId != info.User?.Id) + { + this.Logger.LogCritical("OAuth2Web::SecurityException - Received user id {receivedUserId} does not matches authorized user id {authorizedUserId} or authorized is null.", stateUserId, info.User?.Id); throw new SecurityException("State mismatch"); + } } var targetPending = this.OAuth2RequestUrls.First(u => this.OAuth2Client.ValidateState(new(u), requestUrl, this.Configuration.SecureStates));